feat: Add https proxying
This commit is contained in:
265
src/proxy.c
265
src/proxy.c
@@ -10,6 +10,7 @@
|
||||
#include <pthread.h>
|
||||
#include <netdb.h>
|
||||
#include <ctype.h>
|
||||
#include <signal.h>
|
||||
|
||||
#define MAX_REQUEST_SIZE 8192
|
||||
#define MAX_URL_LENGTH 2048
|
||||
@@ -22,12 +23,29 @@ typedef struct {
|
||||
// Extract hostname from HTTP request
|
||||
char* extract_host(const char* request) {
|
||||
static char host[MAX_URL_LENGTH];
|
||||
const char* host_header = strstr(request, "Host: ");
|
||||
const char* host_header = NULL;
|
||||
|
||||
// Check if this is a CONNECT request for HTTPS
|
||||
if (strncmp(request, "CONNECT ", 8) == 0) {
|
||||
// Extract hostname from CONNECT line
|
||||
host_header = request + 8;
|
||||
int i = 0;
|
||||
while (host_header[i] && host_header[i] != ' ' && host_header[i] != ':' && i < MAX_URL_LENGTH - 1) {
|
||||
host[i] = host_header[i];
|
||||
i++;
|
||||
}
|
||||
host[i] = '\0';
|
||||
return host;
|
||||
}
|
||||
|
||||
// For regular HTTP requests, extract from Host header
|
||||
host_header = strstr(request, "Host: ");
|
||||
|
||||
if (host_header) {
|
||||
host_header += 6; // Skip "Host: "
|
||||
int i = 0;
|
||||
while (host_header[i] && host_header[i] != '\r' && host_header[i] != '\n' && i < MAX_URL_LENGTH - 1) {
|
||||
while (host_header[i] && host_header[i] != '\r' && host_header[i] != '\n'
|
||||
&& host_header[i] != ':' && i < MAX_URL_LENGTH - 1) {
|
||||
host[i] = host_header[i];
|
||||
i++;
|
||||
}
|
||||
@@ -38,6 +56,139 @@ char* extract_host(const char* request) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Extract port from HTTP request
|
||||
int extract_port(const char* request) {
|
||||
// Default ports
|
||||
int default_port = 80;
|
||||
|
||||
// Check if this is a CONNECT request (likely HTTPS)
|
||||
if (strncmp(request, "CONNECT ", 8) == 0) {
|
||||
default_port = 443;
|
||||
const char* connect_line = request + 8;
|
||||
const char* port_start = strchr(connect_line, ':');
|
||||
|
||||
if (port_start) {
|
||||
int port = atoi(port_start + 1);
|
||||
return port > 0 ? port : default_port;
|
||||
}
|
||||
|
||||
return default_port;
|
||||
}
|
||||
|
||||
// For regular HTTP, check Host header for port
|
||||
const char* host_header = strstr(request, "Host: ");
|
||||
if (host_header) {
|
||||
host_header += 6; // Skip "Host: "
|
||||
const char* port_start = strchr(host_header, ':');
|
||||
|
||||
if (port_start) {
|
||||
int port = atoi(port_start + 1);
|
||||
return port > 0 ? port : default_port;
|
||||
}
|
||||
}
|
||||
|
||||
// Always return valid default port for HTTP
|
||||
return default_port;
|
||||
}
|
||||
|
||||
// Handle HTTPS CONNECT tunneling
|
||||
void handle_https_tunnel(int client_sock, int server_sock) {
|
||||
fd_set read_fds;
|
||||
char buffer[MAX_REQUEST_SIZE];
|
||||
int max_fd = (client_sock > server_sock) ? client_sock : server_sock;
|
||||
|
||||
// Send 200 Connection Established to the client
|
||||
const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n";
|
||||
if (send(client_sock, success_response, strlen(success_response), 0) < 0) {
|
||||
perror("Failed to send connection established response");
|
||||
return;
|
||||
}
|
||||
|
||||
// Tunnel data between client and server
|
||||
while (1) {
|
||||
FD_ZERO(&read_fds);
|
||||
FD_SET(client_sock, &read_fds);
|
||||
FD_SET(server_sock, &read_fds);
|
||||
|
||||
if (select(max_fd + 1, &read_fds, NULL, NULL, NULL) < 0) {
|
||||
perror("Select failed");
|
||||
break;
|
||||
}
|
||||
|
||||
if (FD_ISSET(client_sock, &read_fds)) {
|
||||
ssize_t bytes_received = recv(client_sock, buffer, sizeof(buffer), 0);
|
||||
if (bytes_received <= 0) break;
|
||||
if (send(server_sock, buffer, bytes_received, 0) <= 0) break;
|
||||
}
|
||||
|
||||
if (FD_ISSET(server_sock, &read_fds)) {
|
||||
ssize_t bytes_received = recv(server_sock, buffer, sizeof(buffer), 0);
|
||||
if (bytes_received <= 0) break;
|
||||
if (send(client_sock, buffer, bytes_received, 0) <= 0) break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Modify HTTP request for direct server communication
|
||||
char* rewrite_http_request(const char* original_request, size_t* new_length) {
|
||||
// Find the first line ending
|
||||
const char* first_line_end = strstr(original_request, "\r\n");
|
||||
if (!first_line_end) {
|
||||
printf("Malformed HTTP request: no line ending found\n");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Parse the request line
|
||||
char method[16] = {0};
|
||||
char url[MAX_URL_LENGTH] = {0};
|
||||
char version[16] = {0};
|
||||
|
||||
sscanf(original_request, "%15s %2047s %15s", method, url, version);
|
||||
printf("Original request: %s %s %s\n", method, url, version);
|
||||
|
||||
// Extract path from URL (remove http://hostname)
|
||||
char* path = NULL;
|
||||
if (strncmp(url, "http://", 7) == 0) {
|
||||
path = strchr(url + 7, '/');
|
||||
if (!path) {
|
||||
path = "/"; // Default to root path if not found
|
||||
printf("No path in URL, using default: %s\n", path);
|
||||
} else {
|
||||
printf("Extracted path from URL: %s\n", path);
|
||||
}
|
||||
} else {
|
||||
path = url; // Already a path or malformed
|
||||
printf("Using URL as path: %s\n", path);
|
||||
}
|
||||
|
||||
// Calculate new request size
|
||||
const char* headers_start = first_line_end + 2; // Skip \r\n
|
||||
size_t headers_length = strlen(headers_start);
|
||||
size_t request_line_length = strlen(method) + strlen(path) + strlen(version) + 4; // +4 for spaces and \r\n
|
||||
size_t total_length = request_line_length + headers_length;
|
||||
|
||||
// Allocate memory for new request
|
||||
char* new_request = malloc(total_length + 1);
|
||||
if (!new_request) {
|
||||
printf("Failed to allocate memory for HTTP request rewrite\n");
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Write the new request
|
||||
int written = snprintf(new_request, total_length + 1, "%s %s %s\r\n%s",
|
||||
method, path, version, headers_start);
|
||||
|
||||
if (written < 0 || (size_t)written > total_length) {
|
||||
printf("Failed to rewrite HTTP request\n");
|
||||
free(new_request);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
printf("Rewritten request first line: %s %s %s\n", method, path, version);
|
||||
*new_length = written;
|
||||
return new_request;
|
||||
}
|
||||
|
||||
// Handle client connection in a separate thread
|
||||
void* handle_client(void* arg) {
|
||||
thread_arg_t* thread_arg = (thread_arg_t*)arg;
|
||||
@@ -55,6 +206,9 @@ void* handle_client(void* arg) {
|
||||
|
||||
request[bytes_received] = '\0';
|
||||
|
||||
// Check if this is an HTTPS CONNECT request
|
||||
int is_connect = (strncmp(request, "CONNECT ", 8) == 0);
|
||||
|
||||
// Extract host from request
|
||||
char* host = extract_host(request);
|
||||
if (!host) {
|
||||
@@ -63,7 +217,17 @@ void* handle_client(void* arg) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
printf("Proxying request to: %s\n", host);
|
||||
// Extract port from request
|
||||
int port = extract_port(request);
|
||||
|
||||
// Ensure we always have a valid port
|
||||
if (port <= 0 || port > 65535) {
|
||||
port = is_connect ? 443 : 80;
|
||||
printf("Invalid port detected, using default port: %d\n", port);
|
||||
}
|
||||
|
||||
printf("Proxying %s request to: %s (port %d)\n",
|
||||
is_connect ? "HTTPS" : "HTTP", host, port);
|
||||
|
||||
// Resolve hostname using DoH
|
||||
char ip_addr[INET6_ADDRSTRLEN];
|
||||
@@ -87,7 +251,7 @@ void* handle_client(void* arg) {
|
||||
memset(&server_addr, 0, sizeof(server_addr));
|
||||
server_addr.sin_family = AF_INET;
|
||||
server_addr.sin_addr.s_addr = inet_addr(ip_addr);
|
||||
server_addr.sin_port = htons(80); // Default to HTTP port
|
||||
server_addr.sin_port = htons(port);
|
||||
|
||||
if (connect(server_sock, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) {
|
||||
perror("Cannot connect to server");
|
||||
@@ -96,22 +260,77 @@ void* handle_client(void* arg) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Forward the request to the server
|
||||
if (send(server_sock, request, bytes_received, 0) < 0) {
|
||||
perror("Failed to send request to server");
|
||||
close(server_sock);
|
||||
close(client_sock);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
// Receive response from server and forward to client
|
||||
while ((bytes_received = recv(server_sock, buffer, sizeof(buffer), 0)) > 0) {
|
||||
if (send(client_sock, buffer, bytes_received, 0) < 0) {
|
||||
perror("Failed to send response to client");
|
||||
break;
|
||||
if (is_connect) {
|
||||
// HTTPS: Handle CONNECT tunnel
|
||||
handle_https_tunnel(client_sock, server_sock);
|
||||
} else {
|
||||
// HTTP: Rewrite the request and forward it
|
||||
printf("HTTP request received, rewriting for server...\n");
|
||||
|
||||
size_t new_length = 0;
|
||||
char* modified_request = rewrite_http_request(request, &new_length);
|
||||
|
||||
if (!modified_request) {
|
||||
printf("Failed to rewrite HTTP request\n");
|
||||
close(server_sock);
|
||||
close(client_sock);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
printf("Forwarding modified HTTP request (%zu bytes) to server...\n", new_length);
|
||||
|
||||
// Set timeouts for sending and receiving
|
||||
struct timeval tv;
|
||||
tv.tv_sec = 5; // 5 seconds timeout
|
||||
tv.tv_usec = 0;
|
||||
setsockopt(server_sock, SOL_SOCKET, SO_SNDTIMEO, (const char*)&tv, sizeof tv);
|
||||
setsockopt(server_sock, SOL_SOCKET, SO_RCVTIMEO, (const char*)&tv, sizeof tv);
|
||||
|
||||
if (send(server_sock, modified_request, new_length, 0) < 0) {
|
||||
perror("Failed to send request to server");
|
||||
free(modified_request);
|
||||
close(server_sock);
|
||||
close(client_sock);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
free(modified_request);
|
||||
|
||||
// Force-flush any pending data
|
||||
int flag = 1;
|
||||
setsockopt(server_sock, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, sizeof(int));
|
||||
|
||||
// Receive response from server and forward to client
|
||||
printf("Waiting for server response (timeout: 5s)...\n");
|
||||
|
||||
fd_set read_fds;
|
||||
FD_ZERO(&read_fds);
|
||||
FD_SET(server_sock, &read_fds);
|
||||
|
||||
// Use select with timeout to wait for response
|
||||
if (select(server_sock + 1, &read_fds, NULL, NULL, &tv) <= 0) {
|
||||
printf("Timeout or error waiting for server response\n");
|
||||
close(server_sock);
|
||||
close(client_sock);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
while ((bytes_received = recv(server_sock, buffer, sizeof(buffer), 0)) > 0) {
|
||||
printf("Received %zd bytes from server\n", bytes_received);
|
||||
if (send(client_sock, buffer, bytes_received, 0) < 0) {
|
||||
perror("Failed to send response to client");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (bytes_received < 0) {
|
||||
perror("Error receiving from server");
|
||||
} else if (bytes_received == 0) {
|
||||
printf("Server closed connection normally\n");
|
||||
}
|
||||
}
|
||||
|
||||
printf("Connection closed: %s (port %d)\n", host, port);
|
||||
close(server_sock);
|
||||
close(client_sock);
|
||||
return NULL;
|
||||
@@ -157,7 +376,11 @@ int start_proxy_server(int port) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Set up signal handler for graceful termination
|
||||
signal(SIGINT, handle_signal);
|
||||
|
||||
printf("Proxy server listening on port %d\n", port);
|
||||
printf("Press Ctrl+C to stop the server\n");
|
||||
|
||||
// Accept and handle client connections
|
||||
while (1) {
|
||||
@@ -195,3 +418,11 @@ int start_proxy_server(int port) {
|
||||
close(server_sock);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Signal handler for graceful termination
|
||||
void handle_signal(int sig) {
|
||||
if (sig == SIGINT) {
|
||||
printf("\nShutting down proxy server...\n");
|
||||
exit(0);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user