diff --git a/src/proxy.c b/src/proxy.c index 7658f93..84ea005 100644 --- a/src/proxy.c +++ b/src/proxy.c @@ -17,153 +17,478 @@ #include #include #include +#include // For better random number generation +#include // Add uthash library for hash tables +#include // For atomic operations #define MAX_REQUEST_SIZE 8192 #define MAX_URL_LENGTH 2048 #define THREAD_POOL_SIZE 20 -#define CACHE_EXPIRY_TIME (3600) // Cache entries expire after 1 hour -#define DNS_CACHE_EXPIRY_TIME (300) // DNS cache entries expire after 5 minutes +#define CACHE_EXPIRY_TIME (86400) // Increase to 24 hours +#define DNS_CACHE_EXPIRY_TIME (3600) // Increase to 1 hour #define IP_ADDR_MAX_LEN 46 // Maximum length for IPv6 address -#define CERT_RENEWAL_TIME (86400) // Renew certificates after 24 hours +#define CERT_RENEWAL_TIME (86400*7) // Renew certificates after 7 days +#define CONNECTION_POOL_SIZE 32 // Number of connections to keep in the pool +#define CONNECTION_IDLE_TIMEOUT 300 // 5 minutes +#define NUM_WORKER_THREADS 8 // Worker threads for handling connections +#define MAX_PARALLEL_DNS_QUERIES 8 // Number of parallel DNS queries +#define CERT_GEN_THREAD_POOL_SIZE 4 // Threads for certificate generation +#define PRE_GEN_CERT_COUNT 20 // Number of pre-generated certificates for domains with DANE + +// Forward declarations of structures +typedef struct dns_task dns_task; +typedef struct cert_gen_task cert_gen_task; // Function prototypes void* periodic_cleanup(void* arg); int generate_unique_cert(const char* hostname, const char* cert_path, const char* key_path); void cleanup_ssl_connection(SSL* ssl, SSL_CTX* ctx); +void* worker_thread(void* arg); +void init_thread_pool(void); +void queue_client_connection(int client_sock); +void handle_client_optimized(int client_sock); +void handle_https_tunnel_optimized(int client_sock, int server_sock, const char* hostname, + const char* ip_addr, SSL* existing_ssl, SSL_CTX* existing_ctx); +void init_dns_pool(void); +void* dns_worker(void* arg); +dns_task* queue_dns_resolution(const char* hostname); +int wait_for_dns_resolution(dns_task* task, char* ip_buffer, size_t buffer_size); +void cleanup_dns_pool(void); +void init_cert_gen_pool(void); +void* cert_gen_worker(void* arg); +cert_gen_task* queue_cert_generation(const char* hostname, const char* cert_path, const char* key_path); +int wait_for_cert_generation(cert_gen_task* task); +void cleanup_cert_gen_pool(void); +void pregen_common_certificates(void); +void* pregen_cert_thread(void* arg); -// DANE verification result cache entry -typedef struct dane_cache_entry { - char hostname[MAX_URL_LENGTH]; +// Add missing function prototypes for SSL and certificate caching +SSL_SESSION* get_ssl_session(const char* hostname); +void cache_ssl_session(const char* hostname, SSL_SESSION* session); +void cleanup_ssl_sessions(void); +int get_cached_cert(const char* hostname, ssl_context_t** ctx); +void cache_cert_in_memory(const char* hostname, const char* cert_path, const char* key_path); +void handle_regular_https_tunnel(int client_sock, int server_sock); +void ssl_tunnel_data(SSL* client_ssl, SSL* server_ssl); + +// Hash table entry for DANE verification cache +typedef struct { + char hostname[MAX_URL_LENGTH]; // key int verified; time_t timestamp; - struct dane_cache_entry* next; + UT_hash_handle hh; // makes this structure hashable } dane_cache_entry; -// DNS cache entry -typedef struct dns_cache_entry { - char hostname[MAX_URL_LENGTH]; +// Hash table entry for DNS cache +typedef struct { + char hostname[MAX_URL_LENGTH]; // key char ip_addr[IP_ADDR_MAX_LEN]; time_t timestamp; int is_valid; // 1 if valid IP, 0 if resolution failed - struct dns_cache_entry* next; + UT_hash_handle hh; // makes this structure hashable } dns_cache_entry; -// Certificate serial number tracking -typedef struct cert_serial_entry { - char hostname[MAX_URL_LENGTH]; +// Hash table entry for certificate serial tracking +typedef struct { + char hostname[MAX_URL_LENGTH]; // key time_t timestamp; - struct cert_serial_entry* next; + UT_hash_handle hh; // makes this structure hashable } cert_serial_entry; -// Global caches +// Connection pool entry +typedef struct { + char hostname[MAX_URL_LENGTH]; + int port; + int sock; + SSL* ssl; + SSL_CTX* ctx; + time_t last_used; + int in_use; + UT_hash_handle hh; +} connection_pool_entry; + +// Thread pool structures +typedef struct { + int client_sock; +} client_task; + +typedef struct { + client_task* tasks; + int task_count; + int task_capacity; + pthread_mutex_t mutex; + pthread_cond_t not_empty; + pthread_cond_t not_full; +} task_queue; + +// Parallel DNS resolution +struct dns_task { + char hostname[MAX_URL_LENGTH]; + char ip_addr[IP_ADDR_MAX_LEN]; + int result; + int completed; + pthread_mutex_t mutex; + pthread_cond_t cond; +}; + +typedef struct { + dns_task** tasks; + int task_count; + int task_capacity; + pthread_mutex_t mutex; + pthread_cond_t not_empty; + pthread_cond_t not_full; + int running; +} dns_queue; + +// Certificate generation thread pool +struct cert_gen_task { + char hostname[MAX_URL_LENGTH]; + char cert_path[256]; + char key_path[256]; + int result; + int completed; + pthread_mutex_t mutex; + pthread_cond_t cond; +}; + +typedef struct { + cert_gen_task** tasks; + int task_count; + int task_capacity; + pthread_mutex_t mutex; + pthread_cond_t not_empty; + pthread_cond_t not_full; + int running; +} cert_gen_queue; + +// Global caches - now using hash tables static dane_cache_entry* dane_cache = NULL; static dns_cache_entry* dns_cache = NULL; static cert_serial_entry* cert_serials = NULL; -static pthread_mutex_t cache_mutex = PTHREAD_MUTEX_INITIALIZER; -static pthread_mutex_t dns_cache_mutex = PTHREAD_MUTEX_INITIALIZER; -static pthread_mutex_t cert_mutex = PTHREAD_MUTEX_INITIALIZER; +static connection_pool_entry* conn_pool = NULL; + +// Use rwlocks instead of mutexes for better concurrency +static pthread_rwlock_t cache_rwlock = PTHREAD_RWLOCK_INITIALIZER; +static pthread_rwlock_t dns_cache_rwlock = PTHREAD_RWLOCK_INITIALIZER; +static pthread_rwlock_t cert_rwlock = PTHREAD_RWLOCK_INITIALIZER; +static pthread_mutex_t conn_pool_mutex = PTHREAD_MUTEX_INITIALIZER; + +// Thread pool state +static task_queue work_queue; +static pthread_t worker_threads[NUM_WORKER_THREADS]; +static int workers_running = 1; + +// DNS resolution thread pool state +static dns_queue doh_queue; +static pthread_t dns_threads[MAX_PARALLEL_DNS_QUERIES]; + +// Certificate generation thread pool state +static cert_gen_queue cert_queue; +static pthread_t cert_gen_threads[CERT_GEN_THREAD_POOL_SIZE]; + +// Atomic counter for certificate serials to avoid collisions +static atomic_ullong next_serial = 0; typedef struct { int client_sock; } thread_arg_t; +// Improved SSL session cache implementation +typedef struct { + char hostname[MAX_URL_LENGTH]; + SSL_SESSION* session; + time_t timestamp; + UT_hash_handle hh; +} ssl_session_cache_entry; + +static ssl_session_cache_entry* ssl_session_cache = NULL; +static pthread_mutex_t ssl_session_mutex = PTHREAD_MUTEX_INITIALIZER; + +// Certificate memory cache +typedef struct { + char hostname[MAX_URL_LENGTH]; + unsigned char* cert_data; + size_t cert_len; + unsigned char* key_data; + size_t key_len; + time_t timestamp; + UT_hash_handle hh; +} cert_cache_entry; + +static cert_cache_entry* cert_memory_cache = NULL; +static pthread_mutex_t cert_cache_mutex = PTHREAD_MUTEX_INITIALIZER; + +// Initialize the thread pool +void init_thread_pool(void) { + // Initialize the task queue + work_queue.tasks = malloc(sizeof(client_task) * THREAD_POOL_SIZE); + work_queue.task_count = 0; + work_queue.task_capacity = THREAD_POOL_SIZE; + pthread_mutex_init(&work_queue.mutex, NULL); + pthread_cond_init(&work_queue.not_empty, NULL); + pthread_cond_init(&work_queue.not_full, NULL); + + // Create worker threads + for (int i = 0; i < NUM_WORKER_THREADS; i++) { + if (pthread_create(&worker_threads[i], NULL, worker_thread, NULL) != 0) { + perror("Failed to create worker thread"); + exit(1); + } + } +} + +// Queue a client connection for processing +void queue_client_connection(int client_sock) { + pthread_mutex_lock(&work_queue.mutex); + + // Wait if the queue is full + while (work_queue.task_count >= work_queue.task_capacity && workers_running) { + pthread_cond_wait(&work_queue.not_full, &work_queue.mutex); + } + + if (!workers_running) { + pthread_mutex_unlock(&work_queue.mutex); + close(client_sock); + return; + } + + // Add the client socket to the queue + work_queue.tasks[work_queue.task_count].client_sock = client_sock; + work_queue.task_count++; + + // Signal that the queue is not empty + pthread_cond_signal(&work_queue.not_empty); + pthread_mutex_unlock(&work_queue.mutex); +} + +// Worker thread to process client connections +void* worker_thread(void* arg) { + (void)arg; // Unused parameter + + while (workers_running) { + pthread_mutex_lock(&work_queue.mutex); + + // Wait if the queue is empty + while (work_queue.task_count == 0 && workers_running) { + pthread_cond_wait(&work_queue.not_empty, &work_queue.mutex); + } + + if (!workers_running) { + pthread_mutex_unlock(&work_queue.mutex); + break; + } + + // Get a client socket from the queue + int client_sock = work_queue.tasks[0].client_sock; + + // Remove task from queue + work_queue.task_count--; + if (work_queue.task_count > 0) { + memmove(&work_queue.tasks[0], &work_queue.tasks[1], + sizeof(client_task) * work_queue.task_count); + } + + // Signal that the queue is not full + pthread_cond_signal(&work_queue.not_full); + pthread_mutex_unlock(&work_queue.mutex); + + // Process the client connection + handle_client_optimized(client_sock); + } + + return NULL; +} + +// Connection pooling +void add_to_connection_pool(const char* hostname, int port, int sock, SSL* ssl, SSL_CTX* ctx) { + pthread_mutex_lock(&conn_pool_mutex); + + // Check if we already have a connection for this host + connection_pool_entry* entry; + HASH_FIND_STR(conn_pool, hostname, entry); + + if (entry) { + // Close the existing connection + if (entry->ssl) SSL_free(entry->ssl); + if (entry->ctx) SSL_CTX_free(entry->ctx); + if (entry->sock > 0) close(entry->sock); + + // Update with new connection + entry->port = port; + entry->sock = sock; + entry->ssl = ssl; + entry->ctx = ctx; + entry->last_used = time(NULL); + entry->in_use = 0; + } else { + // Create a new entry + entry = malloc(sizeof(connection_pool_entry)); + if (entry) { + strncpy(entry->hostname, hostname, MAX_URL_LENGTH-1); + entry->hostname[MAX_URL_LENGTH-1] = '\0'; + entry->port = port; + entry->sock = sock; + entry->ssl = ssl; + entry->ctx = ctx; + entry->last_used = time(NULL); + entry->in_use = 0; + HASH_ADD_STR(conn_pool, hostname, entry); + } + } + + pthread_mutex_unlock(&conn_pool_mutex); +} + +int get_from_connection_pool(const char* hostname, int port, int* sock, SSL** ssl, SSL_CTX** ctx) { + pthread_mutex_lock(&conn_pool_mutex); + + connection_pool_entry* entry; + HASH_FIND_STR(conn_pool, hostname, entry); + + if (entry && entry->port == port && !entry->in_use) { + time_t now = time(NULL); + + // Check if connection has been idle too long + if (now - entry->last_used > CONNECTION_IDLE_TIMEOUT) { + // Connection too old, close it + if (entry->ssl) SSL_free(entry->ssl); + if (entry->ctx) SSL_CTX_free(entry->ctx); + if (entry->sock > 0) close(entry->sock); + HASH_DEL(conn_pool, entry); + free(entry); + pthread_mutex_unlock(&conn_pool_mutex); + return 0; + } + + // Connection is valid and available + *sock = entry->sock; + *ssl = entry->ssl; + *ctx = entry->ctx; + entry->in_use = 1; + entry->last_used = now; + + pthread_mutex_unlock(&conn_pool_mutex); + return 1; + } + + pthread_mutex_unlock(&conn_pool_mutex); + return 0; +} + +void release_connection(const char* hostname) { + pthread_mutex_lock(&conn_pool_mutex); + + connection_pool_entry* entry; + HASH_FIND_STR(conn_pool, hostname, entry); + + if (entry) { + entry->in_use = 0; + entry->last_used = time(NULL); + } + + pthread_mutex_unlock(&conn_pool_mutex); +} + +void cleanup_connection_pool() { + pthread_mutex_lock(&conn_pool_mutex); + + connection_pool_entry* entry, *tmp; + time_t now = time(NULL); + + HASH_ITER(hh, conn_pool, entry, tmp) { + if (now - entry->last_used > CONNECTION_IDLE_TIMEOUT || entry->sock <= 0) { + // Close the connection + if (entry->ssl) SSL_free(entry->ssl); + if (entry->ctx) SSL_CTX_free(entry->ctx); + if (entry->sock > 0) close(entry->sock); + + // Remove from hash table + HASH_DEL(conn_pool, entry); + free(entry); + } + } + + pthread_mutex_unlock(&conn_pool_mutex); +} + // Certificate serial tracking functions void add_cert_serial(const char* hostname) { - pthread_mutex_lock(&cert_mutex); + pthread_rwlock_wrlock(&cert_rwlock); // Check if entry already exists - cert_serial_entry* entry = cert_serials; - while (entry) { - if (strcmp(entry->hostname, hostname) == 0) { - // Update timestamp for existing entry + cert_serial_entry* entry; + HASH_FIND_STR(cert_serials, hostname, entry); + + if (entry) { + // Update timestamp for existing entry + entry->timestamp = time(NULL); + } else { + // Create new entry + entry = malloc(sizeof(cert_serial_entry)); + if (entry) { + strncpy(entry->hostname, hostname, MAX_URL_LENGTH-1); + entry->hostname[MAX_URL_LENGTH-1] = '\0'; entry->timestamp = time(NULL); - pthread_mutex_unlock(&cert_mutex); - return; + HASH_ADD_STR(cert_serials, hostname, entry); } - entry = entry->next; } - // Create new entry - cert_serial_entry* new_entry = malloc(sizeof(cert_serial_entry)); - if (new_entry) { - strncpy(new_entry->hostname, hostname, MAX_URL_LENGTH-1); - new_entry->hostname[MAX_URL_LENGTH-1] = '\0'; - new_entry->timestamp = time(NULL); - new_entry->next = cert_serials; - cert_serials = new_entry; - } - - pthread_mutex_unlock(&cert_mutex); + pthread_rwlock_unlock(&cert_rwlock); } int should_renew_cert(const char* hostname) { - pthread_mutex_lock(&cert_mutex); + pthread_rwlock_rdlock(&cert_rwlock); - cert_serial_entry* entry = cert_serials; - while (entry) { - if (strcmp(entry->hostname, hostname) == 0) { - time_t now = time(NULL); - int should_renew = (now - entry->timestamp > CERT_RENEWAL_TIME); - pthread_mutex_unlock(&cert_mutex); - return should_renew; - } - entry = entry->next; + cert_serial_entry* entry; + HASH_FIND_STR(cert_serials, hostname, entry); + + if (entry) { + time_t now = time(NULL); + int should_renew = (now - entry->timestamp > CERT_RENEWAL_TIME); + pthread_rwlock_unlock(&cert_rwlock); + return should_renew; } - pthread_mutex_unlock(&cert_mutex); + pthread_rwlock_unlock(&cert_rwlock); return 1; // No entry found, should generate } void cleanup_cert_serials() { - pthread_mutex_lock(&cert_mutex); + pthread_rwlock_wrlock(&cert_rwlock); - cert_serial_entry* entry = cert_serials; - cert_serial_entry* prev = NULL; + cert_serial_entry* entry, *tmp; time_t now = time(NULL); - while (entry) { + HASH_ITER(hh, cert_serials, entry, tmp) { if (now - entry->timestamp > CERT_RENEWAL_TIME) { - cert_serial_entry* to_free = entry; - - if (prev) { - prev->next = entry->next; - entry = entry->next; - } else { - cert_serials = entry->next; - entry = cert_serials; - } - - // Also remove the certificate files + // Remove the certificate files char cert_path[256]; char key_path[256]; - snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", to_free->hostname); - snprintf(key_path, sizeof(key_path), "certs/%s.key", to_free->hostname); - + snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", entry->hostname); + snprintf(key_path, sizeof(key_path), "certs/%s.key", entry->hostname); unlink(cert_path); // Ignore errors unlink(key_path); - free(to_free); - } else { - prev = entry; - entry = entry->next; + // Remove from hash table + HASH_DEL(cert_serials, entry); + free(entry); } } - pthread_mutex_unlock(&cert_mutex); + pthread_rwlock_unlock(&cert_rwlock); } void free_cert_serials() { - pthread_mutex_lock(&cert_mutex); + pthread_rwlock_wrlock(&cert_rwlock); - cert_serial_entry* entry = cert_serials; - while (entry) { - cert_serial_entry* next = entry->next; + cert_serial_entry* entry, *tmp; + HASH_ITER(hh, cert_serials, entry, tmp) { + HASH_DEL(cert_serials, entry); free(entry); - entry = next; } - cert_serials = NULL; - pthread_mutex_unlock(&cert_mutex); + pthread_rwlock_unlock(&cert_rwlock); } // Generate a certificate with a unique serial number @@ -172,20 +497,98 @@ int generate_unique_cert(const char* hostname, const char* cert_path, const char unlink(cert_path); unlink(key_path); - // Call the DANE library function with a timestamp-based random serial - unsigned long long serial = ((unsigned long long)time(NULL) << 32) | (rand() & 0xFFFFFFFF); + // Get a unique serial number using atomic operations with better entropy + struct timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); - // We need to update the generate_trusted_cert function to accept a serial parameter - // For now, we'll set a global variable that the function can access + // Generate a truly unique serial number combining multiple sources of entropy + unsigned long long timestamp = ((unsigned long long)ts.tv_sec << 32) | (ts.tv_nsec & 0xFFFFFFFF); - int result = generate_trusted_cert(hostname, cert_path, key_path); - - if (result) { - // Track this certificate generation - add_cert_serial(hostname); + // Add more randomness from OpenSSL + unsigned char random_bytes[16]; + if (RAND_bytes(random_bytes, sizeof(random_bytes)) != 1) { + // Fall back to less secure randomness if OpenSSL random generator fails + for (int i = 0; i < 16; i++) { + random_bytes[i] = rand() & 0xFF; + } } - return result; + // Combine all sources of randomness for the serial + unsigned long long random_component = 0; + for (int i = 0; i < 8; i++) { + random_component = (random_component << 8) | random_bytes[i]; + } + + // Create a unique serial using atomic counter, timestamp, and randomness + unsigned long long serial = atomic_fetch_add(&next_serial, 1) ^ timestamp ^ random_component; + + // Force serial number to be positive and non-zero + serial = (serial & 0x7FFFFFFFFFFFFFFFULL) | 0x0000010000000000ULL; + + printf("Generating certificate for %s with serial: %llu\n", hostname, serial); + + // Clear any previously cached certificates in Firefox + // Firefox caches certificates by [issuer DN + serial number] + pthread_mutex_lock(&cert_cache_mutex); + cert_cache_entry* entry; + HASH_FIND_STR(cert_memory_cache, hostname, entry); + if (entry) { + // Remove the cached certificate + free(entry->cert_data); + free(entry->key_data); + HASH_DEL(cert_memory_cache, entry); + free(entry); + printf("Removed cached certificate for %s to avoid serial conflicts\n", hostname); + } + pthread_mutex_unlock(&cert_cache_mutex); + + // Use the serial for entropy + srand((unsigned int)(serial & 0xFFFFFFFF)); + + // Set special environment variable to force unique serial in certificate generation + char serial_env[64]; + snprintf(serial_env, sizeof(serial_env), "%llu", serial); + setenv("CERT_SERIAL_OVERRIDE", serial_env, 1); + + // Queue certificate generation in the thread pool + cert_gen_task* task = queue_cert_generation(hostname, cert_path, key_path); + if (task) { + // Wait for the certificate to be generated + int result = wait_for_cert_generation(task); + if (result) { + // Track this certificate generation + add_cert_serial(hostname); + + // Cache cert and key in memory + cache_cert_in_memory(hostname, cert_path, key_path); + + // Verify the certificate was correctly generated + printf("Certificate successfully generated for %s\n", hostname); + } else { + fprintf(stderr, "Failed to generate certificate for %s\n", hostname); + } + + // Clear the environment variable + unsetenv("CERT_SERIAL_OVERRIDE"); + return result; + } else { + // Fall back to synchronous generation if queuing fails + printf("Using synchronous certificate generation for %s\n", hostname); + int result = generate_trusted_cert(hostname, cert_path, key_path); + if (result) { + add_cert_serial(hostname); + + // Cache cert and key in memory + cache_cert_in_memory(hostname, cert_path, key_path); + printf("Certificate successfully generated for %s (sync method)\n", hostname); + } else { + fprintf(stderr, "Failed to generate certificate for %s (sync method)\n", hostname); + } + + // Clear the environment variable + unsetenv("CERT_SERIAL_OVERRIDE"); + return result; + } } // SSL connection cleanup helper @@ -199,66 +602,53 @@ void cleanup_ssl_connection(SSL* ssl, SSL_CTX* ctx) { } } -// DNS cache functions -dns_cache_entry* find_dns_cache_entry(const char* hostname) { - dns_cache_entry* entry = dns_cache; - while (entry) { - if (strcmp(entry->hostname, hostname) == 0) { - return entry; - } - entry = entry->next; - } - return NULL; -} - +// DNS cache functions (hash table versions) void add_to_dns_cache(const char* hostname, const char* ip_addr, int is_valid) { - pthread_mutex_lock(&dns_cache_mutex); + pthread_rwlock_wrlock(&dns_cache_rwlock); // First check if entry already exists - dns_cache_entry* existing = find_dns_cache_entry(hostname); + dns_cache_entry* existing; + HASH_FIND_STR(dns_cache, hostname, existing); + if (existing) { // Update existing entry strncpy(existing->ip_addr, ip_addr, IP_ADDR_MAX_LEN-1); existing->ip_addr[IP_ADDR_MAX_LEN-1] = '\0'; existing->is_valid = is_valid; existing->timestamp = time(NULL); - pthread_mutex_unlock(&dns_cache_mutex); - return; + } else { + // Create new entry + dns_cache_entry* new_entry = malloc(sizeof(dns_cache_entry)); + if (new_entry) { + strncpy(new_entry->hostname, hostname, MAX_URL_LENGTH-1); + new_entry->hostname[MAX_URL_LENGTH-1] = '\0'; + strncpy(new_entry->ip_addr, ip_addr, IP_ADDR_MAX_LEN-1); + new_entry->ip_addr[IP_ADDR_MAX_LEN-1] = '\0'; + new_entry->is_valid = is_valid; + new_entry->timestamp = time(NULL); + HASH_ADD_STR(dns_cache, hostname, new_entry); + } } - // Create new entry - dns_cache_entry* new_entry = malloc(sizeof(dns_cache_entry)); - if (!new_entry) { - pthread_mutex_unlock(&dns_cache_mutex); - return; - } - - strncpy(new_entry->hostname, hostname, MAX_URL_LENGTH-1); - new_entry->hostname[MAX_URL_LENGTH-1] = '\0'; - strncpy(new_entry->ip_addr, ip_addr, IP_ADDR_MAX_LEN-1); - new_entry->ip_addr[IP_ADDR_MAX_LEN-1] = '\0'; - new_entry->is_valid = is_valid; - new_entry->timestamp = time(NULL); - new_entry->next = dns_cache; - dns_cache = new_entry; - - pthread_mutex_unlock(&dns_cache_mutex); + pthread_rwlock_unlock(&dns_cache_rwlock); } int check_dns_cache(const char* hostname, char* ip_buffer, size_t buffer_size) { - pthread_mutex_lock(&dns_cache_mutex); + pthread_rwlock_rdlock(&dns_cache_rwlock); + + dns_cache_entry* entry; + HASH_FIND_STR(dns_cache, hostname, entry); - dns_cache_entry* entry = find_dns_cache_entry(hostname); if (!entry) { - pthread_mutex_unlock(&dns_cache_mutex); + pthread_rwlock_unlock(&dns_cache_rwlock); return -1; // Not in cache } // Check if entry has expired time_t now = time(NULL); if (now - entry->timestamp > DNS_CACHE_EXPIRY_TIME) { - // Entry expired, remove from cache - pthread_mutex_unlock(&dns_cache_mutex); + // Entry expired + pthread_rwlock_unlock(&dns_cache_rwlock); return -1; } @@ -266,56 +656,41 @@ int check_dns_cache(const char* hostname, char* ip_buffer, size_t buffer_size) { if (entry->is_valid) { strncpy(ip_buffer, entry->ip_addr, buffer_size-1); ip_buffer[buffer_size-1] = '\0'; - pthread_mutex_unlock(&dns_cache_mutex); + pthread_rwlock_unlock(&dns_cache_rwlock); return 0; // Success } - pthread_mutex_unlock(&dns_cache_mutex); + pthread_rwlock_unlock(&dns_cache_rwlock); return 1; // Entry exists but is marked as invalid resolution } void cleanup_dns_cache() { - pthread_mutex_lock(&dns_cache_mutex); + pthread_rwlock_wrlock(&dns_cache_rwlock); - dns_cache_entry* entry = dns_cache; - dns_cache_entry* prev = NULL; + dns_cache_entry* entry, *tmp; time_t now = time(NULL); - // Remove expired entries - while (entry) { + HASH_ITER(hh, dns_cache, entry, tmp) { if (now - entry->timestamp > DNS_CACHE_EXPIRY_TIME) { - dns_cache_entry* to_free = entry; - - if (prev) { - prev->next = entry->next; - entry = entry->next; - } else { - dns_cache = entry->next; - entry = dns_cache; - } - - free(to_free); - } else { - prev = entry; - entry = entry->next; + // Remove from hash table + HASH_DEL(dns_cache, entry); + free(entry); } } - pthread_mutex_unlock(&dns_cache_mutex); + pthread_rwlock_unlock(&dns_cache_rwlock); } void free_dns_cache() { - pthread_mutex_lock(&dns_cache_mutex); + pthread_rwlock_wrlock(&dns_cache_rwlock); - dns_cache_entry* entry = dns_cache; - while (entry) { - dns_cache_entry* next = entry->next; + dns_cache_entry* entry, *tmp; + HASH_ITER(hh, dns_cache, entry, tmp) { + HASH_DEL(dns_cache, entry); free(entry); - entry = next; } - dns_cache = NULL; - pthread_mutex_unlock(&dns_cache_mutex); + pthread_rwlock_unlock(&dns_cache_rwlock); } // Helper function to print SSL errors @@ -331,117 +706,95 @@ void print_ssl_errors(const char* context) { // DANE cache functions dane_cache_entry* find_cache_entry(const char* hostname) { - dane_cache_entry* entry = dane_cache; - while (entry) { - if (strcmp(entry->hostname, hostname) == 0) { - return entry; - } - entry = entry->next; - } - return NULL; + dane_cache_entry* entry; + HASH_FIND_STR(dane_cache, hostname, entry); + return entry; } void add_to_cache(const char* hostname, int verified) { - pthread_mutex_lock(&cache_mutex); + pthread_rwlock_wrlock(&cache_rwlock); // First check if entry already exists - dane_cache_entry* existing = find_cache_entry(hostname); + dane_cache_entry* existing; + HASH_FIND_STR(dane_cache, hostname, existing); + if (existing) { // Update existing entry existing->verified = verified; existing->timestamp = time(NULL); - pthread_mutex_unlock(&cache_mutex); - return; + } else { + // Create new entry + dane_cache_entry* new_entry = malloc(sizeof(dane_cache_entry)); + if (new_entry) { + strncpy(new_entry->hostname, hostname, MAX_URL_LENGTH-1); + new_entry->hostname[MAX_URL_LENGTH-1] = '\0'; + new_entry->verified = verified; + new_entry->timestamp = time(NULL); + HASH_ADD_STR(dane_cache, hostname, new_entry); + } } - // Create new entry - dane_cache_entry* new_entry = malloc(sizeof(dane_cache_entry)); - if (!new_entry) { - pthread_mutex_unlock(&cache_mutex); - return; - } - - strncpy(new_entry->hostname, hostname, MAX_URL_LENGTH-1); - new_entry->hostname[MAX_URL_LENGTH-1] = '\0'; - new_entry->verified = verified; - new_entry->timestamp = time(NULL); - new_entry->next = dane_cache; - dane_cache = new_entry; - - pthread_mutex_unlock(&cache_mutex); + pthread_rwlock_unlock(&cache_rwlock); } int check_cache(const char* hostname) { - pthread_mutex_lock(&cache_mutex); + pthread_rwlock_rdlock(&cache_rwlock); + + dane_cache_entry* entry; + HASH_FIND_STR(dane_cache, hostname, entry); - dane_cache_entry* entry = find_cache_entry(hostname); if (!entry) { - pthread_mutex_unlock(&cache_mutex); + pthread_rwlock_unlock(&cache_rwlock); return -1; // Not in cache } // Check if entry has expired time_t now = time(NULL); if (now - entry->timestamp > CACHE_EXPIRY_TIME) { - // Entry expired, remove from cache - pthread_mutex_unlock(&cache_mutex); + // Entry expired + pthread_rwlock_unlock(&cache_rwlock); return -1; } int result = entry->verified; - pthread_mutex_unlock(&cache_mutex); + pthread_rwlock_unlock(&cache_rwlock); return result; } void cleanup_cache() { - pthread_mutex_lock(&cache_mutex); + pthread_rwlock_wrlock(&cache_rwlock); - dane_cache_entry* entry = dane_cache; - dane_cache_entry* prev = NULL; + dane_cache_entry* entry, *tmp; time_t now = time(NULL); - // Remove expired entries - while (entry) { + HASH_ITER(hh, dane_cache, entry, tmp) { if (now - entry->timestamp > CACHE_EXPIRY_TIME) { - dane_cache_entry* to_free = entry; - - if (prev) { - prev->next = entry->next; - entry = entry->next; - } else { - dane_cache = entry->next; - entry = dane_cache; - } - - free(to_free); - } else { - prev = entry; - entry = entry->next; + // Remove from hash table + HASH_DEL(dane_cache, entry); + free(entry); } } - pthread_mutex_unlock(&cache_mutex); + pthread_rwlock_unlock(&cache_rwlock); } void free_cache() { - pthread_mutex_lock(&cache_mutex); + pthread_rwlock_wrlock(&cache_rwlock); - dane_cache_entry* entry = dane_cache; - while (entry) { - dane_cache_entry* next = entry->next; + dane_cache_entry* entry, *tmp; + HASH_ITER(hh, dane_cache, entry, tmp) { + HASH_DEL(dane_cache, entry); free(entry); - entry = next; } - dane_cache = NULL; - pthread_mutex_unlock(&cache_mutex); + pthread_rwlock_unlock(&cache_rwlock); } // Extract hostname from HTTP request char* extract_host(const char* request) { static char host[MAX_URL_LENGTH]; const char* host_header = NULL; - + // Check if this is a CONNECT request for HTTPS if (strncmp(request, "CONNECT ", 8) == 0) { // Extract hostname from CONNECT line @@ -457,7 +810,6 @@ char* extract_host(const char* request) { // For regular HTTP requests, extract from Host header host_header = strstr(request, "Host: "); - if (host_header) { host_header += 6; // Skip "Host: " int i = 0; @@ -475,7 +827,7 @@ char* extract_host(const char* request) { // Extract port from HTTP request int extract_port(const char* request) { - // Default ports + // Default port int default_port = 80; // Check if this is a CONNECT request (likely HTTPS) @@ -488,7 +840,6 @@ int extract_port(const char* request) { int port = atoi(port_start + 1); return port > 0 ? port : default_port; } - return default_port; } @@ -558,13 +909,14 @@ void handle_https_tunnel(int client_sock, int server_sock, const char* hostname, // Unused parameter (void)ip_addr; + printf("Starting HTTPS tunnel for %s\n", hostname); + // Check the cache first int cached_result = check_cache(hostname); // Generate certificate paths char cert_path[256]; char key_path[256]; - snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", hostname); snprintf(key_path, sizeof(key_path), "certs/%s.key", hostname); @@ -572,90 +924,118 @@ void handle_https_tunnel(int client_sock, int server_sock, const char* hostname, if (cached_result == 1) { printf("Using cached DANE verification for %s (verified)\n", hostname); - // Check if certificate files exist and if we need to renew them - if (access(cert_path, F_OK) == 0 && access(key_path, F_OK) == 0 && !should_renew_cert(hostname)) { - // Initialize SSL context with our certificate - ssl_context_t* client_ctx = init_ssl_context(cert_path, key_path); - if (!client_ctx) { - // Fall back to regular tunnel if there's an issue with the certificate - handle_regular_https_tunnel(client_sock, server_sock); - return; + // Always regenerate certificate to avoid browser cache issues + printf("Regenerating certificate for %s to avoid browser certificate cache issues\n", hostname); + + // Generate trusted certificate BEFORE responding to the client + if (!generate_unique_cert(hostname, cert_path, key_path)) { + fprintf(stderr, "Failed to generate trusted certificate for %s\n", hostname); + handle_regular_https_tunnel(client_sock, server_sock); + return; + } + + // Initialize SSL context with our certificate + ssl_context_t* client_ctx = init_ssl_context(cert_path, key_path); + if (!client_ctx) { + fprintf(stderr, "Failed to initialize SSL context\n"); + handle_regular_https_tunnel(client_sock, server_sock); + return; + } + + // Connect to the server + SSL_CTX* server_ctx = SSL_CTX_new(TLS_client_method()); + if (!server_ctx) { + SSL_CTX_free(client_ctx->ctx); + free(client_ctx); + handle_regular_https_tunnel(client_sock, server_sock); + return; + } + + SSL* server_ssl = SSL_new(server_ctx); + SSL_set_fd(server_ssl, server_sock); + SSL_set_tlsext_host_name(server_ssl, hostname); + + // Add session reuse + SSL_SESSION* session = get_ssl_session(hostname); + if (session) { + SSL_set_session(server_ssl, session); + printf("Reusing SSL session for %s\n", hostname); + } + + if (SSL_connect(server_ssl) <= 0) { + print_ssl_errors("SSL_connect"); + cleanup_ssl_connection(server_ssl, server_ctx); + SSL_CTX_free(client_ctx->ctx); + free(client_ctx); + handle_regular_https_tunnel(client_sock, server_sock); + return; + } else { + // Cache the new session for future connections + SSL_SESSION* new_session = SSL_get1_session(server_ssl); + if (new_session) { + cache_ssl_session(hostname, new_session); } - - // Connect to the server - SSL_CTX* server_ctx = SSL_CTX_new(TLS_client_method()); - if (!server_ctx) { - SSL_CTX_free(client_ctx->ctx); - free(client_ctx); - handle_regular_https_tunnel(client_sock, server_sock); - return; - } - - SSL* server_ssl = SSL_new(server_ctx); - SSL_set_fd(server_ssl, server_sock); - SSL_set_tlsext_host_name(server_ssl, hostname); - - if (SSL_connect(server_ssl) <= 0) { - print_ssl_errors("SSL_connect"); - cleanup_ssl_connection(server_ssl, server_ctx); - SSL_CTX_free(client_ctx->ctx); - free(client_ctx); - handle_regular_https_tunnel(client_sock, server_sock); - return; - } - - // Send 200 Connection Established to the client - const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n"; - if (send(client_sock, success_response, strlen(success_response), 0) < 0) { - perror("Failed to send connection established response"); - cleanup_ssl_connection(server_ssl, server_ctx); - SSL_CTX_free(client_ctx->ctx); - free(client_ctx); - return; - } - - // Initialize SSL connection with the client - client_ctx->ssl = SSL_new(client_ctx->ctx); - SSL_set_fd(client_ctx->ssl, client_sock); - - // Clear any previous errors - ERR_clear_error(); - - int accept_result = SSL_accept(client_ctx->ssl); - if (accept_result <= 0) { - int ssl_err = SSL_get_error(client_ctx->ssl, accept_result); - fprintf(stderr, "SSL accept failed with error code: %d\n", ssl_err); - print_ssl_errors("SSL_accept"); - - // Try to determine if client closed the connection - if (ssl_err == SSL_ERROR_SYSCALL && errno == 0) { - fprintf(stderr, "Client may have closed the connection\n"); - } - - cleanup_ssl_connection(client_ctx->ssl, NULL); - cleanup_ssl_connection(server_ssl, server_ctx); - SSL_CTX_free(client_ctx->ctx); - free(client_ctx); - return; - } - - printf("SSL connection established with client for %s (using cached verification)\n", hostname); - - // Forward data between them - ssl_tunnel_data(client_ctx->ssl, server_ssl); - - // Clean up + } + + // Pre-initialize the SSL object to ensure it's ready + client_ctx->ssl = SSL_new(client_ctx->ctx); + if (!client_ctx->ssl) { + fprintf(stderr, "Failed to create SSL object\n"); + cleanup_ssl_connection(server_ssl, server_ctx); + SSL_CTX_free(client_ctx->ctx); + free(client_ctx); + handle_regular_https_tunnel(client_sock, server_sock); + return; + } + + // Set up the file descriptor + SSL_set_fd(client_ctx->ssl, client_sock); + + // Send 200 Connection Established to the client + const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n"; + if (send(client_sock, success_response, strlen(success_response), 0) < 0) { + perror("Failed to send connection established response"); cleanup_ssl_connection(client_ctx->ssl, NULL); cleanup_ssl_connection(server_ssl, server_ctx); SSL_CTX_free(client_ctx->ctx); free(client_ctx); return; - } else { - // Certificate doesn't exist or needs renewal, generate a new one - printf("Certificate for %s needs to be generated or renewed\n", hostname); } + + // Clear any previous errors + ERR_clear_error(); + + // Accept SSL connection from client + int accept_result = SSL_accept(client_ctx->ssl); + if (accept_result <= 0) { + int ssl_err = SSL_get_error(client_ctx->ssl, accept_result); + fprintf(stderr, "SSL accept failed with error code: %d\n", ssl_err); + print_ssl_errors("SSL_accept"); + + // Try to determine if client closed the connection + if (ssl_err == SSL_ERROR_SYSCALL && errno == 0) { + fprintf(stderr, "Client may have closed the connection\n"); + } + + cleanup_ssl_connection(client_ctx->ssl, NULL); + cleanup_ssl_connection(server_ssl, server_ctx); + SSL_CTX_free(client_ctx->ctx); + free(client_ctx); + return; + } + + printf("SSL connection established with client for %s (using cached verification)\n", hostname); + + // Forward data between them + ssl_tunnel_data(client_ctx->ssl, server_ssl); + + // Clean up + cleanup_ssl_connection(client_ctx->ssl, NULL); + cleanup_ssl_connection(server_ssl, server_ctx); + SSL_CTX_free(client_ctx->ctx); + free(client_ctx); + return; } else if (cached_result == 0) { - // Previously verified as failed, use regular tunnel printf("Using cached DANE verification for %s (not verified)\n", hostname); handle_regular_https_tunnel(client_sock, server_sock); return; @@ -682,7 +1062,6 @@ void handle_https_tunnel(int client_sock, int server_sock, const char* hostname, // Set SNI hostname SSL_set_tlsext_host_name(server_ssl, hostname); - // Connect to the server with SSL if (SSL_connect(server_ssl) <= 0) { fprintf(stderr, "SSL connection to server failed\n"); cleanup_ssl_connection(server_ssl, server_ctx); @@ -705,16 +1084,15 @@ void handle_https_tunnel(int client_sock, int server_sock, const char* hostname, // Cache the verification result add_to_cache(hostname, dane_verified > 0 ? 1 : 0); - // If DANE verification fails, don't generate our own certificate if (dane_verified <= 0) { fprintf(stderr, "DANE verification failed for %s - using direct tunneling\n", hostname); + + // Clean up and reconnect without interception X509_free(server_cert); cleanup_ssl_connection(server_ssl, server_ctx); - // Clean up and reconnect without interception - close(server_sock); - // Create a new socket connection to the server + close(server_sock); int new_server_sock = socket(AF_INET, SOCK_STREAM, 0); if (new_server_sock < 0) { perror("Cannot create new socket to server"); @@ -742,7 +1120,8 @@ void handle_https_tunnel(int client_sock, int server_sock, const char* hostname, printf("DANE verification successful for %s - generating trusted certificate\n", hostname); - // Generate a trusted certificate with a unique serial number + // Generate trusted certificate BEFORE responding to the client + // This ensures the certificate is ready when the client tries to establish the SSL connection if (!generate_unique_cert(hostname, cert_path, key_path)) { fprintf(stderr, "Failed to generate trusted certificate for %s\n", hostname); X509_free(server_cert); @@ -761,54 +1140,619 @@ void handle_https_tunnel(int client_sock, int server_sock, const char* hostname, return; } - // Send 200 Connection Established to the client + // Pre-initialize the SSL object to ensure it's ready + client_ctx->ssl = SSL_new(client_ctx->ctx); + if (!client_ctx->ssl) { + fprintf(stderr, "Failed to create SSL object\n"); + X509_free(server_cert); + SSL_CTX_free(client_ctx->ctx); + free(client_ctx); + cleanup_ssl_connection(server_ssl, server_ctx); + handle_regular_https_tunnel(client_sock, server_sock); + return; + } + + // Only after everything is set up, send the Connection Established response const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n"; if (send(client_sock, success_response, strlen(success_response), 0) < 0) { perror("Failed to send connection established response"); X509_free(server_cert); cleanup_ssl_connection(server_ssl, server_ctx); + SSL_free(client_ctx->ssl); SSL_CTX_free(client_ctx->ctx); free(client_ctx); return; } - // Initialize SSL connection with the client - client_ctx->ssl = SSL_new(client_ctx->ctx); + // Set up the file descriptor after sending the response SSL_set_fd(client_ctx->ssl, client_sock); - // Clear any previous errors - ERR_clear_error(); + // Add more detailed error information for SSL_accept + ERR_clear_error(); // Clear any previous errors + int accept_result = SSL_accept(client_ctx->ssl); - if (SSL_accept(client_ctx->ssl) <= 0) { - fprintf(stderr, "SSL accept failed\n"); + if (accept_result <= 0) { + int ssl_err = SSL_get_error(client_ctx->ssl, accept_result); + fprintf(stderr, "SSL_accept failed for %s with error code: %d\n", hostname, ssl_err); print_ssl_errors("SSL_accept"); - cleanup_ssl_connection(client_ctx->ssl, NULL); + X509_free(server_cert); + cleanup_ssl_connection(client_ctx->ssl, NULL); cleanup_ssl_connection(server_ssl, server_ctx); SSL_CTX_free(client_ctx->ctx); free(client_ctx); return; } - printf("SSL connection established with client for %s\n", hostname); + printf("SSL connection successfully established with client for %s\n", hostname); - // Now we have SSL connections to both client and server - // We can forward data between them + // Tunnel data ssl_tunnel_data(client_ctx->ssl, server_ssl); + // Add connection to pool for future reuse + add_to_connection_pool(hostname, 443, server_sock, server_ssl, server_ctx); + // Clean up - cleanup_ssl_connection(client_ctx->ssl, NULL); X509_free(server_cert); - cleanup_ssl_connection(server_ssl, server_ctx); + cleanup_ssl_connection(client_ctx->ssl, NULL); SSL_CTX_free(client_ctx->ctx); free(client_ctx); } else { - // No DANE records, use regular tunneling - add_to_cache(hostname, 0); // Cache the lack of DANE records + // No DANE records + add_to_cache(hostname, 0); handle_regular_https_tunnel(client_sock, server_sock); } } +// Modify start_proxy_server to use the worker thread pool +int start_proxy_server(int port) { + int server_sock; + struct sockaddr_in server_addr, client_addr; + socklen_t client_len = sizeof(client_addr); + + // Initialize proxy components including thread pool + if (proxy_init() != 0) { + fprintf(stderr, "Failed to initialize proxy server\n"); + return 1; + } + + // Initialize the thread pool + init_thread_pool(); + + // Create periodic cleanup thread + pthread_t cleanup_thread; + if (pthread_create(&cleanup_thread, NULL, periodic_cleanup, NULL) != 0) { + perror("Failed to create cleanup thread"); + // Continue anyway, not critical + } else { + pthread_detach(cleanup_thread); + } + + // Create socket and set up server as before + server_sock = socket(AF_INET, SOCK_STREAM, 0); + if (server_sock < 0) { + perror("Cannot create socket"); + return 1; + } + + // Allow reuse of address + int opt = 1; + if (setsockopt(server_sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) { + perror("setsockopt failed"); + close(server_sock); + return 1; + } + + // Initialize server address + memset(&server_addr, 0, sizeof(server_addr)); + server_addr.sin_family = AF_INET; + server_addr.sin_addr.s_addr = INADDR_ANY; + server_addr.sin_port = htons(port); + + // Bind socket + if (bind(server_sock, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) { + perror("Bind failed"); + close(server_sock); + return 1; + } + + // Listen for connections + if (listen(server_sock, 10) < 0) { + perror("Listen failed"); + close(server_sock); + return 1; + } + + // Set up signal handler for graceful termination + signal(SIGINT, handle_signal); + + printf("Proxy server listening on port %d\n", port); + printf("Press Ctrl+C to stop the server\n"); + + // Accept connections and queue them for processing + while (1) { + int client_sock = accept(server_sock, (struct sockaddr*)&client_addr, &client_len); + if (client_sock < 0) { + perror("Accept failed"); + continue; + } + + printf("New connection from %s:%d\n", + inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port)); + + // Queue client connection for processing by a worker thread + queue_client_connection(client_sock); + } + + // Cleanup + close(server_sock); + return 0; +} + +// Initialize the proxy server +int proxy_init() { + // Create the certs directory if it doesn't exist + mkdir("certs", 0755); + + // Initialize OpenSSL error strings + SSL_load_error_strings(); + ERR_load_crypto_strings(); + + // Initialize DANE support + if (!dane_init()) { + fprintf(stderr, "Failed to initialize DANE support\n"); + return 1; + } + + // Initialize DNS thread pool + init_dns_pool(); + + // Initialize certificate generation thread pool + init_cert_gen_pool(); + + // Generate a truly random initial seed for certificate serials + unsigned char random_bytes[8]; + if (RAND_bytes(random_bytes, sizeof(random_bytes)) == 1) { + unsigned long long random_seed = 0; + for (int i = 0; i < 8; i++) { + random_seed = (random_seed << 8) | random_bytes[i]; + } + atomic_init(&next_serial, random_seed); + } else { + // Fallback to time-based initialization with more entropy + struct timespec ts; + clock_gettime(CLOCK_REALTIME, &ts); + atomic_init(&next_serial, ((unsigned long long)ts.tv_sec << 32) | ts.tv_nsec); + } + + // Don't pre-generate certificates, only generate for domains with DANE + // pregen_common_certificates(); + + return 0; +} + +// Clean up proxy resources +void proxy_cleanup() { + // Clean up caches + free_cache(); + free_dns_cache(); + free_cert_serials(); + + // Clean up certificate generation thread pool + cleanup_cert_gen_pool(); + + // Clean up DANE resources + dane_cleanup(); +} + +// Signal handler for graceful termination +void handle_signal(int sig) { + if (sig == SIGINT) { + printf("\nShutting down proxy server...\n"); + printf("Cleaning up cache and temporary certificates...\n"); + + // Clean up resources + proxy_cleanup(); + + exit(0); + } +} + +// Update periodic_cleanup to also clean SSL sessions +void* periodic_cleanup(void* arg) { + (void)arg; // Unused parameter + + while (1) { + // Sleep for 5 minutes + sleep(300); + + // Clean up expired cache entries + cleanup_cache(); + cleanup_dns_cache(); + cleanup_cert_serials(); + cleanup_connection_pool(); + cleanup_ssl_sessions(); // Add this call + + printf("Performed periodic cache cleanup\n"); + } + + return NULL; +} + +// Initialize DNS resolution thread pool +void init_dns_pool(void) { + // Initialize the task queue + doh_queue.tasks = malloc(sizeof(dns_task*) * MAX_PARALLEL_DNS_QUERIES * 2); + doh_queue.task_count = 0; + doh_queue.task_capacity = MAX_PARALLEL_DNS_QUERIES * 2; + doh_queue.running = 1; + pthread_mutex_init(&doh_queue.mutex, NULL); + pthread_cond_init(&doh_queue.not_empty, NULL); + pthread_cond_init(&doh_queue.not_full, NULL); + + // Create worker threads for DNS resolution + for (int i = 0; i < MAX_PARALLEL_DNS_QUERIES; i++) { + if (pthread_create(&dns_threads[i], NULL, dns_worker, NULL) != 0) { + perror("Failed to create DNS resolution thread"); + // Continue anyway + } + } +} + +// DNS resolution worker thread +void* dns_worker(void* arg) { + (void)arg; + + while (doh_queue.running) { + dns_task* task = NULL; + + pthread_mutex_lock(&doh_queue.mutex); + while (doh_queue.task_count == 0 && doh_queue.running) { + pthread_cond_wait(&doh_queue.not_empty, &doh_queue.mutex); + } + + if (!doh_queue.running) { + pthread_mutex_unlock(&doh_queue.mutex); + break; + } + + // Get the task + task = doh_queue.tasks[0]; + + // Remove task from queue + doh_queue.task_count--; + if (doh_queue.task_count > 0) { + memmove(&doh_queue.tasks[0], &doh_queue.tasks[1], + sizeof(dns_task*) * doh_queue.task_count); + } + + pthread_cond_signal(&doh_queue.not_full); + pthread_mutex_unlock(&doh_queue.mutex); + + // Perform DNS resolution + task->result = resolve_doh(task->hostname, task->ip_addr, IP_ADDR_MAX_LEN); + + // Mark task as completed and signal waiting threads + pthread_mutex_lock(&task->mutex); + task->completed = 1; + pthread_cond_signal(&task->cond); + pthread_mutex_unlock(&task->mutex); + } + + return NULL; +} + +// Queue a DNS resolution task +dns_task* queue_dns_resolution(const char* hostname) { + // Create a new task + dns_task* task = malloc(sizeof(dns_task)); + if (!task) return NULL; + + strncpy(task->hostname, hostname, MAX_URL_LENGTH-1); + task->hostname[MAX_URL_LENGTH-1] = '\0'; + task->ip_addr[0] = '\0'; + task->result = 0; + task->completed = 0; + pthread_mutex_init(&task->mutex, NULL); + pthread_cond_init(&task->cond, NULL); + + // Add task to queue + pthread_mutex_lock(&doh_queue.mutex); + while (doh_queue.task_count >= doh_queue.task_capacity && doh_queue.running) { + pthread_cond_wait(&doh_queue.not_full, &doh_queue.mutex); + } + + if (!doh_queue.running) { + pthread_mutex_unlock(&doh_queue.mutex); + free(task); + return NULL; + } + + // Add the task to the queue + doh_queue.tasks[doh_queue.task_count++] = task; + + // Signal that the queue is not empty + pthread_cond_signal(&doh_queue.not_empty); + pthread_mutex_unlock(&doh_queue.mutex); + + return task; +} + +// Wait for DNS resolution to complete and get result +int wait_for_dns_resolution(dns_task* task, char* ip_buffer, size_t buffer_size) { + if (!task) return -1; + + pthread_mutex_lock(&task->mutex); + while (!task->completed) { + pthread_cond_wait(&task->cond, &task->mutex); + } + int result = task->result; + pthread_mutex_unlock(&task->mutex); + + if (result == 0 && ip_buffer) { + strncpy(ip_buffer, task->ip_addr, buffer_size-1); + ip_buffer[buffer_size-1] = '\0'; + } + + // Clean up task resources + pthread_mutex_destroy(&task->mutex); + pthread_cond_destroy(&task->cond); + free(task); + + return result; +} + +// Cleanup DNS resolution pool +void cleanup_dns_pool() { + // Signal shutdown + pthread_mutex_lock(&doh_queue.mutex); + doh_queue.running = 0; + pthread_cond_broadcast(&doh_queue.not_empty); + pthread_cond_broadcast(&doh_queue.not_full); + pthread_mutex_unlock(&doh_queue.mutex); + + // Wait for threads to exit + for (int i = 0; i < MAX_PARALLEL_DNS_QUERIES; i++) { + pthread_join(dns_threads[i], NULL); + } + + // Free any remaining tasks + pthread_mutex_lock(&doh_queue.mutex); + for (int i = 0; i < doh_queue.task_count; i++) { + free(doh_queue.tasks[i]); + } + free(doh_queue.tasks); + pthread_mutex_unlock(&doh_queue.mutex); + + // Destroy sync primitives + pthread_mutex_destroy(&doh_queue.mutex); + pthread_cond_destroy(&doh_queue.not_empty); + pthread_cond_destroy(&doh_queue.not_full); +} + +// Initialize certificate generation thread pool +void init_cert_gen_pool(void) { + // Initialize the task queue + cert_queue.tasks = malloc(sizeof(cert_gen_task*) * THREAD_POOL_SIZE); + cert_queue.task_count = 0; + cert_queue.task_capacity = THREAD_POOL_SIZE; + cert_queue.running = 1; + pthread_mutex_init(&cert_queue.mutex, NULL); + pthread_cond_init(&cert_queue.not_empty, NULL); + pthread_cond_init(&cert_queue.not_full, NULL); + + // Create worker threads for certificate generation + for (int i = 0; i < CERT_GEN_THREAD_POOL_SIZE; i++) { + if (pthread_create(&cert_gen_threads[i], NULL, cert_gen_worker, NULL) != 0) { + perror("Failed to create certificate generation thread"); + // Continue anyway + } + } + + // Initialize the atomic serial counter + atomic_init(&next_serial, ((unsigned long long)time(NULL) << 32)); +} + +// Certificate generation worker thread +void* cert_gen_worker(void* arg) { + (void)arg; + + while (cert_queue.running) { + cert_gen_task* task = NULL; + + pthread_mutex_lock(&cert_queue.mutex); + while (cert_queue.task_count == 0 && cert_queue.running) { + pthread_cond_wait(&cert_queue.not_empty, &cert_queue.mutex); + } + + if (!cert_queue.running) { + pthread_mutex_unlock(&cert_queue.mutex); + break; + } + + // Get the task + task = cert_queue.tasks[0]; + + // Remove task from queue + cert_queue.task_count--; + if (cert_queue.task_count > 0) { + memmove(&cert_queue.tasks[0], &cert_queue.tasks[1], + sizeof(cert_gen_task*) * cert_queue.task_count); + } + + pthread_cond_signal(&cert_queue.not_full); + pthread_mutex_unlock(&cert_queue.mutex); + + // Generate the certificate + unsigned long long serial = atomic_fetch_add(&next_serial, 1); + + // Make sure cert and key paths exist + unlink(task->cert_path); + unlink(task->key_path); + + // Use the serial for entropy + srand((unsigned int)(serial & 0xFFFFFFFF)); + + // Generate the cert with a unique serial + task->result = generate_trusted_cert(task->hostname, task->cert_path, task->key_path); + + // Mark task as completed and signal waiting threads + pthread_mutex_lock(&task->mutex); + task->completed = 1; + pthread_cond_signal(&task->cond); + pthread_mutex_unlock(&task->mutex); + } + + return NULL; +} + +// Queue a certificate generation task +cert_gen_task* queue_cert_generation(const char* hostname, const char* cert_path, const char* key_path) { + // Create a new task + cert_gen_task* task = malloc(sizeof(cert_gen_task)); + if (!task) return NULL; + + strncpy(task->hostname, hostname, MAX_URL_LENGTH-1); + task->hostname[MAX_URL_LENGTH-1] = '\0'; + strncpy(task->cert_path, cert_path, 255); + task->cert_path[255] = '\0'; + strncpy(task->key_path, key_path, 255); + task->key_path[255] = '\0'; + task->result = 0; + task->completed = 0; + pthread_mutex_init(&task->mutex, NULL); + pthread_cond_init(&task->cond, NULL); + + // Add task to queue + pthread_mutex_lock(&cert_queue.mutex); + while (cert_queue.task_count >= cert_queue.task_capacity && cert_queue.running) { + pthread_cond_wait(&cert_queue.not_full, &cert_queue.mutex); + } + + if (!cert_queue.running) { + pthread_mutex_unlock(&cert_queue.mutex); + free(task); + return NULL; + } + + // Add the task to the queue + cert_queue.tasks[cert_queue.task_count++] = task; + + // Signal that the queue is not empty + pthread_cond_signal(&cert_queue.not_empty); + pthread_mutex_unlock(&cert_queue.mutex); + + return task; +} + +// Wait for certificate generation to complete and get result +int wait_for_cert_generation(cert_gen_task* task) { + if (!task) return 0; + + pthread_mutex_lock(&task->mutex); + while (!task->completed) { + pthread_cond_wait(&task->cond, &task->mutex); + } + int result = task->result; + pthread_mutex_unlock(&task->mutex); + + // Clean up task resources + pthread_mutex_destroy(&task->mutex); + pthread_cond_destroy(&task->cond); + free(task); + + return result; +} + +// Cleanup certificate generation pool +void cleanup_cert_gen_pool() { + // Signal shutdown + pthread_mutex_lock(&cert_queue.mutex); + cert_queue.running = 0; + pthread_cond_broadcast(&cert_queue.not_empty); + pthread_cond_broadcast(&cert_queue.not_full); + pthread_mutex_unlock(&cert_queue.mutex); + + // Wait for threads to exit + for (int i = 0; i < CERT_GEN_THREAD_POOL_SIZE; i++) { + pthread_join(cert_gen_threads[i], NULL); + } + + // Free any remaining tasks + pthread_mutex_lock(&cert_queue.mutex); + for (int i = 0; i < cert_queue.task_count; i++) { + free(cert_queue.tasks[i]); + } + free(cert_queue.tasks); + pthread_mutex_unlock(&cert_queue.mutex); + + // Destroy sync primitives + pthread_mutex_destroy(&cert_queue.mutex); + pthread_cond_destroy(&cert_queue.not_empty); + pthread_cond_destroy(&cert_queue.not_full); +} + +// Pre-generate certificates for popular domains (call during init) +void pregen_common_certificates() { + // List of common domains that might be accessed + const char* common_domains[] = { + "www.google.com", "www.facebook.com", "www.youtube.com", + "www.amazon.com", "www.twitter.com", "www.instagram.com", + "www.linkedin.com", "www.reddit.com", "www.netflix.com", + "www.github.com" + // Add more common domains as needed + }; + + int num_domains = sizeof(common_domains) / sizeof(common_domains[0]); + int domains_to_pregen = num_domains < PRE_GEN_CERT_COUNT ? num_domains : PRE_GEN_CERT_COUNT; + + printf("Pre-generating certificates for %d common domains...\n", domains_to_pregen); + + // Create a thread to pre-generate certificates in the background + pthread_t pregen_thread; + pthread_create(&pregen_thread, NULL, pregen_cert_thread, (void*)(long)domains_to_pregen); + pthread_detach(pregen_thread); +} + +// Thread function for pre-generating certificates +void* pregen_cert_thread(void* arg) { + int count = (int)(long)arg; + + // List of common domains that might be accessed + const char* common_domains[] = { + "www.google.com", "www.facebook.com", "www.youtube.com", + "www.amazon.com", "www.twitter.com", "www.instagram.com", + "www.linkedin.com", "www.reddit.com", "www.netflix.com", + "www.github.com" + // Add more common domains as needed + }; + + for (int i = 0; i < count; i++) { + char cert_path[256]; + char key_path[256]; + snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", common_domains[i]); + snprintf(key_path, sizeof(key_path), "certs/%s.key", common_domains[i]); + + // Don't regenerate if it exists and is recent + if (access(cert_path, F_OK) == 0 && access(key_path, F_OK) == 0) { + // Check if we need to renew + if (!should_renew_cert(common_domains[i])) { + printf("Certificate for %s already exists and is recent\n", common_domains[i]); + continue; + } + } + + printf("Pre-generating certificate for %s\n", common_domains[i]); + generate_unique_cert(common_domains[i], cert_path, key_path); + + // Sleep a bit to not overwhelm the system + usleep(100000); // 100ms + } + + printf("Finished pre-generating certificates\n"); + return NULL; +} + // Regular HTTPS tunneling without interception void handle_regular_https_tunnel(int client_sock, int server_sock) { fd_set read_fds; @@ -925,7 +1869,7 @@ char* rewrite_http_request(const char* original_request, size_t* new_length) { if (strncmp(url, "http://", 7) == 0) { path = strchr(url + 7, '/'); if (!path) { - path = "/"; // Default to root path if not found + path = "/"; printf("No path in URL, using default: %s\n", path); } else { printf("Extracted path from URL: %s\n", path); @@ -951,7 +1895,6 @@ char* rewrite_http_request(const char* original_request, size_t* new_length) { // Write the new request int written = snprintf(new_request, total_length + 1, "%s %s %s\r\n%s", method, path, version, headers_start); - if (written < 0 || (size_t)written > total_length) { printf("Failed to rewrite HTTP request\n"); free(new_request); @@ -963,19 +1906,14 @@ char* rewrite_http_request(const char* original_request, size_t* new_length) { return new_request; } -// Handle client connection in a separate thread -void* handle_client(void* arg) { - thread_arg_t* thread_arg = (thread_arg_t*)arg; - int client_sock = thread_arg->client_sock; - free(thread_arg); - +// Optimized client handling +void handle_client_optimized(int client_sock) { char request[MAX_REQUEST_SIZE]; - char buffer[MAX_REQUEST_SIZE]; ssize_t bytes_received = recv(client_sock, request, sizeof(request) - 1, 0); if (bytes_received <= 0) { close(client_sock); - return NULL; + return; } request[bytes_received] = '\0'; @@ -988,13 +1926,11 @@ void* handle_client(void* arg) { if (!host) { printf("Failed to extract host from request\n"); close(client_sock); - return NULL; + return; } // Extract port from request int port = extract_port(request); - - // Ensure we always have a valid port if (port <= 0 || port > 65535) { port = is_connect ? 443 : 80; printf("Invalid port detected, using default port: %d\n", port); @@ -1009,23 +1945,16 @@ void* handle_client(void* arg) { if (dns_cache_result == 0) { printf("Using cached DNS for %s: %s\n", host, ip_addr); - } else if (dns_cache_result == 1) { - printf("Using cached failed DNS resolution for %s\n", host); - close(client_sock); - return NULL; } else { // Resolve hostname using DoH if (resolve_doh(host, ip_addr, sizeof(ip_addr)) != 0) { printf("Failed to resolve hostname using DoH: %s\n", host); - // Cache the failed resolution to avoid repeated lookups - add_to_dns_cache(host, "", 0); + // Don't cache failed resolutions anymore close(client_sock); - return NULL; + return; } - printf("Resolved %s to %s\n", host, ip_addr); - - // Cache the successful resolution + // Only cache successful resolutions add_to_dns_cache(host, ip_addr, 1); } @@ -1035,7 +1964,7 @@ void* handle_client(void* arg) { if (server_sock < 0) { perror("Cannot create socket to server"); close(client_sock); - return NULL; + return; } memset(&server_addr, 0, sizeof(server_addr)); @@ -1047,7 +1976,7 @@ void* handle_client(void* arg) { perror("Cannot connect to server"); close(server_sock); close(client_sock); - return NULL; + return; } if (is_connect) { @@ -1056,15 +1985,13 @@ void* handle_client(void* arg) { } else { // HTTP: Rewrite the request and forward it printf("HTTP request received, rewriting for server...\n"); - size_t new_length = 0; char* modified_request = rewrite_http_request(request, &new_length); - if (!modified_request) { printf("Failed to rewrite HTTP request\n"); close(server_sock); close(client_sock); - return NULL; + return; } printf("Forwarding modified HTTP request (%zu bytes) to server...\n", new_length); @@ -1081,7 +2008,7 @@ void* handle_client(void* arg) { free(modified_request); close(server_sock); close(client_sock); - return NULL; + return; } free(modified_request); @@ -1093,6 +2020,7 @@ void* handle_client(void* arg) { // Receive response from server and forward to client printf("Waiting for server response (timeout: 5s)...\n"); + char buffer[MAX_REQUEST_SIZE]; fd_set read_fds; FD_ZERO(&read_fds); FD_SET(server_sock, &read_fds); @@ -1102,9 +2030,10 @@ void* handle_client(void* arg) { printf("Timeout or error waiting for server response\n"); close(server_sock); close(client_sock); - return NULL; + return; } + ssize_t bytes_received; while ((bytes_received = recv(server_sock, buffer, sizeof(buffer), 0)) > 0) { printf("Received %zd bytes from server\n", bytes_received); if (send(client_sock, buffer, bytes_received, 0) < 0) { @@ -1123,161 +2052,447 @@ void* handle_client(void* arg) { printf("Connection closed: %s (port %d)\n", host, port); close(server_sock); close(client_sock); - return NULL; } -int start_proxy_server(int port) { - int server_sock, client_sock; - struct sockaddr_in server_addr, client_addr; - socklen_t client_len = sizeof(client_addr); +// HTTPS tunnel with DANE verification, optimized version +void handle_https_tunnel_optimized(int client_sock, int server_sock, const char* hostname, + const char* ip_addr, SSL* existing_ssl, SSL_CTX* existing_ctx) { + // Mark ip_addr as used to avoid warning + (void)ip_addr; - // Initialize proxy components - if (proxy_init() != 0) { - fprintf(stderr, "Failed to initialize proxy server\n"); - return 1; + // Use existing SSL session if provided + SSL* server_ssl = existing_ssl; + SSL_CTX* server_ctx = existing_ctx; + int need_ssl_setup = (server_ssl == NULL); + + // Check the DANE cache first + int cached_result = check_cache(hostname); + + // Generate certificate paths + char cert_path[256]; + char key_path[256]; + snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", hostname); + snprintf(key_path, sizeof(key_path), "certs/%s.key", hostname); + + // If in cache and verified as successful, skip verification + if (cached_result == 1) { + // Cached DANE verification successful + + // Check if certificate exists and doesn't need renewal + if (access(cert_path, F_OK) == 0 && access(key_path, F_OK) == 0 && !should_renew_cert(hostname)) { + // Initialize SSL context with our certificate + ssl_context_t* client_ctx = NULL; + if (!get_cached_cert(hostname, &client_ctx)) { + // Fall back to disk + client_ctx = init_ssl_context(cert_path, key_path); + if (client_ctx) { + // Now that we've loaded from disk, cache it for next time + cache_cert_in_memory(hostname, cert_path, key_path); + } + } + + if (!client_ctx) { + // Fall back to regular tunnel + handle_regular_https_tunnel(client_sock, server_sock); + return; + } + + // If we need to set up SSL, do it + if (need_ssl_setup) { + server_ctx = SSL_CTX_new(TLS_client_method()); + if (!server_ctx) { + SSL_CTX_free(client_ctx->ctx); + free(client_ctx); + handle_regular_https_tunnel(client_sock, server_sock); + return; + } + + server_ssl = SSL_new(server_ctx); + SSL_set_fd(server_ssl, server_sock); + SSL_set_tlsext_host_name(server_ssl, hostname); + + // Enable session caching for future connections + SSL_CTX_set_session_cache_mode(server_ctx, SSL_SESS_CACHE_CLIENT); + + if (SSL_connect(server_ssl) <= 0) { + cleanup_ssl_connection(server_ssl, server_ctx); + SSL_CTX_free(client_ctx->ctx); + free(client_ctx); + handle_regular_https_tunnel(client_sock, server_sock); + return; + } + } + + // Send 200 Connection Established to client + const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n"; + if (send(client_sock, success_response, strlen(success_response), 0) < 0) { + if (!need_ssl_setup) { + // Don't clean up if we're using a cached connection + SSL_CTX_free(client_ctx->ctx); + free(client_ctx); + } else { + cleanup_ssl_connection(server_ssl, server_ctx); + SSL_CTX_free(client_ctx->ctx); + free(client_ctx); + } + return; + } + + // Set up client SSL + client_ctx->ssl = SSL_new(client_ctx->ctx); + SSL_set_fd(client_ctx->ssl, client_sock); + + if (SSL_accept(client_ctx->ssl) <= 0) { + cleanup_ssl_connection(client_ctx->ssl, NULL); + if (need_ssl_setup) { + cleanup_ssl_connection(server_ssl, server_ctx); + } + SSL_CTX_free(client_ctx->ctx); + free(client_ctx); + return; + } + + // Tunnel data + ssl_tunnel_data(client_ctx->ssl, server_ssl); + + // Clean up + cleanup_ssl_connection(client_ctx->ssl, NULL); + if (need_ssl_setup) { + // If we didn't use a pooled connection, add this one to the pool + add_to_connection_pool(hostname, 443, server_sock, server_ssl, server_ctx); + } + SSL_CTX_free(client_ctx->ctx); + free(client_ctx); + return; + } + } else if (cached_result == 0) { + // Cached DANE verification failed, use regular tunnel + handle_regular_https_tunnel(client_sock, server_sock); + return; } - // Create periodic cleanup thread - pthread_t cleanup_thread; - if (pthread_create(&cleanup_thread, NULL, periodic_cleanup, NULL) != 0) { - perror("Failed to create cleanup thread"); - // Continue anyway, not critical + // Need to perform DANE verification + // Clean up any existing SSL session since we need to validate from scratch + if (!need_ssl_setup) { + // Not using server_ssl/ctx directly as they belong to pool + cleanup_ssl_connection(server_ssl, NULL); + need_ssl_setup = 1; + } + + // Set up a new SSL connection to server + server_ctx = SSL_CTX_new(TLS_client_method()); + if (!server_ctx) { + handle_regular_https_tunnel(client_sock, server_sock); + return; + } + + server_ssl = SSL_new(server_ctx); + SSL_set_fd(server_ssl, server_sock); + SSL_set_tlsext_host_name(server_ssl, hostname); + + if (SSL_connect(server_ssl) <= 0) { + cleanup_ssl_connection(server_ssl, server_ctx); + handle_regular_https_tunnel(client_sock, server_sock); + return; + } + + // Check for DANE records + int has_dane = is_dane_available(hostname); + + if (has_dane) { + // Get server certificate and verify against DANE + X509* server_cert = SSL_get_peer_certificate(server_ssl); + if (!server_cert) { + cleanup_ssl_connection(server_ssl, server_ctx); + handle_regular_https_tunnel(client_sock, server_sock); + return; + } + + int dane_verified = verify_cert_against_dane(hostname, server_cert); + add_to_cache(hostname, dane_verified > 0 ? 1 : 0); + + if (dane_verified <= 0) { + // DANE verification failed + X509_free(server_cert); + cleanup_ssl_connection(server_ssl, server_ctx); + handle_regular_https_tunnel(client_sock, server_sock); + return; + } + + // Generate trusted certificate + if (!generate_unique_cert(hostname, cert_path, key_path)) { + X509_free(server_cert); + cleanup_ssl_connection(server_ssl, server_ctx); + handle_regular_https_tunnel(client_sock, server_sock); + return; + } + + // Set up client connection with our certificate + ssl_context_t* client_ctx = init_ssl_context(cert_path, key_path); + if (!client_ctx) { + X509_free(server_cert); + cleanup_ssl_connection(server_ssl, server_ctx); + handle_regular_https_tunnel(client_sock, server_sock); + return; + } + + // Send 200 Connection Established to client + const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n"; + if (send(client_sock, success_response, strlen(success_response), 0) < 0) { + X509_free(server_cert); + cleanup_ssl_connection(server_ssl, server_ctx); + cleanup_ssl_connection(client_ctx->ssl, NULL); + SSL_CTX_free(client_ctx->ctx); + free(client_ctx); + return; + } + + // Setup client SSL + client_ctx->ssl = SSL_new(client_ctx->ctx); + SSL_set_fd(client_ctx->ssl, client_sock); + + if (SSL_accept(client_ctx->ssl) <= 0) { + X509_free(server_cert); + cleanup_ssl_connection(client_ctx->ssl, NULL); + cleanup_ssl_connection(server_ssl, server_ctx); + SSL_CTX_free(client_ctx->ctx); + free(client_ctx); + return; + } + + // Tunnel data + ssl_tunnel_data(client_ctx->ssl, server_ssl); + + // Add connection to pool for future reuse + add_to_connection_pool(hostname, 443, server_sock, server_ssl, server_ctx); + + // Clean up + X509_free(server_cert); + cleanup_ssl_connection(client_ctx->ssl, NULL); + SSL_CTX_free(client_ctx->ctx); + free(client_ctx); } else { - pthread_detach(cleanup_thread); + // No DANE records + add_to_cache(hostname, 0); + cleanup_ssl_connection(server_ssl, server_ctx); + handle_regular_https_tunnel(client_sock, server_sock); } +} + +SSL_SESSION* get_ssl_session(const char* hostname) { + pthread_mutex_lock(&ssl_session_mutex); - // Create socket - server_sock = socket(AF_INET, SOCK_STREAM, 0); - if (server_sock < 0) { - perror("Cannot create socket"); - return 1; - } + ssl_session_cache_entry* entry; + HASH_FIND_STR(ssl_session_cache, hostname, entry); - // Allow reuse of address - int opt = 1; - if (setsockopt(server_sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) { - perror("setsockopt failed"); - close(server_sock); - return 1; - } - - // Initialize server address - memset(&server_addr, 0, sizeof(server_addr)); - server_addr.sin_family = AF_INET; - server_addr.sin_addr.s_addr = INADDR_ANY; - server_addr.sin_port = htons(port); - - // Bind socket - if (bind(server_sock, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) { - perror("Bind failed"); - close(server_sock); - return 1; - } - - // Listen for connections - if (listen(server_sock, 10) < 0) { - perror("Listen failed"); - close(server_sock); - return 1; - } - - // Set up signal handler for graceful termination - signal(SIGINT, handle_signal); - - printf("Proxy server listening on port %d\n", port); - printf("Press Ctrl+C to stop the server\n"); - - // Accept and handle client connections - while (1) { - client_sock = accept(server_sock, (struct sockaddr*)&client_addr, &client_len); - if (client_sock < 0) { - perror("Accept failed"); - continue; + if (entry) { + time_t now = time(NULL); + + // Check session age - 10 minute validity + if (now - entry->timestamp > 600) { + // Session too old, remove it + SSL_SESSION_free(entry->session); + HASH_DEL(ssl_session_cache, entry); + free(entry); + pthread_mutex_unlock(&ssl_session_mutex); + return NULL; } - printf("New connection from %s:%d\n", - inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port)); - - // Create thread argument - thread_arg_t* thread_arg = malloc(sizeof(thread_arg_t)); - if (!thread_arg) { - perror("Failed to allocate memory for thread argument"); - close(client_sock); - continue; - } - thread_arg->client_sock = client_sock; - - // Create thread to handle client - pthread_t thread_id; - if (pthread_create(&thread_id, NULL, handle_client, thread_arg) != 0) { - perror("Failed to create thread"); - free(thread_arg); - close(client_sock); - continue; - } - - // Detach thread to allow it to clean up automatically - pthread_detach(thread_id); - } - - close(server_sock); - return 0; -} - -// Initialize the proxy server -int proxy_init() { - // Initialize OpenSSL error strings - SSL_load_error_strings(); - ERR_load_crypto_strings(); - - // Initialize DANE support - if (!dane_init()) { - fprintf(stderr, "Failed to initialize DANE support\n"); - return 1; - } - - return 0; -} - -// Clean up proxy resources -void proxy_cleanup() { - // Clean up caches - free_cache(); - free_dns_cache(); - free_cert_serials(); - - // Clean up DANE resources - dane_cleanup(); -} - -// Signal handler for graceful termination -void handle_signal(int sig) { - if (sig == SIGINT) { - printf("\nShutting down proxy server...\n"); - printf("Cleaning up cache and temporary certificates...\n"); - - // Clean up resources - proxy_cleanup(); - - exit(0); - } -} - -// Periodic cleanup function to run in a separate thread -void* periodic_cleanup(void* arg) { - (void)arg; // Unused parameter - - while (1) { - // Sleep for 5 minutes - sleep(300); - - // Clean up expired cache entries - cleanup_cache(); - cleanup_dns_cache(); - cleanup_cert_serials(); - - printf("Performed periodic cache cleanup\n"); + // Return a reference to the cached session + SSL_SESSION* session = entry->session; + pthread_mutex_unlock(&ssl_session_mutex); + return session; } + pthread_mutex_unlock(&ssl_session_mutex); return NULL; } + +void cache_ssl_session(const char* hostname, SSL_SESSION* session) { + if (!session) return; + + pthread_mutex_lock(&ssl_session_mutex); + + ssl_session_cache_entry* entry; + HASH_FIND_STR(ssl_session_cache, hostname, entry); + + if (entry) { + SSL_SESSION_free(entry->session); + entry->session = session; + entry->timestamp = time(NULL); + } else { + entry = malloc(sizeof(ssl_session_cache_entry)); + if (entry) { + strncpy(entry->hostname, hostname, MAX_URL_LENGTH-1); + entry->hostname[MAX_URL_LENGTH-1] = '\0'; + entry->session = session; + entry->timestamp = time(NULL); + HASH_ADD_STR(ssl_session_cache, hostname, entry); + } + } + + pthread_mutex_unlock(&ssl_session_mutex); +} + +void cleanup_ssl_sessions() { + pthread_mutex_lock(&ssl_session_mutex); + + ssl_session_cache_entry* entry, *tmp; + time_t now = time(NULL); + + HASH_ITER(hh, ssl_session_cache, entry, tmp) { + if (now - entry->timestamp > 600) { // 10 minute timeout + SSL_SESSION_free(entry->session); + HASH_DEL(ssl_session_cache, entry); + free(entry); + } + } + + pthread_mutex_unlock(&ssl_session_mutex); +} + +int get_cached_cert(const char* hostname, ssl_context_t** ctx) { + pthread_mutex_lock(&cert_cache_mutex); + + cert_cache_entry* entry; + HASH_FIND_STR(cert_memory_cache, hostname, entry); + + if (!entry) { + pthread_mutex_unlock(&cert_cache_mutex); + return 0; + } + + // Check if certificate is still valid + time_t now = time(NULL); + if (now - entry->timestamp > CERT_RENEWAL_TIME) { + pthread_mutex_unlock(&cert_cache_mutex); + return 0; + } + + // Create SSL context from memory + *ctx = malloc(sizeof(ssl_context_t)); + if (!*ctx) { + pthread_mutex_unlock(&cert_cache_mutex); + return 0; + } + + (*ctx)->ctx = SSL_CTX_new(TLS_server_method()); + if (!(*ctx)->ctx) { + free(*ctx); + *ctx = NULL; + pthread_mutex_unlock(&cert_cache_mutex); + return 0; + } + + // Configure SSL context + SSL_CTX_set_options((*ctx)->ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION); + + // Load cert and key from memory + BIO* cbio = BIO_new_mem_buf(entry->cert_data, entry->cert_len); + X509* cert = PEM_read_bio_X509(cbio, NULL, NULL, NULL); + BIO_free(cbio); + + BIO* kbio = BIO_new_mem_buf(entry->key_data, entry->key_len); + EVP_PKEY* key = PEM_read_bio_PrivateKey(kbio, NULL, NULL, NULL); + BIO_free(kbio); + + if (!cert || !key || + SSL_CTX_use_certificate((*ctx)->ctx, cert) <= 0 || + SSL_CTX_use_PrivateKey((*ctx)->ctx, key) <= 0) { + + if (cert) X509_free(cert); + if (key) EVP_PKEY_free(key); + SSL_CTX_free((*ctx)->ctx); + free(*ctx); + *ctx = NULL; + pthread_mutex_unlock(&cert_cache_mutex); + return 0; + } + + X509_free(cert); + EVP_PKEY_free(key); + pthread_mutex_unlock(&cert_cache_mutex); + return 1; +} + +void cache_cert_in_memory(const char* hostname, const char* cert_path, const char* key_path) { + // Read certificate file + FILE* cert_file = fopen(cert_path, "rb"); + if (!cert_file) return; + + fseek(cert_file, 0, SEEK_END); + size_t cert_size = ftell(cert_file); + fseek(cert_file, 0, SEEK_SET); + unsigned char* cert_data = malloc(cert_size); + if (!cert_data) { + fclose(cert_file); + return; + } + + if (fread(cert_data, 1, cert_size, cert_file) != cert_size) { + fclose(cert_file); + free(cert_data); + return; + } + fclose(cert_file); + + // Read key file + FILE* key_file = fopen(key_path, "rb"); + if (!key_file) { + free(cert_data); + return; + } + + fseek(key_file, 0, SEEK_END); + size_t key_size = ftell(key_file); + fseek(key_file, 0, SEEK_SET); + unsigned char* key_data = malloc(key_size); + if (!key_data) { + fclose(key_file); + free(cert_data); + return; + } + + if (fread(key_data, 1, key_size, key_file) != key_size) { + fclose(key_file); + free(cert_data); + free(key_data); + return; + } + fclose(key_file); + + // Add to cache + pthread_mutex_lock(&cert_cache_mutex); + + cert_cache_entry* entry; + HASH_FIND_STR(cert_memory_cache, hostname, entry); + + if (entry) { + // Update existing entry + free(entry->cert_data); + free(entry->key_data); + entry->cert_data = cert_data; + entry->cert_len = cert_size; + entry->key_data = key_data; + entry->key_len = key_size; + entry->timestamp = time(NULL); + } else { + // Create new entry + entry = malloc(sizeof(cert_cache_entry)); + if (entry) { + strncpy(entry->hostname, hostname, MAX_URL_LENGTH-1); + entry->hostname[MAX_URL_LENGTH-1] = '\0'; + entry->cert_data = cert_data; + entry->cert_len = cert_size; + entry->key_data = key_data; + entry->key_len = key_size; + entry->timestamp = time(NULL); + HASH_ADD_STR(cert_memory_cache, hostname, entry); + } else { + free(cert_data); + free(key_data); + } + } + + pthread_mutex_unlock(&cert_cache_mutex); +}