#include "proxy.h" #include "doh.h" #include "dane.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include // For better random number generation #include // Add uthash library for hash tables #include // For atomic operations #define MAX_REQUEST_SIZE 8192 #define MAX_URL_LENGTH 2048 #define THREAD_POOL_SIZE 20 #define CACHE_EXPIRY_TIME (86400) // Increase to 24 hours #define DNS_CACHE_EXPIRY_TIME (3600) // Increase to 1 hour #define IP_ADDR_MAX_LEN 46 // Maximum length for IPv6 address #define CERT_RENEWAL_TIME (86400*7) // Renew certificates after 7 days #define CONNECTION_POOL_SIZE 32 // Number of connections to keep in the pool #define CONNECTION_IDLE_TIMEOUT 300 // 5 minutes #define NUM_WORKER_THREADS 8 // Worker threads for handling connections #define MAX_PARALLEL_DNS_QUERIES 8 // Number of parallel DNS queries #define CERT_GEN_THREAD_POOL_SIZE 4 // Threads for certificate generation #define PRE_GEN_CERT_COUNT 20 // Number of pre-generated certificates for domains with DANE // Forward declarations of structures typedef struct dns_task dns_task; typedef struct cert_gen_task cert_gen_task; // Function prototypes void* periodic_cleanup(void* arg); int generate_unique_cert(const char* hostname, const char* cert_path, const char* key_path); void cleanup_ssl_connection(SSL* ssl, SSL_CTX* ctx); void* worker_thread(void* arg); void init_thread_pool(void); void queue_client_connection(int client_sock); void handle_client_optimized(int client_sock); void handle_https_tunnel_optimized(int client_sock, int server_sock, const char* hostname, const char* ip_addr, SSL* existing_ssl, SSL_CTX* existing_ctx); void init_dns_pool(void); void* dns_worker(void* arg); dns_task* queue_dns_resolution(const char* hostname); int wait_for_dns_resolution(dns_task* task, char* ip_buffer, size_t buffer_size); void cleanup_dns_pool(void); void init_cert_gen_pool(void); void* cert_gen_worker(void* arg); cert_gen_task* queue_cert_generation(const char* hostname, const char* cert_path, const char* key_path); int wait_for_cert_generation(cert_gen_task* task); void cleanup_cert_gen_pool(void); void pregen_common_certificates(void); void* pregen_cert_thread(void* arg); // Add missing function prototypes for SSL and certificate caching SSL_SESSION* get_ssl_session(const char* hostname); void cache_ssl_session(const char* hostname, SSL_SESSION* session); void cleanup_ssl_sessions(void); int get_cached_cert(const char* hostname, ssl_context_t** ctx); void cache_cert_in_memory(const char* hostname, const char* cert_path, const char* key_path); void handle_regular_https_tunnel(int client_sock, int server_sock); void ssl_tunnel_data(SSL* client_ssl, SSL* server_ssl); // Hash table entry for DANE verification cache typedef struct { char hostname[MAX_URL_LENGTH]; // key int verified; time_t timestamp; UT_hash_handle hh; // makes this structure hashable } dane_cache_entry; // Hash table entry for DNS cache typedef struct { char hostname[MAX_URL_LENGTH]; // key char ip_addr[IP_ADDR_MAX_LEN]; time_t timestamp; int is_valid; // 1 if valid IP, 0 if resolution failed UT_hash_handle hh; // makes this structure hashable } dns_cache_entry; // Hash table entry for certificate serial tracking typedef struct { char hostname[MAX_URL_LENGTH]; // key time_t timestamp; UT_hash_handle hh; // makes this structure hashable } cert_serial_entry; // Connection pool entry typedef struct { char hostname[MAX_URL_LENGTH]; int port; int sock; SSL* ssl; SSL_CTX* ctx; time_t last_used; int in_use; UT_hash_handle hh; } connection_pool_entry; // Thread pool structures typedef struct { int client_sock; } client_task; typedef struct { client_task* tasks; int task_count; int task_capacity; pthread_mutex_t mutex; pthread_cond_t not_empty; pthread_cond_t not_full; } task_queue; // Parallel DNS resolution struct dns_task { char hostname[MAX_URL_LENGTH]; char ip_addr[IP_ADDR_MAX_LEN]; int result; int completed; pthread_mutex_t mutex; pthread_cond_t cond; }; typedef struct { dns_task** tasks; int task_count; int task_capacity; pthread_mutex_t mutex; pthread_cond_t not_empty; pthread_cond_t not_full; int running; } dns_queue; // Certificate generation thread pool struct cert_gen_task { char hostname[MAX_URL_LENGTH]; char cert_path[256]; char key_path[256]; int result; int completed; pthread_mutex_t mutex; pthread_cond_t cond; }; typedef struct { cert_gen_task** tasks; int task_count; int task_capacity; pthread_mutex_t mutex; pthread_cond_t not_empty; pthread_cond_t not_full; int running; } cert_gen_queue; // Global caches - now using hash tables static dane_cache_entry* dane_cache = NULL; static dns_cache_entry* dns_cache = NULL; static cert_serial_entry* cert_serials = NULL; static connection_pool_entry* conn_pool = NULL; // Use rwlocks instead of mutexes for better concurrency static pthread_rwlock_t cache_rwlock = PTHREAD_RWLOCK_INITIALIZER; static pthread_rwlock_t dns_cache_rwlock = PTHREAD_RWLOCK_INITIALIZER; static pthread_rwlock_t cert_rwlock = PTHREAD_RWLOCK_INITIALIZER; static pthread_mutex_t conn_pool_mutex = PTHREAD_MUTEX_INITIALIZER; // Thread pool state static task_queue work_queue; static pthread_t worker_threads[NUM_WORKER_THREADS]; static int workers_running = 1; // DNS resolution thread pool state static dns_queue doh_queue; static pthread_t dns_threads[MAX_PARALLEL_DNS_QUERIES]; // Certificate generation thread pool state static cert_gen_queue cert_queue; static pthread_t cert_gen_threads[CERT_GEN_THREAD_POOL_SIZE]; // Atomic counter for certificate serials to avoid collisions static atomic_ullong next_serial = 0; typedef struct { int client_sock; } thread_arg_t; // Improved SSL session cache implementation typedef struct { char hostname[MAX_URL_LENGTH]; SSL_SESSION* session; time_t timestamp; UT_hash_handle hh; } ssl_session_cache_entry; static ssl_session_cache_entry* ssl_session_cache = NULL; static pthread_mutex_t ssl_session_mutex = PTHREAD_MUTEX_INITIALIZER; // Certificate memory cache typedef struct { char hostname[MAX_URL_LENGTH]; unsigned char* cert_data; size_t cert_len; unsigned char* key_data; size_t key_len; time_t timestamp; UT_hash_handle hh; } cert_cache_entry; static cert_cache_entry* cert_memory_cache = NULL; static pthread_mutex_t cert_cache_mutex = PTHREAD_MUTEX_INITIALIZER; // Initialize the thread pool void init_thread_pool(void) { // Initialize the task queue work_queue.tasks = malloc(sizeof(client_task) * THREAD_POOL_SIZE); work_queue.task_count = 0; work_queue.task_capacity = THREAD_POOL_SIZE; pthread_mutex_init(&work_queue.mutex, NULL); pthread_cond_init(&work_queue.not_empty, NULL); pthread_cond_init(&work_queue.not_full, NULL); // Create worker threads for (int i = 0; i < NUM_WORKER_THREADS; i++) { if (pthread_create(&worker_threads[i], NULL, worker_thread, NULL) != 0) { perror("Failed to create worker thread"); exit(1); } } } // Queue a client connection for processing void queue_client_connection(int client_sock) { pthread_mutex_lock(&work_queue.mutex); // Wait if the queue is full while (work_queue.task_count >= work_queue.task_capacity && workers_running) { pthread_cond_wait(&work_queue.not_full, &work_queue.mutex); } if (!workers_running) { pthread_mutex_unlock(&work_queue.mutex); close(client_sock); return; } // Add the client socket to the queue work_queue.tasks[work_queue.task_count].client_sock = client_sock; work_queue.task_count++; // Signal that the queue is not empty pthread_cond_signal(&work_queue.not_empty); pthread_mutex_unlock(&work_queue.mutex); } // Worker thread to process client connections void* worker_thread(void* arg) { (void)arg; // Unused parameter while (workers_running) { pthread_mutex_lock(&work_queue.mutex); // Wait if the queue is empty while (work_queue.task_count == 0 && workers_running) { pthread_cond_wait(&work_queue.not_empty, &work_queue.mutex); } if (!workers_running) { pthread_mutex_unlock(&work_queue.mutex); break; } // Get a client socket from the queue int client_sock = work_queue.tasks[0].client_sock; // Remove task from queue work_queue.task_count--; if (work_queue.task_count > 0) { memmove(&work_queue.tasks[0], &work_queue.tasks[1], sizeof(client_task) * work_queue.task_count); } // Signal that the queue is not full pthread_cond_signal(&work_queue.not_full); pthread_mutex_unlock(&work_queue.mutex); // Process the client connection handle_client_optimized(client_sock); } return NULL; } // Connection pooling void add_to_connection_pool(const char* hostname, int port, int sock, SSL* ssl, SSL_CTX* ctx) { pthread_mutex_lock(&conn_pool_mutex); // Check if we already have a connection for this host connection_pool_entry* entry; HASH_FIND_STR(conn_pool, hostname, entry); if (entry) { // Close the existing connection if (entry->ssl) SSL_free(entry->ssl); if (entry->ctx) SSL_CTX_free(entry->ctx); if (entry->sock > 0) close(entry->sock); // Update with new connection entry->port = port; entry->sock = sock; entry->ssl = ssl; entry->ctx = ctx; entry->last_used = time(NULL); entry->in_use = 0; } else { // Create a new entry entry = malloc(sizeof(connection_pool_entry)); if (entry) { strncpy(entry->hostname, hostname, MAX_URL_LENGTH-1); entry->hostname[MAX_URL_LENGTH-1] = '\0'; entry->port = port; entry->sock = sock; entry->ssl = ssl; entry->ctx = ctx; entry->last_used = time(NULL); entry->in_use = 0; HASH_ADD_STR(conn_pool, hostname, entry); } } pthread_mutex_unlock(&conn_pool_mutex); } int get_from_connection_pool(const char* hostname, int port, int* sock, SSL** ssl, SSL_CTX** ctx) { pthread_mutex_lock(&conn_pool_mutex); connection_pool_entry* entry; HASH_FIND_STR(conn_pool, hostname, entry); if (entry && entry->port == port && !entry->in_use) { time_t now = time(NULL); // Check if connection has been idle too long if (now - entry->last_used > CONNECTION_IDLE_TIMEOUT) { // Connection too old, close it if (entry->ssl) SSL_free(entry->ssl); if (entry->ctx) SSL_CTX_free(entry->ctx); if (entry->sock > 0) close(entry->sock); HASH_DEL(conn_pool, entry); free(entry); pthread_mutex_unlock(&conn_pool_mutex); return 0; } // Connection is valid and available *sock = entry->sock; *ssl = entry->ssl; *ctx = entry->ctx; entry->in_use = 1; entry->last_used = now; pthread_mutex_unlock(&conn_pool_mutex); return 1; } pthread_mutex_unlock(&conn_pool_mutex); return 0; } void release_connection(const char* hostname) { pthread_mutex_lock(&conn_pool_mutex); connection_pool_entry* entry; HASH_FIND_STR(conn_pool, hostname, entry); if (entry) { entry->in_use = 0; entry->last_used = time(NULL); } pthread_mutex_unlock(&conn_pool_mutex); } void cleanup_connection_pool() { pthread_mutex_lock(&conn_pool_mutex); connection_pool_entry* entry, *tmp; time_t now = time(NULL); HASH_ITER(hh, conn_pool, entry, tmp) { if (now - entry->last_used > CONNECTION_IDLE_TIMEOUT || entry->sock <= 0) { // Close the connection if (entry->ssl) SSL_free(entry->ssl); if (entry->ctx) SSL_CTX_free(entry->ctx); if (entry->sock > 0) close(entry->sock); // Remove from hash table HASH_DEL(conn_pool, entry); free(entry); } } pthread_mutex_unlock(&conn_pool_mutex); } // Certificate serial tracking functions void add_cert_serial(const char* hostname) { pthread_rwlock_wrlock(&cert_rwlock); // Check if entry already exists cert_serial_entry* entry; HASH_FIND_STR(cert_serials, hostname, entry); if (entry) { // Update timestamp for existing entry entry->timestamp = time(NULL); } else { // Create new entry entry = malloc(sizeof(cert_serial_entry)); if (entry) { strncpy(entry->hostname, hostname, MAX_URL_LENGTH-1); entry->hostname[MAX_URL_LENGTH-1] = '\0'; entry->timestamp = time(NULL); HASH_ADD_STR(cert_serials, hostname, entry); } } pthread_rwlock_unlock(&cert_rwlock); } int should_renew_cert(const char* hostname) { pthread_rwlock_rdlock(&cert_rwlock); cert_serial_entry* entry; HASH_FIND_STR(cert_serials, hostname, entry); if (entry) { time_t now = time(NULL); int should_renew = (now - entry->timestamp > CERT_RENEWAL_TIME); pthread_rwlock_unlock(&cert_rwlock); return should_renew; } pthread_rwlock_unlock(&cert_rwlock); return 1; // No entry found, should generate } void cleanup_cert_serials() { pthread_rwlock_wrlock(&cert_rwlock); cert_serial_entry* entry, *tmp; time_t now = time(NULL); HASH_ITER(hh, cert_serials, entry, tmp) { if (now - entry->timestamp > CERT_RENEWAL_TIME) { // Remove the certificate files char cert_path[256]; char key_path[256]; snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", entry->hostname); snprintf(key_path, sizeof(key_path), "certs/%s.key", entry->hostname); unlink(cert_path); // Ignore errors unlink(key_path); // Remove from hash table HASH_DEL(cert_serials, entry); free(entry); } } pthread_rwlock_unlock(&cert_rwlock); } void free_cert_serials() { pthread_rwlock_wrlock(&cert_rwlock); cert_serial_entry* entry, *tmp; HASH_ITER(hh, cert_serials, entry, tmp) { HASH_DEL(cert_serials, entry); free(entry); } pthread_rwlock_unlock(&cert_rwlock); } // Generate a certificate with a unique serial number int generate_unique_cert(const char* hostname, const char* cert_path, const char* key_path) { // Remove existing certificate files to prevent reuse unlink(cert_path); unlink(key_path); // Get a unique serial number using atomic operations with better entropy struct timespec ts; clock_gettime(CLOCK_REALTIME, &ts); // Generate a truly unique serial number combining multiple sources of entropy unsigned long long timestamp = ((unsigned long long)ts.tv_sec << 32) | (ts.tv_nsec & 0xFFFFFFFF); // Add more randomness from OpenSSL unsigned char random_bytes[16]; if (RAND_bytes(random_bytes, sizeof(random_bytes)) != 1) { // Fall back to less secure randomness if OpenSSL random generator fails for (int i = 0; i < 16; i++) { random_bytes[i] = rand() & 0xFF; } } // Combine all sources of randomness for the serial unsigned long long random_component = 0; for (int i = 0; i < 8; i++) { random_component = (random_component << 8) | random_bytes[i]; } // Create a unique serial using atomic counter, timestamp, and randomness unsigned long long serial = atomic_fetch_add(&next_serial, 1) ^ timestamp ^ random_component; // Force serial number to be positive and non-zero serial = (serial & 0x7FFFFFFFFFFFFFFFULL) | 0x0000010000000000ULL; printf("Generating certificate for %s with serial: %llu\n", hostname, serial); // Clear any previously cached certificates in Firefox // Firefox caches certificates by [issuer DN + serial number] pthread_mutex_lock(&cert_cache_mutex); cert_cache_entry* entry; HASH_FIND_STR(cert_memory_cache, hostname, entry); if (entry) { // Remove the cached certificate free(entry->cert_data); free(entry->key_data); HASH_DEL(cert_memory_cache, entry); free(entry); printf("Removed cached certificate for %s to avoid serial conflicts\n", hostname); } pthread_mutex_unlock(&cert_cache_mutex); // Use the serial for entropy srand((unsigned int)(serial & 0xFFFFFFFF)); // Set special environment variable to force unique serial in certificate generation char serial_env[64]; snprintf(serial_env, sizeof(serial_env), "%llu", serial); setenv("CERT_SERIAL_OVERRIDE", serial_env, 1); // Queue certificate generation in the thread pool cert_gen_task* task = queue_cert_generation(hostname, cert_path, key_path); if (task) { // Wait for the certificate to be generated int result = wait_for_cert_generation(task); if (result) { // Track this certificate generation add_cert_serial(hostname); // Cache cert and key in memory cache_cert_in_memory(hostname, cert_path, key_path); // Verify the certificate was correctly generated printf("Certificate successfully generated for %s\n", hostname); } else { fprintf(stderr, "Failed to generate certificate for %s\n", hostname); } // Clear the environment variable unsetenv("CERT_SERIAL_OVERRIDE"); return result; } else { // Fall back to synchronous generation if queuing fails printf("Using synchronous certificate generation for %s\n", hostname); int result = generate_trusted_cert(hostname, cert_path, key_path); if (result) { add_cert_serial(hostname); // Cache cert and key in memory cache_cert_in_memory(hostname, cert_path, key_path); printf("Certificate successfully generated for %s (sync method)\n", hostname); } else { fprintf(stderr, "Failed to generate certificate for %s (sync method)\n", hostname); } // Clear the environment variable unsetenv("CERT_SERIAL_OVERRIDE"); return result; } } // SSL connection cleanup helper void cleanup_ssl_connection(SSL* ssl, SSL_CTX* ctx) { if (ssl) { SSL_shutdown(ssl); SSL_free(ssl); } if (ctx) { SSL_CTX_free(ctx); } } // DNS cache functions (hash table versions) void add_to_dns_cache(const char* hostname, const char* ip_addr, int is_valid) { pthread_rwlock_wrlock(&dns_cache_rwlock); // First check if entry already exists dns_cache_entry* existing; HASH_FIND_STR(dns_cache, hostname, existing); if (existing) { // Update existing entry strncpy(existing->ip_addr, ip_addr, IP_ADDR_MAX_LEN-1); existing->ip_addr[IP_ADDR_MAX_LEN-1] = '\0'; existing->is_valid = is_valid; existing->timestamp = time(NULL); } else { // Create new entry dns_cache_entry* new_entry = malloc(sizeof(dns_cache_entry)); if (new_entry) { strncpy(new_entry->hostname, hostname, MAX_URL_LENGTH-1); new_entry->hostname[MAX_URL_LENGTH-1] = '\0'; strncpy(new_entry->ip_addr, ip_addr, IP_ADDR_MAX_LEN-1); new_entry->ip_addr[IP_ADDR_MAX_LEN-1] = '\0'; new_entry->is_valid = is_valid; new_entry->timestamp = time(NULL); HASH_ADD_STR(dns_cache, hostname, new_entry); } } pthread_rwlock_unlock(&dns_cache_rwlock); } int check_dns_cache(const char* hostname, char* ip_buffer, size_t buffer_size) { pthread_rwlock_rdlock(&dns_cache_rwlock); dns_cache_entry* entry; HASH_FIND_STR(dns_cache, hostname, entry); if (!entry) { pthread_rwlock_unlock(&dns_cache_rwlock); return -1; // Not in cache } // Check if entry has expired time_t now = time(NULL); if (now - entry->timestamp > DNS_CACHE_EXPIRY_TIME) { // Entry expired pthread_rwlock_unlock(&dns_cache_rwlock); return -1; } // If entry is valid, copy IP address to buffer if (entry->is_valid) { strncpy(ip_buffer, entry->ip_addr, buffer_size-1); ip_buffer[buffer_size-1] = '\0'; pthread_rwlock_unlock(&dns_cache_rwlock); return 0; // Success } pthread_rwlock_unlock(&dns_cache_rwlock); return 1; // Entry exists but is marked as invalid resolution } void cleanup_dns_cache() { pthread_rwlock_wrlock(&dns_cache_rwlock); dns_cache_entry* entry, *tmp; time_t now = time(NULL); HASH_ITER(hh, dns_cache, entry, tmp) { if (now - entry->timestamp > DNS_CACHE_EXPIRY_TIME) { // Remove from hash table HASH_DEL(dns_cache, entry); free(entry); } } pthread_rwlock_unlock(&dns_cache_rwlock); } void free_dns_cache() { pthread_rwlock_wrlock(&dns_cache_rwlock); dns_cache_entry* entry, *tmp; HASH_ITER(hh, dns_cache, entry, tmp) { HASH_DEL(dns_cache, entry); free(entry); } pthread_rwlock_unlock(&dns_cache_rwlock); } // Helper function to print SSL errors void print_ssl_errors(const char* context) { unsigned long err; char err_buf[256]; while ((err = ERR_get_error()) != 0) { ERR_error_string_n(err, err_buf, sizeof(err_buf)); fprintf(stderr, "SSL Error (%s): %s\n", context, err_buf); } } // DANE cache functions dane_cache_entry* find_cache_entry(const char* hostname) { dane_cache_entry* entry; HASH_FIND_STR(dane_cache, hostname, entry); return entry; } void add_to_cache(const char* hostname, int verified) { pthread_rwlock_wrlock(&cache_rwlock); // First check if entry already exists dane_cache_entry* existing; HASH_FIND_STR(dane_cache, hostname, existing); if (existing) { // Update existing entry existing->verified = verified; existing->timestamp = time(NULL); } else { // Create new entry dane_cache_entry* new_entry = malloc(sizeof(dane_cache_entry)); if (new_entry) { strncpy(new_entry->hostname, hostname, MAX_URL_LENGTH-1); new_entry->hostname[MAX_URL_LENGTH-1] = '\0'; new_entry->verified = verified; new_entry->timestamp = time(NULL); HASH_ADD_STR(dane_cache, hostname, new_entry); } } pthread_rwlock_unlock(&cache_rwlock); } int check_cache(const char* hostname) { pthread_rwlock_rdlock(&cache_rwlock); dane_cache_entry* entry; HASH_FIND_STR(dane_cache, hostname, entry); if (!entry) { pthread_rwlock_unlock(&cache_rwlock); return -1; // Not in cache } // Check if entry has expired time_t now = time(NULL); if (now - entry->timestamp > CACHE_EXPIRY_TIME) { // Entry expired pthread_rwlock_unlock(&cache_rwlock); return -1; } int result = entry->verified; pthread_rwlock_unlock(&cache_rwlock); return result; } void cleanup_cache() { pthread_rwlock_wrlock(&cache_rwlock); dane_cache_entry* entry, *tmp; time_t now = time(NULL); HASH_ITER(hh, dane_cache, entry, tmp) { if (now - entry->timestamp > CACHE_EXPIRY_TIME) { // Remove from hash table HASH_DEL(dane_cache, entry); free(entry); } } pthread_rwlock_unlock(&cache_rwlock); } void free_cache() { pthread_rwlock_wrlock(&cache_rwlock); dane_cache_entry* entry, *tmp; HASH_ITER(hh, dane_cache, entry, tmp) { HASH_DEL(dane_cache, entry); free(entry); } pthread_rwlock_unlock(&cache_rwlock); } // Extract hostname from HTTP request char* extract_host(const char* request) { static char host[MAX_URL_LENGTH]; const char* host_header = NULL; // Check if this is a CONNECT request for HTTPS if (strncmp(request, "CONNECT ", 8) == 0) { // Extract hostname from CONNECT line host_header = request + 8; int i = 0; while (host_header[i] && host_header[i] != ' ' && host_header[i] != ':' && i < MAX_URL_LENGTH - 1) { host[i] = host_header[i]; i++; } host[i] = '\0'; return host; } // For regular HTTP requests, extract from Host header host_header = strstr(request, "Host: "); if (host_header) { host_header += 6; // Skip "Host: " int i = 0; while (host_header[i] && host_header[i] != '\r' && host_header[i] != '\n' && host_header[i] != ':' && i < MAX_URL_LENGTH - 1) { host[i] = host_header[i]; i++; } host[i] = '\0'; return host; } return NULL; } // Extract port from HTTP request int extract_port(const char* request) { // Default port int default_port = 80; // Check if this is a CONNECT request (likely HTTPS) if (strncmp(request, "CONNECT ", 8) == 0) { default_port = 443; const char* connect_line = request + 8; const char* port_start = strchr(connect_line, ':'); if (port_start) { int port = atoi(port_start + 1); return port > 0 ? port : default_port; } return default_port; } // For regular HTTP, check Host header for port const char* host_header = strstr(request, "Host: "); if (host_header) { host_header += 6; // Skip "Host: " const char* port_start = strchr(host_header, ':'); if (port_start) { int port = atoi(port_start + 1); return port > 0 ? port : default_port; } } // Always return valid default port for HTTP return default_port; } // Initialize SSL context for intercepting HTTPS ssl_context_t* init_ssl_context(const char* cert_path, const char* key_path) { ssl_context_t* ssl_ctx = malloc(sizeof(ssl_context_t)); if (!ssl_ctx) { fprintf(stderr, "Failed to allocate memory for SSL context\n"); return NULL; } // Initialize SSL context ssl_ctx->ctx = SSL_CTX_new(TLS_server_method()); if (!ssl_ctx->ctx) { fprintf(stderr, "Failed to create SSL context\n"); free(ssl_ctx); return NULL; } // Configure SSL context SSL_CTX_set_options(ssl_ctx->ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION); // Load certificate and private key if (SSL_CTX_use_certificate_file(ssl_ctx->ctx, cert_path, SSL_FILETYPE_PEM) <= 0) { fprintf(stderr, "Failed to load certificate: %s\n", cert_path); SSL_CTX_free(ssl_ctx->ctx); free(ssl_ctx); return NULL; } if (SSL_CTX_use_PrivateKey_file(ssl_ctx->ctx, key_path, SSL_FILETYPE_PEM) <= 0) { fprintf(stderr, "Failed to load private key: %s\n", key_path); SSL_CTX_free(ssl_ctx->ctx); free(ssl_ctx); return NULL; } // Verify private key if (!SSL_CTX_check_private_key(ssl_ctx->ctx)) { fprintf(stderr, "Private key does not match the certificate\n"); SSL_CTX_free(ssl_ctx->ctx); free(ssl_ctx); return NULL; } return ssl_ctx; } // Handle HTTPS CONNECT tunneling with DANE verification void handle_https_tunnel(int client_sock, int server_sock, const char* hostname, const char* ip_addr) { // Unused parameter (void)ip_addr; printf("Starting HTTPS tunnel for %s\n", hostname); // Check the cache first int cached_result = check_cache(hostname); // Generate certificate paths char cert_path[256]; char key_path[256]; snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", hostname); snprintf(key_path, sizeof(key_path), "certs/%s.key", hostname); // If in cache and verified as successful, skip verification if (cached_result == 1) { printf("Using cached DANE verification for %s (verified)\n", hostname); // Always regenerate certificate to avoid browser cache issues printf("Regenerating certificate for %s to avoid browser certificate cache issues\n", hostname); // Generate trusted certificate BEFORE responding to the client if (!generate_unique_cert(hostname, cert_path, key_path)) { fprintf(stderr, "Failed to generate trusted certificate for %s\n", hostname); handle_regular_https_tunnel(client_sock, server_sock); return; } // Initialize SSL context with our certificate ssl_context_t* client_ctx = init_ssl_context(cert_path, key_path); if (!client_ctx) { fprintf(stderr, "Failed to initialize SSL context\n"); handle_regular_https_tunnel(client_sock, server_sock); return; } // Connect to the server SSL_CTX* server_ctx = SSL_CTX_new(TLS_client_method()); if (!server_ctx) { SSL_CTX_free(client_ctx->ctx); free(client_ctx); handle_regular_https_tunnel(client_sock, server_sock); return; } SSL* server_ssl = SSL_new(server_ctx); SSL_set_fd(server_ssl, server_sock); SSL_set_tlsext_host_name(server_ssl, hostname); // Add session reuse SSL_SESSION* session = get_ssl_session(hostname); if (session) { SSL_set_session(server_ssl, session); printf("Reusing SSL session for %s\n", hostname); } if (SSL_connect(server_ssl) <= 0) { print_ssl_errors("SSL_connect"); cleanup_ssl_connection(server_ssl, server_ctx); SSL_CTX_free(client_ctx->ctx); free(client_ctx); handle_regular_https_tunnel(client_sock, server_sock); return; } else { // Cache the new session for future connections SSL_SESSION* new_session = SSL_get1_session(server_ssl); if (new_session) { cache_ssl_session(hostname, new_session); } } // Pre-initialize the SSL object to ensure it's ready client_ctx->ssl = SSL_new(client_ctx->ctx); if (!client_ctx->ssl) { fprintf(stderr, "Failed to create SSL object\n"); cleanup_ssl_connection(server_ssl, server_ctx); SSL_CTX_free(client_ctx->ctx); free(client_ctx); handle_regular_https_tunnel(client_sock, server_sock); return; } // Set up the file descriptor SSL_set_fd(client_ctx->ssl, client_sock); // Send 200 Connection Established to the client const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n"; if (send(client_sock, success_response, strlen(success_response), 0) < 0) { perror("Failed to send connection established response"); cleanup_ssl_connection(client_ctx->ssl, NULL); cleanup_ssl_connection(server_ssl, server_ctx); SSL_CTX_free(client_ctx->ctx); free(client_ctx); return; } // Clear any previous errors ERR_clear_error(); // Accept SSL connection from client int accept_result = SSL_accept(client_ctx->ssl); if (accept_result <= 0) { int ssl_err = SSL_get_error(client_ctx->ssl, accept_result); fprintf(stderr, "SSL accept failed with error code: %d\n", ssl_err); print_ssl_errors("SSL_accept"); // Try to determine if client closed the connection if (ssl_err == SSL_ERROR_SYSCALL && errno == 0) { fprintf(stderr, "Client may have closed the connection\n"); } cleanup_ssl_connection(client_ctx->ssl, NULL); cleanup_ssl_connection(server_ssl, server_ctx); SSL_CTX_free(client_ctx->ctx); free(client_ctx); return; } printf("SSL connection established with client for %s (using cached verification)\n", hostname); // Forward data between them ssl_tunnel_data(client_ctx->ssl, server_ssl); // Clean up cleanup_ssl_connection(client_ctx->ssl, NULL); cleanup_ssl_connection(server_ssl, server_ctx); SSL_CTX_free(client_ctx->ctx); free(client_ctx); return; } else if (cached_result == 0) { printf("Using cached DANE verification for %s (not verified)\n", hostname); handle_regular_https_tunnel(client_sock, server_sock); return; } // Not in cache or expired, perform full verification // Check if we have DANE records for this domain int has_dane = is_dane_available(hostname); if (has_dane) { printf("DANE records found for %s, will verify certificate\n", hostname); // First, establish a connection to the server to get its certificate SSL_CTX* server_ctx = SSL_CTX_new(TLS_client_method()); if (!server_ctx) { fprintf(stderr, "Failed to create server SSL context\n"); handle_regular_https_tunnel(client_sock, server_sock); return; } SSL* server_ssl = SSL_new(server_ctx); SSL_set_fd(server_ssl, server_sock); // Set SNI hostname SSL_set_tlsext_host_name(server_ssl, hostname); if (SSL_connect(server_ssl) <= 0) { fprintf(stderr, "SSL connection to server failed\n"); cleanup_ssl_connection(server_ssl, server_ctx); handle_regular_https_tunnel(client_sock, server_sock); return; } // Get the server's certificate X509* server_cert = SSL_get_peer_certificate(server_ssl); if (!server_cert) { fprintf(stderr, "Failed to get server certificate\n"); cleanup_ssl_connection(server_ssl, server_ctx); handle_regular_https_tunnel(client_sock, server_sock); return; } // Verify the certificate against DANE int dane_verified = verify_cert_against_dane(hostname, server_cert); // Cache the verification result add_to_cache(hostname, dane_verified > 0 ? 1 : 0); if (dane_verified <= 0) { fprintf(stderr, "DANE verification failed for %s - using direct tunneling\n", hostname); // Clean up and reconnect without interception X509_free(server_cert); cleanup_ssl_connection(server_ssl, server_ctx); // Create a new socket connection to the server close(server_sock); int new_server_sock = socket(AF_INET, SOCK_STREAM, 0); if (new_server_sock < 0) { perror("Cannot create new socket to server"); close(client_sock); return; } struct sockaddr_in server_addr; memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; server_addr.sin_addr.s_addr = inet_addr(ip_addr); server_addr.sin_port = htons(443); if (connect(new_server_sock, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) { perror("Cannot reconnect to server"); close(new_server_sock); close(client_sock); return; } // Use regular tunneling without interception handle_regular_https_tunnel(client_sock, new_server_sock); return; } printf("DANE verification successful for %s - generating trusted certificate\n", hostname); // Generate trusted certificate BEFORE responding to the client // This ensures the certificate is ready when the client tries to establish the SSL connection if (!generate_unique_cert(hostname, cert_path, key_path)) { fprintf(stderr, "Failed to generate trusted certificate for %s\n", hostname); X509_free(server_cert); cleanup_ssl_connection(server_ssl, server_ctx); handle_regular_https_tunnel(client_sock, server_sock); return; } // Initialize SSL context with our certificate ssl_context_t* client_ctx = init_ssl_context(cert_path, key_path); if (!client_ctx) { fprintf(stderr, "Failed to initialize SSL context\n"); X509_free(server_cert); cleanup_ssl_connection(server_ssl, server_ctx); handle_regular_https_tunnel(client_sock, server_sock); return; } // Pre-initialize the SSL object to ensure it's ready client_ctx->ssl = SSL_new(client_ctx->ctx); if (!client_ctx->ssl) { fprintf(stderr, "Failed to create SSL object\n"); X509_free(server_cert); SSL_CTX_free(client_ctx->ctx); free(client_ctx); cleanup_ssl_connection(server_ssl, server_ctx); handle_regular_https_tunnel(client_sock, server_sock); return; } // Only after everything is set up, send the Connection Established response const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n"; if (send(client_sock, success_response, strlen(success_response), 0) < 0) { perror("Failed to send connection established response"); X509_free(server_cert); cleanup_ssl_connection(server_ssl, server_ctx); SSL_free(client_ctx->ssl); SSL_CTX_free(client_ctx->ctx); free(client_ctx); return; } // Set up the file descriptor after sending the response SSL_set_fd(client_ctx->ssl, client_sock); // Add more detailed error information for SSL_accept ERR_clear_error(); // Clear any previous errors int accept_result = SSL_accept(client_ctx->ssl); if (accept_result <= 0) { int ssl_err = SSL_get_error(client_ctx->ssl, accept_result); fprintf(stderr, "SSL_accept failed for %s with error code: %d\n", hostname, ssl_err); print_ssl_errors("SSL_accept"); X509_free(server_cert); cleanup_ssl_connection(client_ctx->ssl, NULL); cleanup_ssl_connection(server_ssl, server_ctx); SSL_CTX_free(client_ctx->ctx); free(client_ctx); return; } printf("SSL connection successfully established with client for %s\n", hostname); // Tunnel data ssl_tunnel_data(client_ctx->ssl, server_ssl); // Add connection to pool for future reuse add_to_connection_pool(hostname, 443, server_sock, server_ssl, server_ctx); // Clean up X509_free(server_cert); cleanup_ssl_connection(client_ctx->ssl, NULL); SSL_CTX_free(client_ctx->ctx); free(client_ctx); } else { // No DANE records add_to_cache(hostname, 0); handle_regular_https_tunnel(client_sock, server_sock); } } // Modify start_proxy_server to use the worker thread pool int start_proxy_server(int port) { int server_sock; struct sockaddr_in server_addr, client_addr; socklen_t client_len = sizeof(client_addr); // Initialize proxy components including thread pool if (proxy_init() != 0) { fprintf(stderr, "Failed to initialize proxy server\n"); return 1; } // Initialize the thread pool init_thread_pool(); // Create periodic cleanup thread pthread_t cleanup_thread; if (pthread_create(&cleanup_thread, NULL, periodic_cleanup, NULL) != 0) { perror("Failed to create cleanup thread"); // Continue anyway, not critical } else { pthread_detach(cleanup_thread); } // Create socket and set up server as before server_sock = socket(AF_INET, SOCK_STREAM, 0); if (server_sock < 0) { perror("Cannot create socket"); return 1; } // Allow reuse of address int opt = 1; if (setsockopt(server_sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) { perror("setsockopt failed"); close(server_sock); return 1; } // Initialize server address memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; server_addr.sin_addr.s_addr = INADDR_ANY; server_addr.sin_port = htons(port); // Bind socket if (bind(server_sock, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) { perror("Bind failed"); close(server_sock); return 1; } // Listen for connections if (listen(server_sock, 10) < 0) { perror("Listen failed"); close(server_sock); return 1; } // Set up signal handler for graceful termination signal(SIGINT, handle_signal); printf("Proxy server listening on port %d\n", port); printf("Press Ctrl+C to stop the server\n"); // Accept connections and queue them for processing while (1) { int client_sock = accept(server_sock, (struct sockaddr*)&client_addr, &client_len); if (client_sock < 0) { perror("Accept failed"); continue; } printf("New connection from %s:%d\n", inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port)); // Queue client connection for processing by a worker thread queue_client_connection(client_sock); } // Cleanup close(server_sock); return 0; } // Initialize the proxy server int proxy_init() { // Create the certs directory if it doesn't exist mkdir("certs", 0755); // Initialize OpenSSL error strings SSL_load_error_strings(); ERR_load_crypto_strings(); // Initialize DANE support if (!dane_init()) { fprintf(stderr, "Failed to initialize DANE support\n"); return 1; } // Initialize DNS thread pool init_dns_pool(); // Initialize certificate generation thread pool init_cert_gen_pool(); // Generate a truly random initial seed for certificate serials unsigned char random_bytes[8]; if (RAND_bytes(random_bytes, sizeof(random_bytes)) == 1) { unsigned long long random_seed = 0; for (int i = 0; i < 8; i++) { random_seed = (random_seed << 8) | random_bytes[i]; } atomic_init(&next_serial, random_seed); } else { // Fallback to time-based initialization with more entropy struct timespec ts; clock_gettime(CLOCK_REALTIME, &ts); atomic_init(&next_serial, ((unsigned long long)ts.tv_sec << 32) | ts.tv_nsec); } // Don't pre-generate certificates, only generate for domains with DANE // pregen_common_certificates(); return 0; } // Clean up proxy resources void proxy_cleanup() { // Clean up caches free_cache(); free_dns_cache(); free_cert_serials(); // Clean up certificate generation thread pool cleanup_cert_gen_pool(); // Clean up DANE resources dane_cleanup(); } // Signal handler for graceful termination void handle_signal(int sig) { if (sig == SIGINT) { printf("\nShutting down proxy server...\n"); printf("Cleaning up cache and temporary certificates...\n"); // Clean up resources proxy_cleanup(); exit(0); } } // Update periodic_cleanup to also clean SSL sessions void* periodic_cleanup(void* arg) { (void)arg; // Unused parameter while (1) { // Sleep for 5 minutes sleep(300); // Clean up expired cache entries cleanup_cache(); cleanup_dns_cache(); cleanup_cert_serials(); cleanup_connection_pool(); cleanup_ssl_sessions(); // Add this call printf("Performed periodic cache cleanup\n"); } return NULL; } // Initialize DNS resolution thread pool void init_dns_pool(void) { // Initialize the task queue doh_queue.tasks = malloc(sizeof(dns_task*) * MAX_PARALLEL_DNS_QUERIES * 2); doh_queue.task_count = 0; doh_queue.task_capacity = MAX_PARALLEL_DNS_QUERIES * 2; doh_queue.running = 1; pthread_mutex_init(&doh_queue.mutex, NULL); pthread_cond_init(&doh_queue.not_empty, NULL); pthread_cond_init(&doh_queue.not_full, NULL); // Create worker threads for DNS resolution for (int i = 0; i < MAX_PARALLEL_DNS_QUERIES; i++) { if (pthread_create(&dns_threads[i], NULL, dns_worker, NULL) != 0) { perror("Failed to create DNS resolution thread"); // Continue anyway } } } // DNS resolution worker thread void* dns_worker(void* arg) { (void)arg; while (doh_queue.running) { dns_task* task = NULL; pthread_mutex_lock(&doh_queue.mutex); while (doh_queue.task_count == 0 && doh_queue.running) { pthread_cond_wait(&doh_queue.not_empty, &doh_queue.mutex); } if (!doh_queue.running) { pthread_mutex_unlock(&doh_queue.mutex); break; } // Get the task task = doh_queue.tasks[0]; // Remove task from queue doh_queue.task_count--; if (doh_queue.task_count > 0) { memmove(&doh_queue.tasks[0], &doh_queue.tasks[1], sizeof(dns_task*) * doh_queue.task_count); } pthread_cond_signal(&doh_queue.not_full); pthread_mutex_unlock(&doh_queue.mutex); // Perform DNS resolution task->result = resolve_doh(task->hostname, task->ip_addr, IP_ADDR_MAX_LEN); // Mark task as completed and signal waiting threads pthread_mutex_lock(&task->mutex); task->completed = 1; pthread_cond_signal(&task->cond); pthread_mutex_unlock(&task->mutex); } return NULL; } // Queue a DNS resolution task dns_task* queue_dns_resolution(const char* hostname) { // Create a new task dns_task* task = malloc(sizeof(dns_task)); if (!task) return NULL; strncpy(task->hostname, hostname, MAX_URL_LENGTH-1); task->hostname[MAX_URL_LENGTH-1] = '\0'; task->ip_addr[0] = '\0'; task->result = 0; task->completed = 0; pthread_mutex_init(&task->mutex, NULL); pthread_cond_init(&task->cond, NULL); // Add task to queue pthread_mutex_lock(&doh_queue.mutex); while (doh_queue.task_count >= doh_queue.task_capacity && doh_queue.running) { pthread_cond_wait(&doh_queue.not_full, &doh_queue.mutex); } if (!doh_queue.running) { pthread_mutex_unlock(&doh_queue.mutex); free(task); return NULL; } // Add the task to the queue doh_queue.tasks[doh_queue.task_count++] = task; // Signal that the queue is not empty pthread_cond_signal(&doh_queue.not_empty); pthread_mutex_unlock(&doh_queue.mutex); return task; } // Wait for DNS resolution to complete and get result int wait_for_dns_resolution(dns_task* task, char* ip_buffer, size_t buffer_size) { if (!task) return -1; pthread_mutex_lock(&task->mutex); while (!task->completed) { pthread_cond_wait(&task->cond, &task->mutex); } int result = task->result; pthread_mutex_unlock(&task->mutex); if (result == 0 && ip_buffer) { strncpy(ip_buffer, task->ip_addr, buffer_size-1); ip_buffer[buffer_size-1] = '\0'; } // Clean up task resources pthread_mutex_destroy(&task->mutex); pthread_cond_destroy(&task->cond); free(task); return result; } // Cleanup DNS resolution pool void cleanup_dns_pool() { // Signal shutdown pthread_mutex_lock(&doh_queue.mutex); doh_queue.running = 0; pthread_cond_broadcast(&doh_queue.not_empty); pthread_cond_broadcast(&doh_queue.not_full); pthread_mutex_unlock(&doh_queue.mutex); // Wait for threads to exit for (int i = 0; i < MAX_PARALLEL_DNS_QUERIES; i++) { pthread_join(dns_threads[i], NULL); } // Free any remaining tasks pthread_mutex_lock(&doh_queue.mutex); for (int i = 0; i < doh_queue.task_count; i++) { free(doh_queue.tasks[i]); } free(doh_queue.tasks); pthread_mutex_unlock(&doh_queue.mutex); // Destroy sync primitives pthread_mutex_destroy(&doh_queue.mutex); pthread_cond_destroy(&doh_queue.not_empty); pthread_cond_destroy(&doh_queue.not_full); } // Initialize certificate generation thread pool void init_cert_gen_pool(void) { // Initialize the task queue cert_queue.tasks = malloc(sizeof(cert_gen_task*) * THREAD_POOL_SIZE); cert_queue.task_count = 0; cert_queue.task_capacity = THREAD_POOL_SIZE; cert_queue.running = 1; pthread_mutex_init(&cert_queue.mutex, NULL); pthread_cond_init(&cert_queue.not_empty, NULL); pthread_cond_init(&cert_queue.not_full, NULL); // Create worker threads for certificate generation for (int i = 0; i < CERT_GEN_THREAD_POOL_SIZE; i++) { if (pthread_create(&cert_gen_threads[i], NULL, cert_gen_worker, NULL) != 0) { perror("Failed to create certificate generation thread"); // Continue anyway } } // Initialize the atomic serial counter atomic_init(&next_serial, ((unsigned long long)time(NULL) << 32)); } // Certificate generation worker thread void* cert_gen_worker(void* arg) { (void)arg; while (cert_queue.running) { cert_gen_task* task = NULL; pthread_mutex_lock(&cert_queue.mutex); while (cert_queue.task_count == 0 && cert_queue.running) { pthread_cond_wait(&cert_queue.not_empty, &cert_queue.mutex); } if (!cert_queue.running) { pthread_mutex_unlock(&cert_queue.mutex); break; } // Get the task task = cert_queue.tasks[0]; // Remove task from queue cert_queue.task_count--; if (cert_queue.task_count > 0) { memmove(&cert_queue.tasks[0], &cert_queue.tasks[1], sizeof(cert_gen_task*) * cert_queue.task_count); } pthread_cond_signal(&cert_queue.not_full); pthread_mutex_unlock(&cert_queue.mutex); // Generate the certificate unsigned long long serial = atomic_fetch_add(&next_serial, 1); // Make sure cert and key paths exist unlink(task->cert_path); unlink(task->key_path); // Use the serial for entropy srand((unsigned int)(serial & 0xFFFFFFFF)); // Generate the cert with a unique serial task->result = generate_trusted_cert(task->hostname, task->cert_path, task->key_path); // Mark task as completed and signal waiting threads pthread_mutex_lock(&task->mutex); task->completed = 1; pthread_cond_signal(&task->cond); pthread_mutex_unlock(&task->mutex); } return NULL; } // Queue a certificate generation task cert_gen_task* queue_cert_generation(const char* hostname, const char* cert_path, const char* key_path) { // Create a new task cert_gen_task* task = malloc(sizeof(cert_gen_task)); if (!task) return NULL; strncpy(task->hostname, hostname, MAX_URL_LENGTH-1); task->hostname[MAX_URL_LENGTH-1] = '\0'; strncpy(task->cert_path, cert_path, 255); task->cert_path[255] = '\0'; strncpy(task->key_path, key_path, 255); task->key_path[255] = '\0'; task->result = 0; task->completed = 0; pthread_mutex_init(&task->mutex, NULL); pthread_cond_init(&task->cond, NULL); // Add task to queue pthread_mutex_lock(&cert_queue.mutex); while (cert_queue.task_count >= cert_queue.task_capacity && cert_queue.running) { pthread_cond_wait(&cert_queue.not_full, &cert_queue.mutex); } if (!cert_queue.running) { pthread_mutex_unlock(&cert_queue.mutex); free(task); return NULL; } // Add the task to the queue cert_queue.tasks[cert_queue.task_count++] = task; // Signal that the queue is not empty pthread_cond_signal(&cert_queue.not_empty); pthread_mutex_unlock(&cert_queue.mutex); return task; } // Wait for certificate generation to complete and get result int wait_for_cert_generation(cert_gen_task* task) { if (!task) return 0; pthread_mutex_lock(&task->mutex); while (!task->completed) { pthread_cond_wait(&task->cond, &task->mutex); } int result = task->result; pthread_mutex_unlock(&task->mutex); // Clean up task resources pthread_mutex_destroy(&task->mutex); pthread_cond_destroy(&task->cond); free(task); return result; } // Cleanup certificate generation pool void cleanup_cert_gen_pool() { // Signal shutdown pthread_mutex_lock(&cert_queue.mutex); cert_queue.running = 0; pthread_cond_broadcast(&cert_queue.not_empty); pthread_cond_broadcast(&cert_queue.not_full); pthread_mutex_unlock(&cert_queue.mutex); // Wait for threads to exit for (int i = 0; i < CERT_GEN_THREAD_POOL_SIZE; i++) { pthread_join(cert_gen_threads[i], NULL); } // Free any remaining tasks pthread_mutex_lock(&cert_queue.mutex); for (int i = 0; i < cert_queue.task_count; i++) { free(cert_queue.tasks[i]); } free(cert_queue.tasks); pthread_mutex_unlock(&cert_queue.mutex); // Destroy sync primitives pthread_mutex_destroy(&cert_queue.mutex); pthread_cond_destroy(&cert_queue.not_empty); pthread_cond_destroy(&cert_queue.not_full); } // Pre-generate certificates for popular domains (call during init) void pregen_common_certificates() { // List of common domains that might be accessed const char* common_domains[] = { "www.google.com", "www.facebook.com", "www.youtube.com", "www.amazon.com", "www.twitter.com", "www.instagram.com", "www.linkedin.com", "www.reddit.com", "www.netflix.com", "www.github.com" // Add more common domains as needed }; int num_domains = sizeof(common_domains) / sizeof(common_domains[0]); int domains_to_pregen = num_domains < PRE_GEN_CERT_COUNT ? num_domains : PRE_GEN_CERT_COUNT; printf("Pre-generating certificates for %d common domains...\n", domains_to_pregen); // Create a thread to pre-generate certificates in the background pthread_t pregen_thread; pthread_create(&pregen_thread, NULL, pregen_cert_thread, (void*)(long)domains_to_pregen); pthread_detach(pregen_thread); } // Thread function for pre-generating certificates void* pregen_cert_thread(void* arg) { int count = (int)(long)arg; // List of common domains that might be accessed const char* common_domains[] = { "www.google.com", "www.facebook.com", "www.youtube.com", "www.amazon.com", "www.twitter.com", "www.instagram.com", "www.linkedin.com", "www.reddit.com", "www.netflix.com", "www.github.com" // Add more common domains as needed }; for (int i = 0; i < count; i++) { char cert_path[256]; char key_path[256]; snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", common_domains[i]); snprintf(key_path, sizeof(key_path), "certs/%s.key", common_domains[i]); // Don't regenerate if it exists and is recent if (access(cert_path, F_OK) == 0 && access(key_path, F_OK) == 0) { // Check if we need to renew if (!should_renew_cert(common_domains[i])) { printf("Certificate for %s already exists and is recent\n", common_domains[i]); continue; } } printf("Pre-generating certificate for %s\n", common_domains[i]); generate_unique_cert(common_domains[i], cert_path, key_path); // Sleep a bit to not overwhelm the system usleep(100000); // 100ms } printf("Finished pre-generating certificates\n"); return NULL; } // Regular HTTPS tunneling without interception void handle_regular_https_tunnel(int client_sock, int server_sock) { fd_set read_fds; char buffer[MAX_REQUEST_SIZE]; int max_fd = (client_sock > server_sock) ? client_sock : server_sock; // Send 200 Connection Established to the client const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n"; if (send(client_sock, success_response, strlen(success_response), 0) < 0) { perror("Failed to send connection established response"); return; } // Tunnel data between client and server while (1) { FD_ZERO(&read_fds); FD_SET(client_sock, &read_fds); FD_SET(server_sock, &read_fds); if (select(max_fd + 1, &read_fds, NULL, NULL, NULL) < 0) { perror("Select failed"); break; } if (FD_ISSET(client_sock, &read_fds)) { ssize_t bytes_received = recv(client_sock, buffer, sizeof(buffer), 0); if (bytes_received <= 0) break; if (send(server_sock, buffer, bytes_received, 0) <= 0) break; } if (FD_ISSET(server_sock, &read_fds)) { ssize_t bytes_received = recv(server_sock, buffer, sizeof(buffer), 0); if (bytes_received <= 0) break; if (send(client_sock, buffer, bytes_received, 0) <= 0) break; } } } // Forward data between SSL connections void ssl_tunnel_data(SSL* client_ssl, SSL* server_ssl) { fd_set read_fds; char buffer[MAX_REQUEST_SIZE]; int client_fd = SSL_get_fd(client_ssl); int server_fd = SSL_get_fd(server_ssl); int max_fd = (client_fd > server_fd) ? client_fd : server_fd; while (1) { FD_ZERO(&read_fds); FD_SET(client_fd, &read_fds); FD_SET(server_fd, &read_fds); if (select(max_fd + 1, &read_fds, NULL, NULL, NULL) < 0) { perror("Select failed"); break; } if (FD_ISSET(client_fd, &read_fds)) { int bytes_received = SSL_read(client_ssl, buffer, sizeof(buffer)); if (bytes_received <= 0) { int err = SSL_get_error(client_ssl, bytes_received); if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) { break; } } else { int bytes_sent = SSL_write(server_ssl, buffer, bytes_received); if (bytes_sent <= 0) { int err = SSL_get_error(server_ssl, bytes_sent); if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) { break; } } } } if (FD_ISSET(server_fd, &read_fds)) { int bytes_received = SSL_read(server_ssl, buffer, sizeof(buffer)); if (bytes_received <= 0) { int err = SSL_get_error(server_ssl, bytes_received); if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) { break; } } else { int bytes_sent = SSL_write(client_ssl, buffer, bytes_received); if (bytes_sent <= 0) { int err = SSL_get_error(client_ssl, bytes_sent); if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) { break; } } } } } } // Modify HTTP request for direct server communication char* rewrite_http_request(const char* original_request, size_t* new_length) { // Find the first line ending const char* first_line_end = strstr(original_request, "\r\n"); if (!first_line_end) { printf("Malformed HTTP request: no line ending found\n"); return NULL; } // Parse the request line char method[16] = {0}; char url[MAX_URL_LENGTH] = {0}; char version[16] = {0}; sscanf(original_request, "%15s %2047s %15s", method, url, version); printf("Original request: %s %s %s\n", method, url, version); // Extract path from URL (remove http://hostname) char* path = NULL; if (strncmp(url, "http://", 7) == 0) { path = strchr(url + 7, '/'); if (!path) { path = "/"; printf("No path in URL, using default: %s\n", path); } else { printf("Extracted path from URL: %s\n", path); } } else { path = url; // Already a path or malformed printf("Using URL as path: %s\n", path); } // Calculate new request size const char* headers_start = first_line_end + 2; // Skip \r\n size_t headers_length = strlen(headers_start); size_t request_line_length = strlen(method) + strlen(path) + strlen(version) + 4; // +4 for spaces and \r\n size_t total_length = request_line_length + headers_length; // Allocate memory for new request char* new_request = malloc(total_length + 1); if (!new_request) { printf("Failed to allocate memory for HTTP request rewrite\n"); return NULL; } // Write the new request int written = snprintf(new_request, total_length + 1, "%s %s %s\r\n%s", method, path, version, headers_start); if (written < 0 || (size_t)written > total_length) { printf("Failed to rewrite HTTP request\n"); free(new_request); return NULL; } printf("Rewritten request first line: %s %s %s\n", method, path, version); *new_length = written; return new_request; } // Optimized client handling void handle_client_optimized(int client_sock) { char request[MAX_REQUEST_SIZE]; ssize_t bytes_received = recv(client_sock, request, sizeof(request) - 1, 0); if (bytes_received <= 0) { close(client_sock); return; } request[bytes_received] = '\0'; // Check if this is an HTTPS CONNECT request int is_connect = (strncmp(request, "CONNECT ", 8) == 0); // Extract host from request char* host = extract_host(request); if (!host) { printf("Failed to extract host from request\n"); close(client_sock); return; } // Extract port from request int port = extract_port(request); if (port <= 0 || port > 65535) { port = is_connect ? 443 : 80; printf("Invalid port detected, using default port: %d\n", port); } printf("Proxying %s request to: %s (port %d)\n", is_connect ? "HTTPS" : "HTTP", host, port); // Try to resolve hostname from DNS cache first char ip_addr[IP_ADDR_MAX_LEN]; int dns_cache_result = check_dns_cache(host, ip_addr, sizeof(ip_addr)); if (dns_cache_result == 0) { printf("Using cached DNS for %s: %s\n", host, ip_addr); } else { // Resolve hostname using DoH if (resolve_doh(host, ip_addr, sizeof(ip_addr)) != 0) { printf("Failed to resolve hostname using DoH: %s\n", host); // Don't cache failed resolutions anymore close(client_sock); return; } printf("Resolved %s to %s\n", host, ip_addr); // Only cache successful resolutions add_to_dns_cache(host, ip_addr, 1); } // Connect to the target server struct sockaddr_in server_addr; int server_sock = socket(AF_INET, SOCK_STREAM, 0); if (server_sock < 0) { perror("Cannot create socket to server"); close(client_sock); return; } memset(&server_addr, 0, sizeof(server_addr)); server_addr.sin_family = AF_INET; server_addr.sin_addr.s_addr = inet_addr(ip_addr); server_addr.sin_port = htons(port); if (connect(server_sock, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) { perror("Cannot connect to server"); close(server_sock); close(client_sock); return; } if (is_connect) { // HTTPS: Handle CONNECT tunnel with DANE support handle_https_tunnel(client_sock, server_sock, host, ip_addr); } else { // HTTP: Rewrite the request and forward it printf("HTTP request received, rewriting for server...\n"); size_t new_length = 0; char* modified_request = rewrite_http_request(request, &new_length); if (!modified_request) { printf("Failed to rewrite HTTP request\n"); close(server_sock); close(client_sock); return; } printf("Forwarding modified HTTP request (%zu bytes) to server...\n", new_length); // Set timeouts for sending and receiving struct timeval tv; tv.tv_sec = 5; // 5 seconds timeout tv.tv_usec = 0; setsockopt(server_sock, SOL_SOCKET, SO_SNDTIMEO, (const char*)&tv, sizeof tv); setsockopt(server_sock, SOL_SOCKET, SO_RCVTIMEO, (const char*)&tv, sizeof tv); if (send(server_sock, modified_request, new_length, 0) < 0) { perror("Failed to send request to server"); free(modified_request); close(server_sock); close(client_sock); return; } free(modified_request); // Force-flush any pending data int flag = 1; setsockopt(server_sock, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, sizeof(int)); // Receive response from server and forward to client printf("Waiting for server response (timeout: 5s)...\n"); char buffer[MAX_REQUEST_SIZE]; fd_set read_fds; FD_ZERO(&read_fds); FD_SET(server_sock, &read_fds); // Use select with timeout to wait for response if (select(server_sock + 1, &read_fds, NULL, NULL, &tv) <= 0) { printf("Timeout or error waiting for server response\n"); close(server_sock); close(client_sock); return; } ssize_t bytes_received; while ((bytes_received = recv(server_sock, buffer, sizeof(buffer), 0)) > 0) { printf("Received %zd bytes from server\n", bytes_received); if (send(client_sock, buffer, bytes_received, 0) < 0) { perror("Failed to send response to client"); break; } } if (bytes_received < 0) { perror("Error receiving from server"); } else if (bytes_received == 0) { printf("Server closed connection normally\n"); } } printf("Connection closed: %s (port %d)\n", host, port); close(server_sock); close(client_sock); } // HTTPS tunnel with DANE verification, optimized version void handle_https_tunnel_optimized(int client_sock, int server_sock, const char* hostname, const char* ip_addr, SSL* existing_ssl, SSL_CTX* existing_ctx) { // Mark ip_addr as used to avoid warning (void)ip_addr; // Use existing SSL session if provided SSL* server_ssl = existing_ssl; SSL_CTX* server_ctx = existing_ctx; int need_ssl_setup = (server_ssl == NULL); // Check the DANE cache first int cached_result = check_cache(hostname); // Generate certificate paths char cert_path[256]; char key_path[256]; snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", hostname); snprintf(key_path, sizeof(key_path), "certs/%s.key", hostname); // If in cache and verified as successful, skip verification if (cached_result == 1) { // Cached DANE verification successful // Check if certificate exists and doesn't need renewal if (access(cert_path, F_OK) == 0 && access(key_path, F_OK) == 0 && !should_renew_cert(hostname)) { // Initialize SSL context with our certificate ssl_context_t* client_ctx = NULL; if (!get_cached_cert(hostname, &client_ctx)) { // Fall back to disk client_ctx = init_ssl_context(cert_path, key_path); if (client_ctx) { // Now that we've loaded from disk, cache it for next time cache_cert_in_memory(hostname, cert_path, key_path); } } if (!client_ctx) { // Fall back to regular tunnel handle_regular_https_tunnel(client_sock, server_sock); return; } // If we need to set up SSL, do it if (need_ssl_setup) { server_ctx = SSL_CTX_new(TLS_client_method()); if (!server_ctx) { SSL_CTX_free(client_ctx->ctx); free(client_ctx); handle_regular_https_tunnel(client_sock, server_sock); return; } server_ssl = SSL_new(server_ctx); SSL_set_fd(server_ssl, server_sock); SSL_set_tlsext_host_name(server_ssl, hostname); // Enable session caching for future connections SSL_CTX_set_session_cache_mode(server_ctx, SSL_SESS_CACHE_CLIENT); if (SSL_connect(server_ssl) <= 0) { cleanup_ssl_connection(server_ssl, server_ctx); SSL_CTX_free(client_ctx->ctx); free(client_ctx); handle_regular_https_tunnel(client_sock, server_sock); return; } } // Send 200 Connection Established to client const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n"; if (send(client_sock, success_response, strlen(success_response), 0) < 0) { if (!need_ssl_setup) { // Don't clean up if we're using a cached connection SSL_CTX_free(client_ctx->ctx); free(client_ctx); } else { cleanup_ssl_connection(server_ssl, server_ctx); SSL_CTX_free(client_ctx->ctx); free(client_ctx); } return; } // Set up client SSL client_ctx->ssl = SSL_new(client_ctx->ctx); SSL_set_fd(client_ctx->ssl, client_sock); if (SSL_accept(client_ctx->ssl) <= 0) { cleanup_ssl_connection(client_ctx->ssl, NULL); if (need_ssl_setup) { cleanup_ssl_connection(server_ssl, server_ctx); } SSL_CTX_free(client_ctx->ctx); free(client_ctx); return; } // Tunnel data ssl_tunnel_data(client_ctx->ssl, server_ssl); // Clean up cleanup_ssl_connection(client_ctx->ssl, NULL); if (need_ssl_setup) { // If we didn't use a pooled connection, add this one to the pool add_to_connection_pool(hostname, 443, server_sock, server_ssl, server_ctx); } SSL_CTX_free(client_ctx->ctx); free(client_ctx); return; } } else if (cached_result == 0) { // Cached DANE verification failed, use regular tunnel handle_regular_https_tunnel(client_sock, server_sock); return; } // Need to perform DANE verification // Clean up any existing SSL session since we need to validate from scratch if (!need_ssl_setup) { // Not using server_ssl/ctx directly as they belong to pool cleanup_ssl_connection(server_ssl, NULL); need_ssl_setup = 1; } // Set up a new SSL connection to server server_ctx = SSL_CTX_new(TLS_client_method()); if (!server_ctx) { handle_regular_https_tunnel(client_sock, server_sock); return; } server_ssl = SSL_new(server_ctx); SSL_set_fd(server_ssl, server_sock); SSL_set_tlsext_host_name(server_ssl, hostname); if (SSL_connect(server_ssl) <= 0) { cleanup_ssl_connection(server_ssl, server_ctx); handle_regular_https_tunnel(client_sock, server_sock); return; } // Check for DANE records int has_dane = is_dane_available(hostname); if (has_dane) { // Get server certificate and verify against DANE X509* server_cert = SSL_get_peer_certificate(server_ssl); if (!server_cert) { cleanup_ssl_connection(server_ssl, server_ctx); handle_regular_https_tunnel(client_sock, server_sock); return; } int dane_verified = verify_cert_against_dane(hostname, server_cert); add_to_cache(hostname, dane_verified > 0 ? 1 : 0); if (dane_verified <= 0) { // DANE verification failed X509_free(server_cert); cleanup_ssl_connection(server_ssl, server_ctx); handle_regular_https_tunnel(client_sock, server_sock); return; } // Generate trusted certificate if (!generate_unique_cert(hostname, cert_path, key_path)) { X509_free(server_cert); cleanup_ssl_connection(server_ssl, server_ctx); handle_regular_https_tunnel(client_sock, server_sock); return; } // Set up client connection with our certificate ssl_context_t* client_ctx = init_ssl_context(cert_path, key_path); if (!client_ctx) { X509_free(server_cert); cleanup_ssl_connection(server_ssl, server_ctx); handle_regular_https_tunnel(client_sock, server_sock); return; } // Send 200 Connection Established to client const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n"; if (send(client_sock, success_response, strlen(success_response), 0) < 0) { X509_free(server_cert); cleanup_ssl_connection(server_ssl, server_ctx); cleanup_ssl_connection(client_ctx->ssl, NULL); SSL_CTX_free(client_ctx->ctx); free(client_ctx); return; } // Setup client SSL client_ctx->ssl = SSL_new(client_ctx->ctx); SSL_set_fd(client_ctx->ssl, client_sock); if (SSL_accept(client_ctx->ssl) <= 0) { X509_free(server_cert); cleanup_ssl_connection(client_ctx->ssl, NULL); cleanup_ssl_connection(server_ssl, server_ctx); SSL_CTX_free(client_ctx->ctx); free(client_ctx); return; } // Tunnel data ssl_tunnel_data(client_ctx->ssl, server_ssl); // Add connection to pool for future reuse add_to_connection_pool(hostname, 443, server_sock, server_ssl, server_ctx); // Clean up X509_free(server_cert); cleanup_ssl_connection(client_ctx->ssl, NULL); SSL_CTX_free(client_ctx->ctx); free(client_ctx); } else { // No DANE records add_to_cache(hostname, 0); cleanup_ssl_connection(server_ssl, server_ctx); handle_regular_https_tunnel(client_sock, server_sock); } } SSL_SESSION* get_ssl_session(const char* hostname) { pthread_mutex_lock(&ssl_session_mutex); ssl_session_cache_entry* entry; HASH_FIND_STR(ssl_session_cache, hostname, entry); if (entry) { time_t now = time(NULL); // Check session age - 10 minute validity if (now - entry->timestamp > 600) { // Session too old, remove it SSL_SESSION_free(entry->session); HASH_DEL(ssl_session_cache, entry); free(entry); pthread_mutex_unlock(&ssl_session_mutex); return NULL; } // Return a reference to the cached session SSL_SESSION* session = entry->session; pthread_mutex_unlock(&ssl_session_mutex); return session; } pthread_mutex_unlock(&ssl_session_mutex); return NULL; } void cache_ssl_session(const char* hostname, SSL_SESSION* session) { if (!session) return; pthread_mutex_lock(&ssl_session_mutex); ssl_session_cache_entry* entry; HASH_FIND_STR(ssl_session_cache, hostname, entry); if (entry) { SSL_SESSION_free(entry->session); entry->session = session; entry->timestamp = time(NULL); } else { entry = malloc(sizeof(ssl_session_cache_entry)); if (entry) { strncpy(entry->hostname, hostname, MAX_URL_LENGTH-1); entry->hostname[MAX_URL_LENGTH-1] = '\0'; entry->session = session; entry->timestamp = time(NULL); HASH_ADD_STR(ssl_session_cache, hostname, entry); } } pthread_mutex_unlock(&ssl_session_mutex); } void cleanup_ssl_sessions() { pthread_mutex_lock(&ssl_session_mutex); ssl_session_cache_entry* entry, *tmp; time_t now = time(NULL); HASH_ITER(hh, ssl_session_cache, entry, tmp) { if (now - entry->timestamp > 600) { // 10 minute timeout SSL_SESSION_free(entry->session); HASH_DEL(ssl_session_cache, entry); free(entry); } } pthread_mutex_unlock(&ssl_session_mutex); } int get_cached_cert(const char* hostname, ssl_context_t** ctx) { pthread_mutex_lock(&cert_cache_mutex); cert_cache_entry* entry; HASH_FIND_STR(cert_memory_cache, hostname, entry); if (!entry) { pthread_mutex_unlock(&cert_cache_mutex); return 0; } // Check if certificate is still valid time_t now = time(NULL); if (now - entry->timestamp > CERT_RENEWAL_TIME) { pthread_mutex_unlock(&cert_cache_mutex); return 0; } // Create SSL context from memory *ctx = malloc(sizeof(ssl_context_t)); if (!*ctx) { pthread_mutex_unlock(&cert_cache_mutex); return 0; } (*ctx)->ctx = SSL_CTX_new(TLS_server_method()); if (!(*ctx)->ctx) { free(*ctx); *ctx = NULL; pthread_mutex_unlock(&cert_cache_mutex); return 0; } // Configure SSL context SSL_CTX_set_options((*ctx)->ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION); // Load cert and key from memory BIO* cbio = BIO_new_mem_buf(entry->cert_data, entry->cert_len); X509* cert = PEM_read_bio_X509(cbio, NULL, NULL, NULL); BIO_free(cbio); BIO* kbio = BIO_new_mem_buf(entry->key_data, entry->key_len); EVP_PKEY* key = PEM_read_bio_PrivateKey(kbio, NULL, NULL, NULL); BIO_free(kbio); if (!cert || !key || SSL_CTX_use_certificate((*ctx)->ctx, cert) <= 0 || SSL_CTX_use_PrivateKey((*ctx)->ctx, key) <= 0) { if (cert) X509_free(cert); if (key) EVP_PKEY_free(key); SSL_CTX_free((*ctx)->ctx); free(*ctx); *ctx = NULL; pthread_mutex_unlock(&cert_cache_mutex); return 0; } X509_free(cert); EVP_PKEY_free(key); pthread_mutex_unlock(&cert_cache_mutex); return 1; } void cache_cert_in_memory(const char* hostname, const char* cert_path, const char* key_path) { // Read certificate file FILE* cert_file = fopen(cert_path, "rb"); if (!cert_file) return; fseek(cert_file, 0, SEEK_END); size_t cert_size = ftell(cert_file); fseek(cert_file, 0, SEEK_SET); unsigned char* cert_data = malloc(cert_size); if (!cert_data) { fclose(cert_file); return; } if (fread(cert_data, 1, cert_size, cert_file) != cert_size) { fclose(cert_file); free(cert_data); return; } fclose(cert_file); // Read key file FILE* key_file = fopen(key_path, "rb"); if (!key_file) { free(cert_data); return; } fseek(key_file, 0, SEEK_END); size_t key_size = ftell(key_file); fseek(key_file, 0, SEEK_SET); unsigned char* key_data = malloc(key_size); if (!key_data) { fclose(key_file); free(cert_data); return; } if (fread(key_data, 1, key_size, key_file) != key_size) { fclose(key_file); free(cert_data); free(key_data); return; } fclose(key_file); // Add to cache pthread_mutex_lock(&cert_cache_mutex); cert_cache_entry* entry; HASH_FIND_STR(cert_memory_cache, hostname, entry); if (entry) { // Update existing entry free(entry->cert_data); free(entry->key_data); entry->cert_data = cert_data; entry->cert_len = cert_size; entry->key_data = key_data; entry->key_len = key_size; entry->timestamp = time(NULL); } else { // Create new entry entry = malloc(sizeof(cert_cache_entry)); if (entry) { strncpy(entry->hostname, hostname, MAX_URL_LENGTH-1); entry->hostname[MAX_URL_LENGTH-1] = '\0'; entry->cert_data = cert_data; entry->cert_len = cert_size; entry->key_data = key_data; entry->key_len = key_size; entry->timestamp = time(NULL); HASH_ADD_STR(cert_memory_cache, hostname, entry); } else { free(cert_data); free(key_data); } } pthread_mutex_unlock(&cert_cache_mutex); }