feat: Add https proxying

This commit is contained in:
2025-04-23 17:33:23 +10:00
parent aa3da9d5c3
commit 92f4f19d32
3 changed files with 303 additions and 17 deletions

View File

@@ -10,6 +10,7 @@
#include <pthread.h>
#include <netdb.h>
#include <ctype.h>
#include <signal.h>
#define MAX_REQUEST_SIZE 8192
#define MAX_URL_LENGTH 2048
@@ -22,12 +23,29 @@ typedef struct {
// Extract hostname from HTTP request
char* extract_host(const char* request) {
static char host[MAX_URL_LENGTH];
const char* host_header = strstr(request, "Host: ");
const char* host_header = NULL;
// Check if this is a CONNECT request for HTTPS
if (strncmp(request, "CONNECT ", 8) == 0) {
// Extract hostname from CONNECT line
host_header = request + 8;
int i = 0;
while (host_header[i] && host_header[i] != ' ' && host_header[i] != ':' && i < MAX_URL_LENGTH - 1) {
host[i] = host_header[i];
i++;
}
host[i] = '\0';
return host;
}
// For regular HTTP requests, extract from Host header
host_header = strstr(request, "Host: ");
if (host_header) {
host_header += 6; // Skip "Host: "
int i = 0;
while (host_header[i] && host_header[i] != '\r' && host_header[i] != '\n' && i < MAX_URL_LENGTH - 1) {
while (host_header[i] && host_header[i] != '\r' && host_header[i] != '\n'
&& host_header[i] != ':' && i < MAX_URL_LENGTH - 1) {
host[i] = host_header[i];
i++;
}
@@ -38,6 +56,139 @@ char* extract_host(const char* request) {
return NULL;
}
// Extract port from HTTP request
int extract_port(const char* request) {
// Default ports
int default_port = 80;
// Check if this is a CONNECT request (likely HTTPS)
if (strncmp(request, "CONNECT ", 8) == 0) {
default_port = 443;
const char* connect_line = request + 8;
const char* port_start = strchr(connect_line, ':');
if (port_start) {
int port = atoi(port_start + 1);
return port > 0 ? port : default_port;
}
return default_port;
}
// For regular HTTP, check Host header for port
const char* host_header = strstr(request, "Host: ");
if (host_header) {
host_header += 6; // Skip "Host: "
const char* port_start = strchr(host_header, ':');
if (port_start) {
int port = atoi(port_start + 1);
return port > 0 ? port : default_port;
}
}
// Always return valid default port for HTTP
return default_port;
}
// Handle HTTPS CONNECT tunneling
void handle_https_tunnel(int client_sock, int server_sock) {
fd_set read_fds;
char buffer[MAX_REQUEST_SIZE];
int max_fd = (client_sock > server_sock) ? client_sock : server_sock;
// Send 200 Connection Established to the client
const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n";
if (send(client_sock, success_response, strlen(success_response), 0) < 0) {
perror("Failed to send connection established response");
return;
}
// Tunnel data between client and server
while (1) {
FD_ZERO(&read_fds);
FD_SET(client_sock, &read_fds);
FD_SET(server_sock, &read_fds);
if (select(max_fd + 1, &read_fds, NULL, NULL, NULL) < 0) {
perror("Select failed");
break;
}
if (FD_ISSET(client_sock, &read_fds)) {
ssize_t bytes_received = recv(client_sock, buffer, sizeof(buffer), 0);
if (bytes_received <= 0) break;
if (send(server_sock, buffer, bytes_received, 0) <= 0) break;
}
if (FD_ISSET(server_sock, &read_fds)) {
ssize_t bytes_received = recv(server_sock, buffer, sizeof(buffer), 0);
if (bytes_received <= 0) break;
if (send(client_sock, buffer, bytes_received, 0) <= 0) break;
}
}
}
// Modify HTTP request for direct server communication
char* rewrite_http_request(const char* original_request, size_t* new_length) {
// Find the first line ending
const char* first_line_end = strstr(original_request, "\r\n");
if (!first_line_end) {
printf("Malformed HTTP request: no line ending found\n");
return NULL;
}
// Parse the request line
char method[16] = {0};
char url[MAX_URL_LENGTH] = {0};
char version[16] = {0};
sscanf(original_request, "%15s %2047s %15s", method, url, version);
printf("Original request: %s %s %s\n", method, url, version);
// Extract path from URL (remove http://hostname)
char* path = NULL;
if (strncmp(url, "http://", 7) == 0) {
path = strchr(url + 7, '/');
if (!path) {
path = "/"; // Default to root path if not found
printf("No path in URL, using default: %s\n", path);
} else {
printf("Extracted path from URL: %s\n", path);
}
} else {
path = url; // Already a path or malformed
printf("Using URL as path: %s\n", path);
}
// Calculate new request size
const char* headers_start = first_line_end + 2; // Skip \r\n
size_t headers_length = strlen(headers_start);
size_t request_line_length = strlen(method) + strlen(path) + strlen(version) + 4; // +4 for spaces and \r\n
size_t total_length = request_line_length + headers_length;
// Allocate memory for new request
char* new_request = malloc(total_length + 1);
if (!new_request) {
printf("Failed to allocate memory for HTTP request rewrite\n");
return NULL;
}
// Write the new request
int written = snprintf(new_request, total_length + 1, "%s %s %s\r\n%s",
method, path, version, headers_start);
if (written < 0 || (size_t)written > total_length) {
printf("Failed to rewrite HTTP request\n");
free(new_request);
return NULL;
}
printf("Rewritten request first line: %s %s %s\n", method, path, version);
*new_length = written;
return new_request;
}
// Handle client connection in a separate thread
void* handle_client(void* arg) {
thread_arg_t* thread_arg = (thread_arg_t*)arg;
@@ -55,6 +206,9 @@ void* handle_client(void* arg) {
request[bytes_received] = '\0';
// Check if this is an HTTPS CONNECT request
int is_connect = (strncmp(request, "CONNECT ", 8) == 0);
// Extract host from request
char* host = extract_host(request);
if (!host) {
@@ -63,7 +217,17 @@ void* handle_client(void* arg) {
return NULL;
}
printf("Proxying request to: %s\n", host);
// Extract port from request
int port = extract_port(request);
// Ensure we always have a valid port
if (port <= 0 || port > 65535) {
port = is_connect ? 443 : 80;
printf("Invalid port detected, using default port: %d\n", port);
}
printf("Proxying %s request to: %s (port %d)\n",
is_connect ? "HTTPS" : "HTTP", host, port);
// Resolve hostname using DoH
char ip_addr[INET6_ADDRSTRLEN];
@@ -87,7 +251,7 @@ void* handle_client(void* arg) {
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_addr.s_addr = inet_addr(ip_addr);
server_addr.sin_port = htons(80); // Default to HTTP port
server_addr.sin_port = htons(port);
if (connect(server_sock, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) {
perror("Cannot connect to server");
@@ -96,22 +260,77 @@ void* handle_client(void* arg) {
return NULL;
}
// Forward the request to the server
if (send(server_sock, request, bytes_received, 0) < 0) {
perror("Failed to send request to server");
close(server_sock);
close(client_sock);
return NULL;
}
// Receive response from server and forward to client
while ((bytes_received = recv(server_sock, buffer, sizeof(buffer), 0)) > 0) {
if (send(client_sock, buffer, bytes_received, 0) < 0) {
perror("Failed to send response to client");
break;
if (is_connect) {
// HTTPS: Handle CONNECT tunnel
handle_https_tunnel(client_sock, server_sock);
} else {
// HTTP: Rewrite the request and forward it
printf("HTTP request received, rewriting for server...\n");
size_t new_length = 0;
char* modified_request = rewrite_http_request(request, &new_length);
if (!modified_request) {
printf("Failed to rewrite HTTP request\n");
close(server_sock);
close(client_sock);
return NULL;
}
printf("Forwarding modified HTTP request (%zu bytes) to server...\n", new_length);
// Set timeouts for sending and receiving
struct timeval tv;
tv.tv_sec = 5; // 5 seconds timeout
tv.tv_usec = 0;
setsockopt(server_sock, SOL_SOCKET, SO_SNDTIMEO, (const char*)&tv, sizeof tv);
setsockopt(server_sock, SOL_SOCKET, SO_RCVTIMEO, (const char*)&tv, sizeof tv);
if (send(server_sock, modified_request, new_length, 0) < 0) {
perror("Failed to send request to server");
free(modified_request);
close(server_sock);
close(client_sock);
return NULL;
}
free(modified_request);
// Force-flush any pending data
int flag = 1;
setsockopt(server_sock, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, sizeof(int));
// Receive response from server and forward to client
printf("Waiting for server response (timeout: 5s)...\n");
fd_set read_fds;
FD_ZERO(&read_fds);
FD_SET(server_sock, &read_fds);
// Use select with timeout to wait for response
if (select(server_sock + 1, &read_fds, NULL, NULL, &tv) <= 0) {
printf("Timeout or error waiting for server response\n");
close(server_sock);
close(client_sock);
return NULL;
}
while ((bytes_received = recv(server_sock, buffer, sizeof(buffer), 0)) > 0) {
printf("Received %zd bytes from server\n", bytes_received);
if (send(client_sock, buffer, bytes_received, 0) < 0) {
perror("Failed to send response to client");
break;
}
}
if (bytes_received < 0) {
perror("Error receiving from server");
} else if (bytes_received == 0) {
printf("Server closed connection normally\n");
}
}
printf("Connection closed: %s (port %d)\n", host, port);
close(server_sock);
close(client_sock);
return NULL;
@@ -157,7 +376,11 @@ int start_proxy_server(int port) {
return 1;
}
// Set up signal handler for graceful termination
signal(SIGINT, handle_signal);
printf("Proxy server listening on port %d\n", port);
printf("Press Ctrl+C to stop the server\n");
// Accept and handle client connections
while (1) {
@@ -195,3 +418,11 @@ int start_proxy_server(int port) {
close(server_sock);
return 0;
}
// Signal handler for graceful termination
void handle_signal(int sig) {
if (sig == SIGINT) {
printf("\nShutting down proxy server...\n");
exit(0);
}
}