Files
fireproxy/src/proxy.c

2499 lines
82 KiB
C

#include "proxy.h"
#include "doh.h"
#include "dane.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <netinet/in.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <pthread.h>
#include <netdb.h>
#include <ctype.h>
#include <signal.h>
#include <errno.h>
#include <sys/stat.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
#include <time.h>
#include <openssl/rand.h> // For better random number generation
#include <uthash.h> // Add uthash library for hash tables
#include <stdatomic.h> // For atomic operations
#define MAX_REQUEST_SIZE 8192
#define MAX_URL_LENGTH 2048
#define THREAD_POOL_SIZE 20
#define CACHE_EXPIRY_TIME (86400) // Increase to 24 hours
#define DNS_CACHE_EXPIRY_TIME (3600) // Increase to 1 hour
#define IP_ADDR_MAX_LEN 46 // Maximum length for IPv6 address
#define CERT_RENEWAL_TIME (86400*7) // Renew certificates after 7 days
#define CONNECTION_POOL_SIZE 32 // Number of connections to keep in the pool
#define CONNECTION_IDLE_TIMEOUT 300 // 5 minutes
#define NUM_WORKER_THREADS 8 // Worker threads for handling connections
#define MAX_PARALLEL_DNS_QUERIES 8 // Number of parallel DNS queries
#define CERT_GEN_THREAD_POOL_SIZE 4 // Threads for certificate generation
#define PRE_GEN_CERT_COUNT 20 // Number of pre-generated certificates for domains with DANE
// Forward declarations of structures
typedef struct dns_task dns_task;
typedef struct cert_gen_task cert_gen_task;
// Function prototypes
void* periodic_cleanup(void* arg);
int generate_unique_cert(const char* hostname, const char* cert_path, const char* key_path);
void cleanup_ssl_connection(SSL* ssl, SSL_CTX* ctx);
void* worker_thread(void* arg);
void init_thread_pool(void);
void queue_client_connection(int client_sock);
void handle_client_optimized(int client_sock);
void handle_https_tunnel_optimized(int client_sock, int server_sock, const char* hostname,
const char* ip_addr, SSL* existing_ssl, SSL_CTX* existing_ctx);
void init_dns_pool(void);
void* dns_worker(void* arg);
dns_task* queue_dns_resolution(const char* hostname);
int wait_for_dns_resolution(dns_task* task, char* ip_buffer, size_t buffer_size);
void cleanup_dns_pool(void);
void init_cert_gen_pool(void);
void* cert_gen_worker(void* arg);
cert_gen_task* queue_cert_generation(const char* hostname, const char* cert_path, const char* key_path);
int wait_for_cert_generation(cert_gen_task* task);
void cleanup_cert_gen_pool(void);
void pregen_common_certificates(void);
void* pregen_cert_thread(void* arg);
// Add missing function prototypes for SSL and certificate caching
SSL_SESSION* get_ssl_session(const char* hostname);
void cache_ssl_session(const char* hostname, SSL_SESSION* session);
void cleanup_ssl_sessions(void);
int get_cached_cert(const char* hostname, ssl_context_t** ctx);
void cache_cert_in_memory(const char* hostname, const char* cert_path, const char* key_path);
void handle_regular_https_tunnel(int client_sock, int server_sock);
void ssl_tunnel_data(SSL* client_ssl, SSL* server_ssl);
// Hash table entry for DANE verification cache
typedef struct {
char hostname[MAX_URL_LENGTH]; // key
int verified;
time_t timestamp;
UT_hash_handle hh; // makes this structure hashable
} dane_cache_entry;
// Hash table entry for DNS cache
typedef struct {
char hostname[MAX_URL_LENGTH]; // key
char ip_addr[IP_ADDR_MAX_LEN];
time_t timestamp;
int is_valid; // 1 if valid IP, 0 if resolution failed
UT_hash_handle hh; // makes this structure hashable
} dns_cache_entry;
// Hash table entry for certificate serial tracking
typedef struct {
char hostname[MAX_URL_LENGTH]; // key
time_t timestamp;
UT_hash_handle hh; // makes this structure hashable
} cert_serial_entry;
// Connection pool entry
typedef struct {
char hostname[MAX_URL_LENGTH];
int port;
int sock;
SSL* ssl;
SSL_CTX* ctx;
time_t last_used;
int in_use;
UT_hash_handle hh;
} connection_pool_entry;
// Thread pool structures
typedef struct {
int client_sock;
} client_task;
typedef struct {
client_task* tasks;
int task_count;
int task_capacity;
pthread_mutex_t mutex;
pthread_cond_t not_empty;
pthread_cond_t not_full;
} task_queue;
// Parallel DNS resolution
struct dns_task {
char hostname[MAX_URL_LENGTH];
char ip_addr[IP_ADDR_MAX_LEN];
int result;
int completed;
pthread_mutex_t mutex;
pthread_cond_t cond;
};
typedef struct {
dns_task** tasks;
int task_count;
int task_capacity;
pthread_mutex_t mutex;
pthread_cond_t not_empty;
pthread_cond_t not_full;
int running;
} dns_queue;
// Certificate generation thread pool
struct cert_gen_task {
char hostname[MAX_URL_LENGTH];
char cert_path[256];
char key_path[256];
int result;
int completed;
pthread_mutex_t mutex;
pthread_cond_t cond;
};
typedef struct {
cert_gen_task** tasks;
int task_count;
int task_capacity;
pthread_mutex_t mutex;
pthread_cond_t not_empty;
pthread_cond_t not_full;
int running;
} cert_gen_queue;
// Global caches - now using hash tables
static dane_cache_entry* dane_cache = NULL;
static dns_cache_entry* dns_cache = NULL;
static cert_serial_entry* cert_serials = NULL;
static connection_pool_entry* conn_pool = NULL;
// Use rwlocks instead of mutexes for better concurrency
static pthread_rwlock_t cache_rwlock = PTHREAD_RWLOCK_INITIALIZER;
static pthread_rwlock_t dns_cache_rwlock = PTHREAD_RWLOCK_INITIALIZER;
static pthread_rwlock_t cert_rwlock = PTHREAD_RWLOCK_INITIALIZER;
static pthread_mutex_t conn_pool_mutex = PTHREAD_MUTEX_INITIALIZER;
// Thread pool state
static task_queue work_queue;
static pthread_t worker_threads[NUM_WORKER_THREADS];
static int workers_running = 1;
// DNS resolution thread pool state
static dns_queue doh_queue;
static pthread_t dns_threads[MAX_PARALLEL_DNS_QUERIES];
// Certificate generation thread pool state
static cert_gen_queue cert_queue;
static pthread_t cert_gen_threads[CERT_GEN_THREAD_POOL_SIZE];
// Atomic counter for certificate serials to avoid collisions
static atomic_ullong next_serial = 0;
typedef struct {
int client_sock;
} thread_arg_t;
// Improved SSL session cache implementation
typedef struct {
char hostname[MAX_URL_LENGTH];
SSL_SESSION* session;
time_t timestamp;
UT_hash_handle hh;
} ssl_session_cache_entry;
static ssl_session_cache_entry* ssl_session_cache = NULL;
static pthread_mutex_t ssl_session_mutex = PTHREAD_MUTEX_INITIALIZER;
// Certificate memory cache
typedef struct {
char hostname[MAX_URL_LENGTH];
unsigned char* cert_data;
size_t cert_len;
unsigned char* key_data;
size_t key_len;
time_t timestamp;
UT_hash_handle hh;
} cert_cache_entry;
static cert_cache_entry* cert_memory_cache = NULL;
static pthread_mutex_t cert_cache_mutex = PTHREAD_MUTEX_INITIALIZER;
// Initialize the thread pool
void init_thread_pool(void) {
// Initialize the task queue
work_queue.tasks = malloc(sizeof(client_task) * THREAD_POOL_SIZE);
work_queue.task_count = 0;
work_queue.task_capacity = THREAD_POOL_SIZE;
pthread_mutex_init(&work_queue.mutex, NULL);
pthread_cond_init(&work_queue.not_empty, NULL);
pthread_cond_init(&work_queue.not_full, NULL);
// Create worker threads
for (int i = 0; i < NUM_WORKER_THREADS; i++) {
if (pthread_create(&worker_threads[i], NULL, worker_thread, NULL) != 0) {
perror("Failed to create worker thread");
exit(1);
}
}
}
// Queue a client connection for processing
void queue_client_connection(int client_sock) {
pthread_mutex_lock(&work_queue.mutex);
// Wait if the queue is full
while (work_queue.task_count >= work_queue.task_capacity && workers_running) {
pthread_cond_wait(&work_queue.not_full, &work_queue.mutex);
}
if (!workers_running) {
pthread_mutex_unlock(&work_queue.mutex);
close(client_sock);
return;
}
// Add the client socket to the queue
work_queue.tasks[work_queue.task_count].client_sock = client_sock;
work_queue.task_count++;
// Signal that the queue is not empty
pthread_cond_signal(&work_queue.not_empty);
pthread_mutex_unlock(&work_queue.mutex);
}
// Worker thread to process client connections
void* worker_thread(void* arg) {
(void)arg; // Unused parameter
while (workers_running) {
pthread_mutex_lock(&work_queue.mutex);
// Wait if the queue is empty
while (work_queue.task_count == 0 && workers_running) {
pthread_cond_wait(&work_queue.not_empty, &work_queue.mutex);
}
if (!workers_running) {
pthread_mutex_unlock(&work_queue.mutex);
break;
}
// Get a client socket from the queue
int client_sock = work_queue.tasks[0].client_sock;
// Remove task from queue
work_queue.task_count--;
if (work_queue.task_count > 0) {
memmove(&work_queue.tasks[0], &work_queue.tasks[1],
sizeof(client_task) * work_queue.task_count);
}
// Signal that the queue is not full
pthread_cond_signal(&work_queue.not_full);
pthread_mutex_unlock(&work_queue.mutex);
// Process the client connection
handle_client_optimized(client_sock);
}
return NULL;
}
// Connection pooling
void add_to_connection_pool(const char* hostname, int port, int sock, SSL* ssl, SSL_CTX* ctx) {
pthread_mutex_lock(&conn_pool_mutex);
// Check if we already have a connection for this host
connection_pool_entry* entry;
HASH_FIND_STR(conn_pool, hostname, entry);
if (entry) {
// Close the existing connection
if (entry->ssl) SSL_free(entry->ssl);
if (entry->ctx) SSL_CTX_free(entry->ctx);
if (entry->sock > 0) close(entry->sock);
// Update with new connection
entry->port = port;
entry->sock = sock;
entry->ssl = ssl;
entry->ctx = ctx;
entry->last_used = time(NULL);
entry->in_use = 0;
} else {
// Create a new entry
entry = malloc(sizeof(connection_pool_entry));
if (entry) {
strncpy(entry->hostname, hostname, MAX_URL_LENGTH-1);
entry->hostname[MAX_URL_LENGTH-1] = '\0';
entry->port = port;
entry->sock = sock;
entry->ssl = ssl;
entry->ctx = ctx;
entry->last_used = time(NULL);
entry->in_use = 0;
HASH_ADD_STR(conn_pool, hostname, entry);
}
}
pthread_mutex_unlock(&conn_pool_mutex);
}
int get_from_connection_pool(const char* hostname, int port, int* sock, SSL** ssl, SSL_CTX** ctx) {
pthread_mutex_lock(&conn_pool_mutex);
connection_pool_entry* entry;
HASH_FIND_STR(conn_pool, hostname, entry);
if (entry && entry->port == port && !entry->in_use) {
time_t now = time(NULL);
// Check if connection has been idle too long
if (now - entry->last_used > CONNECTION_IDLE_TIMEOUT) {
// Connection too old, close it
if (entry->ssl) SSL_free(entry->ssl);
if (entry->ctx) SSL_CTX_free(entry->ctx);
if (entry->sock > 0) close(entry->sock);
HASH_DEL(conn_pool, entry);
free(entry);
pthread_mutex_unlock(&conn_pool_mutex);
return 0;
}
// Connection is valid and available
*sock = entry->sock;
*ssl = entry->ssl;
*ctx = entry->ctx;
entry->in_use = 1;
entry->last_used = now;
pthread_mutex_unlock(&conn_pool_mutex);
return 1;
}
pthread_mutex_unlock(&conn_pool_mutex);
return 0;
}
void release_connection(const char* hostname) {
pthread_mutex_lock(&conn_pool_mutex);
connection_pool_entry* entry;
HASH_FIND_STR(conn_pool, hostname, entry);
if (entry) {
entry->in_use = 0;
entry->last_used = time(NULL);
}
pthread_mutex_unlock(&conn_pool_mutex);
}
void cleanup_connection_pool() {
pthread_mutex_lock(&conn_pool_mutex);
connection_pool_entry* entry, *tmp;
time_t now = time(NULL);
HASH_ITER(hh, conn_pool, entry, tmp) {
if (now - entry->last_used > CONNECTION_IDLE_TIMEOUT || entry->sock <= 0) {
// Close the connection
if (entry->ssl) SSL_free(entry->ssl);
if (entry->ctx) SSL_CTX_free(entry->ctx);
if (entry->sock > 0) close(entry->sock);
// Remove from hash table
HASH_DEL(conn_pool, entry);
free(entry);
}
}
pthread_mutex_unlock(&conn_pool_mutex);
}
// Certificate serial tracking functions
void add_cert_serial(const char* hostname) {
pthread_rwlock_wrlock(&cert_rwlock);
// Check if entry already exists
cert_serial_entry* entry;
HASH_FIND_STR(cert_serials, hostname, entry);
if (entry) {
// Update timestamp for existing entry
entry->timestamp = time(NULL);
} else {
// Create new entry
entry = malloc(sizeof(cert_serial_entry));
if (entry) {
strncpy(entry->hostname, hostname, MAX_URL_LENGTH-1);
entry->hostname[MAX_URL_LENGTH-1] = '\0';
entry->timestamp = time(NULL);
HASH_ADD_STR(cert_serials, hostname, entry);
}
}
pthread_rwlock_unlock(&cert_rwlock);
}
int should_renew_cert(const char* hostname) {
pthread_rwlock_rdlock(&cert_rwlock);
cert_serial_entry* entry;
HASH_FIND_STR(cert_serials, hostname, entry);
if (entry) {
time_t now = time(NULL);
int should_renew = (now - entry->timestamp > CERT_RENEWAL_TIME);
pthread_rwlock_unlock(&cert_rwlock);
return should_renew;
}
pthread_rwlock_unlock(&cert_rwlock);
return 1; // No entry found, should generate
}
void cleanup_cert_serials() {
pthread_rwlock_wrlock(&cert_rwlock);
cert_serial_entry* entry, *tmp;
time_t now = time(NULL);
HASH_ITER(hh, cert_serials, entry, tmp) {
if (now - entry->timestamp > CERT_RENEWAL_TIME) {
// Remove the certificate files
char cert_path[256];
char key_path[256];
snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", entry->hostname);
snprintf(key_path, sizeof(key_path), "certs/%s.key", entry->hostname);
unlink(cert_path); // Ignore errors
unlink(key_path);
// Remove from hash table
HASH_DEL(cert_serials, entry);
free(entry);
}
}
pthread_rwlock_unlock(&cert_rwlock);
}
void free_cert_serials() {
pthread_rwlock_wrlock(&cert_rwlock);
cert_serial_entry* entry, *tmp;
HASH_ITER(hh, cert_serials, entry, tmp) {
HASH_DEL(cert_serials, entry);
free(entry);
}
pthread_rwlock_unlock(&cert_rwlock);
}
// Generate a certificate with a unique serial number
int generate_unique_cert(const char* hostname, const char* cert_path, const char* key_path) {
// Remove existing certificate files to prevent reuse
unlink(cert_path);
unlink(key_path);
// Get a unique serial number using atomic operations with better entropy
struct timespec ts;
clock_gettime(CLOCK_REALTIME, &ts);
// Generate a truly unique serial number combining multiple sources of entropy
unsigned long long timestamp = ((unsigned long long)ts.tv_sec << 32) | (ts.tv_nsec & 0xFFFFFFFF);
// Add more randomness from OpenSSL
unsigned char random_bytes[16];
if (RAND_bytes(random_bytes, sizeof(random_bytes)) != 1) {
// Fall back to less secure randomness if OpenSSL random generator fails
for (int i = 0; i < 16; i++) {
random_bytes[i] = rand() & 0xFF;
}
}
// Combine all sources of randomness for the serial
unsigned long long random_component = 0;
for (int i = 0; i < 8; i++) {
random_component = (random_component << 8) | random_bytes[i];
}
// Create a unique serial using atomic counter, timestamp, and randomness
unsigned long long serial = atomic_fetch_add(&next_serial, 1) ^ timestamp ^ random_component;
// Force serial number to be positive and non-zero
serial = (serial & 0x7FFFFFFFFFFFFFFFULL) | 0x0000010000000000ULL;
printf("Generating certificate for %s with serial: %llu\n", hostname, serial);
// Clear any previously cached certificates in Firefox
// Firefox caches certificates by [issuer DN + serial number]
pthread_mutex_lock(&cert_cache_mutex);
cert_cache_entry* entry;
HASH_FIND_STR(cert_memory_cache, hostname, entry);
if (entry) {
// Remove the cached certificate
free(entry->cert_data);
free(entry->key_data);
HASH_DEL(cert_memory_cache, entry);
free(entry);
printf("Removed cached certificate for %s to avoid serial conflicts\n", hostname);
}
pthread_mutex_unlock(&cert_cache_mutex);
// Use the serial for entropy
srand((unsigned int)(serial & 0xFFFFFFFF));
// Set special environment variable to force unique serial in certificate generation
char serial_env[64];
snprintf(serial_env, sizeof(serial_env), "%llu", serial);
setenv("CERT_SERIAL_OVERRIDE", serial_env, 1);
// Queue certificate generation in the thread pool
cert_gen_task* task = queue_cert_generation(hostname, cert_path, key_path);
if (task) {
// Wait for the certificate to be generated
int result = wait_for_cert_generation(task);
if (result) {
// Track this certificate generation
add_cert_serial(hostname);
// Cache cert and key in memory
cache_cert_in_memory(hostname, cert_path, key_path);
// Verify the certificate was correctly generated
printf("Certificate successfully generated for %s\n", hostname);
} else {
fprintf(stderr, "Failed to generate certificate for %s\n", hostname);
}
// Clear the environment variable
unsetenv("CERT_SERIAL_OVERRIDE");
return result;
} else {
// Fall back to synchronous generation if queuing fails
printf("Using synchronous certificate generation for %s\n", hostname);
int result = generate_trusted_cert(hostname, cert_path, key_path);
if (result) {
add_cert_serial(hostname);
// Cache cert and key in memory
cache_cert_in_memory(hostname, cert_path, key_path);
printf("Certificate successfully generated for %s (sync method)\n", hostname);
} else {
fprintf(stderr, "Failed to generate certificate for %s (sync method)\n", hostname);
}
// Clear the environment variable
unsetenv("CERT_SERIAL_OVERRIDE");
return result;
}
}
// SSL connection cleanup helper
void cleanup_ssl_connection(SSL* ssl, SSL_CTX* ctx) {
if (ssl) {
SSL_shutdown(ssl);
SSL_free(ssl);
}
if (ctx) {
SSL_CTX_free(ctx);
}
}
// DNS cache functions (hash table versions)
void add_to_dns_cache(const char* hostname, const char* ip_addr, int is_valid) {
pthread_rwlock_wrlock(&dns_cache_rwlock);
// First check if entry already exists
dns_cache_entry* existing;
HASH_FIND_STR(dns_cache, hostname, existing);
if (existing) {
// Update existing entry
strncpy(existing->ip_addr, ip_addr, IP_ADDR_MAX_LEN-1);
existing->ip_addr[IP_ADDR_MAX_LEN-1] = '\0';
existing->is_valid = is_valid;
existing->timestamp = time(NULL);
} else {
// Create new entry
dns_cache_entry* new_entry = malloc(sizeof(dns_cache_entry));
if (new_entry) {
strncpy(new_entry->hostname, hostname, MAX_URL_LENGTH-1);
new_entry->hostname[MAX_URL_LENGTH-1] = '\0';
strncpy(new_entry->ip_addr, ip_addr, IP_ADDR_MAX_LEN-1);
new_entry->ip_addr[IP_ADDR_MAX_LEN-1] = '\0';
new_entry->is_valid = is_valid;
new_entry->timestamp = time(NULL);
HASH_ADD_STR(dns_cache, hostname, new_entry);
}
}
pthread_rwlock_unlock(&dns_cache_rwlock);
}
int check_dns_cache(const char* hostname, char* ip_buffer, size_t buffer_size) {
pthread_rwlock_rdlock(&dns_cache_rwlock);
dns_cache_entry* entry;
HASH_FIND_STR(dns_cache, hostname, entry);
if (!entry) {
pthread_rwlock_unlock(&dns_cache_rwlock);
return -1; // Not in cache
}
// Check if entry has expired
time_t now = time(NULL);
if (now - entry->timestamp > DNS_CACHE_EXPIRY_TIME) {
// Entry expired
pthread_rwlock_unlock(&dns_cache_rwlock);
return -1;
}
// If entry is valid, copy IP address to buffer
if (entry->is_valid) {
strncpy(ip_buffer, entry->ip_addr, buffer_size-1);
ip_buffer[buffer_size-1] = '\0';
pthread_rwlock_unlock(&dns_cache_rwlock);
return 0; // Success
}
pthread_rwlock_unlock(&dns_cache_rwlock);
return 1; // Entry exists but is marked as invalid resolution
}
void cleanup_dns_cache() {
pthread_rwlock_wrlock(&dns_cache_rwlock);
dns_cache_entry* entry, *tmp;
time_t now = time(NULL);
HASH_ITER(hh, dns_cache, entry, tmp) {
if (now - entry->timestamp > DNS_CACHE_EXPIRY_TIME) {
// Remove from hash table
HASH_DEL(dns_cache, entry);
free(entry);
}
}
pthread_rwlock_unlock(&dns_cache_rwlock);
}
void free_dns_cache() {
pthread_rwlock_wrlock(&dns_cache_rwlock);
dns_cache_entry* entry, *tmp;
HASH_ITER(hh, dns_cache, entry, tmp) {
HASH_DEL(dns_cache, entry);
free(entry);
}
pthread_rwlock_unlock(&dns_cache_rwlock);
}
// Helper function to print SSL errors
void print_ssl_errors(const char* context) {
unsigned long err;
char err_buf[256];
while ((err = ERR_get_error()) != 0) {
ERR_error_string_n(err, err_buf, sizeof(err_buf));
fprintf(stderr, "SSL Error (%s): %s\n", context, err_buf);
}
}
// DANE cache functions
dane_cache_entry* find_cache_entry(const char* hostname) {
dane_cache_entry* entry;
HASH_FIND_STR(dane_cache, hostname, entry);
return entry;
}
void add_to_cache(const char* hostname, int verified) {
pthread_rwlock_wrlock(&cache_rwlock);
// First check if entry already exists
dane_cache_entry* existing;
HASH_FIND_STR(dane_cache, hostname, existing);
if (existing) {
// Update existing entry
existing->verified = verified;
existing->timestamp = time(NULL);
} else {
// Create new entry
dane_cache_entry* new_entry = malloc(sizeof(dane_cache_entry));
if (new_entry) {
strncpy(new_entry->hostname, hostname, MAX_URL_LENGTH-1);
new_entry->hostname[MAX_URL_LENGTH-1] = '\0';
new_entry->verified = verified;
new_entry->timestamp = time(NULL);
HASH_ADD_STR(dane_cache, hostname, new_entry);
}
}
pthread_rwlock_unlock(&cache_rwlock);
}
int check_cache(const char* hostname) {
pthread_rwlock_rdlock(&cache_rwlock);
dane_cache_entry* entry;
HASH_FIND_STR(dane_cache, hostname, entry);
if (!entry) {
pthread_rwlock_unlock(&cache_rwlock);
return -1; // Not in cache
}
// Check if entry has expired
time_t now = time(NULL);
if (now - entry->timestamp > CACHE_EXPIRY_TIME) {
// Entry expired
pthread_rwlock_unlock(&cache_rwlock);
return -1;
}
int result = entry->verified;
pthread_rwlock_unlock(&cache_rwlock);
return result;
}
void cleanup_cache() {
pthread_rwlock_wrlock(&cache_rwlock);
dane_cache_entry* entry, *tmp;
time_t now = time(NULL);
HASH_ITER(hh, dane_cache, entry, tmp) {
if (now - entry->timestamp > CACHE_EXPIRY_TIME) {
// Remove from hash table
HASH_DEL(dane_cache, entry);
free(entry);
}
}
pthread_rwlock_unlock(&cache_rwlock);
}
void free_cache() {
pthread_rwlock_wrlock(&cache_rwlock);
dane_cache_entry* entry, *tmp;
HASH_ITER(hh, dane_cache, entry, tmp) {
HASH_DEL(dane_cache, entry);
free(entry);
}
pthread_rwlock_unlock(&cache_rwlock);
}
// Extract hostname from HTTP request
char* extract_host(const char* request) {
static char host[MAX_URL_LENGTH];
const char* host_header = NULL;
// Check if this is a CONNECT request for HTTPS
if (strncmp(request, "CONNECT ", 8) == 0) {
// Extract hostname from CONNECT line
host_header = request + 8;
int i = 0;
while (host_header[i] && host_header[i] != ' ' && host_header[i] != ':' && i < MAX_URL_LENGTH - 1) {
host[i] = host_header[i];
i++;
}
host[i] = '\0';
return host;
}
// For regular HTTP requests, extract from Host header
host_header = strstr(request, "Host: ");
if (host_header) {
host_header += 6; // Skip "Host: "
int i = 0;
while (host_header[i] && host_header[i] != '\r' && host_header[i] != '\n'
&& host_header[i] != ':' && i < MAX_URL_LENGTH - 1) {
host[i] = host_header[i];
i++;
}
host[i] = '\0';
return host;
}
return NULL;
}
// Extract port from HTTP request
int extract_port(const char* request) {
// Default port
int default_port = 80;
// Check if this is a CONNECT request (likely HTTPS)
if (strncmp(request, "CONNECT ", 8) == 0) {
default_port = 443;
const char* connect_line = request + 8;
const char* port_start = strchr(connect_line, ':');
if (port_start) {
int port = atoi(port_start + 1);
return port > 0 ? port : default_port;
}
return default_port;
}
// For regular HTTP, check Host header for port
const char* host_header = strstr(request, "Host: ");
if (host_header) {
host_header += 6; // Skip "Host: "
const char* port_start = strchr(host_header, ':');
if (port_start) {
int port = atoi(port_start + 1);
return port > 0 ? port : default_port;
}
}
// Always return valid default port for HTTP
return default_port;
}
// Initialize SSL context for intercepting HTTPS
ssl_context_t* init_ssl_context(const char* cert_path, const char* key_path) {
ssl_context_t* ssl_ctx = malloc(sizeof(ssl_context_t));
if (!ssl_ctx) {
fprintf(stderr, "Failed to allocate memory for SSL context\n");
return NULL;
}
// Initialize SSL context
ssl_ctx->ctx = SSL_CTX_new(TLS_server_method());
if (!ssl_ctx->ctx) {
fprintf(stderr, "Failed to create SSL context\n");
free(ssl_ctx);
return NULL;
}
// Configure SSL context
SSL_CTX_set_options(ssl_ctx->ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION);
// Load certificate and private key
if (SSL_CTX_use_certificate_file(ssl_ctx->ctx, cert_path, SSL_FILETYPE_PEM) <= 0) {
fprintf(stderr, "Failed to load certificate: %s\n", cert_path);
SSL_CTX_free(ssl_ctx->ctx);
free(ssl_ctx);
return NULL;
}
if (SSL_CTX_use_PrivateKey_file(ssl_ctx->ctx, key_path, SSL_FILETYPE_PEM) <= 0) {
fprintf(stderr, "Failed to load private key: %s\n", key_path);
SSL_CTX_free(ssl_ctx->ctx);
free(ssl_ctx);
return NULL;
}
// Verify private key
if (!SSL_CTX_check_private_key(ssl_ctx->ctx)) {
fprintf(stderr, "Private key does not match the certificate\n");
SSL_CTX_free(ssl_ctx->ctx);
free(ssl_ctx);
return NULL;
}
return ssl_ctx;
}
// Handle HTTPS CONNECT tunneling with DANE verification
void handle_https_tunnel(int client_sock, int server_sock, const char* hostname, const char* ip_addr) {
// Unused parameter
(void)ip_addr;
printf("Starting HTTPS tunnel for %s\n", hostname);
// Check the cache first
int cached_result = check_cache(hostname);
// Generate certificate paths
char cert_path[256];
char key_path[256];
snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", hostname);
snprintf(key_path, sizeof(key_path), "certs/%s.key", hostname);
// If in cache and verified as successful, skip verification
if (cached_result == 1) {
printf("Using cached DANE verification for %s (verified)\n", hostname);
// Always regenerate certificate to avoid browser cache issues
printf("Regenerating certificate for %s to avoid browser certificate cache issues\n", hostname);
// Generate trusted certificate BEFORE responding to the client
if (!generate_unique_cert(hostname, cert_path, key_path)) {
fprintf(stderr, "Failed to generate trusted certificate for %s\n", hostname);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Initialize SSL context with our certificate
ssl_context_t* client_ctx = init_ssl_context(cert_path, key_path);
if (!client_ctx) {
fprintf(stderr, "Failed to initialize SSL context\n");
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Connect to the server
SSL_CTX* server_ctx = SSL_CTX_new(TLS_client_method());
if (!server_ctx) {
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
SSL* server_ssl = SSL_new(server_ctx);
SSL_set_fd(server_ssl, server_sock);
SSL_set_tlsext_host_name(server_ssl, hostname);
// Add session reuse
SSL_SESSION* session = get_ssl_session(hostname);
if (session) {
SSL_set_session(server_ssl, session);
printf("Reusing SSL session for %s\n", hostname);
}
if (SSL_connect(server_ssl) <= 0) {
print_ssl_errors("SSL_connect");
cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
} else {
// Cache the new session for future connections
SSL_SESSION* new_session = SSL_get1_session(server_ssl);
if (new_session) {
cache_ssl_session(hostname, new_session);
}
}
// Pre-initialize the SSL object to ensure it's ready
client_ctx->ssl = SSL_new(client_ctx->ctx);
if (!client_ctx->ssl) {
fprintf(stderr, "Failed to create SSL object\n");
cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Set up the file descriptor
SSL_set_fd(client_ctx->ssl, client_sock);
// Send 200 Connection Established to the client
const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n";
if (send(client_sock, success_response, strlen(success_response), 0) < 0) {
perror("Failed to send connection established response");
cleanup_ssl_connection(client_ctx->ssl, NULL);
cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
return;
}
// Clear any previous errors
ERR_clear_error();
// Accept SSL connection from client
int accept_result = SSL_accept(client_ctx->ssl);
if (accept_result <= 0) {
int ssl_err = SSL_get_error(client_ctx->ssl, accept_result);
fprintf(stderr, "SSL accept failed with error code: %d\n", ssl_err);
print_ssl_errors("SSL_accept");
// Try to determine if client closed the connection
if (ssl_err == SSL_ERROR_SYSCALL && errno == 0) {
fprintf(stderr, "Client may have closed the connection\n");
}
cleanup_ssl_connection(client_ctx->ssl, NULL);
cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
return;
}
printf("SSL connection established with client for %s (using cached verification)\n", hostname);
// Forward data between them
ssl_tunnel_data(client_ctx->ssl, server_ssl);
// Clean up
cleanup_ssl_connection(client_ctx->ssl, NULL);
cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
return;
} else if (cached_result == 0) {
printf("Using cached DANE verification for %s (not verified)\n", hostname);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Not in cache or expired, perform full verification
// Check if we have DANE records for this domain
int has_dane = is_dane_available(hostname);
if (has_dane) {
printf("DANE records found for %s, will verify certificate\n", hostname);
// First, establish a connection to the server to get its certificate
SSL_CTX* server_ctx = SSL_CTX_new(TLS_client_method());
if (!server_ctx) {
fprintf(stderr, "Failed to create server SSL context\n");
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
SSL* server_ssl = SSL_new(server_ctx);
SSL_set_fd(server_ssl, server_sock);
// Set SNI hostname
SSL_set_tlsext_host_name(server_ssl, hostname);
if (SSL_connect(server_ssl) <= 0) {
fprintf(stderr, "SSL connection to server failed\n");
cleanup_ssl_connection(server_ssl, server_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Get the server's certificate
X509* server_cert = SSL_get_peer_certificate(server_ssl);
if (!server_cert) {
fprintf(stderr, "Failed to get server certificate\n");
cleanup_ssl_connection(server_ssl, server_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Verify the certificate against DANE
int dane_verified = verify_cert_against_dane(hostname, server_cert);
// Cache the verification result
add_to_cache(hostname, dane_verified > 0 ? 1 : 0);
if (dane_verified <= 0) {
fprintf(stderr, "DANE verification failed for %s - using direct tunneling\n", hostname);
// Clean up and reconnect without interception
X509_free(server_cert);
cleanup_ssl_connection(server_ssl, server_ctx);
// Create a new socket connection to the server
close(server_sock);
int new_server_sock = socket(AF_INET, SOCK_STREAM, 0);
if (new_server_sock < 0) {
perror("Cannot create new socket to server");
close(client_sock);
return;
}
struct sockaddr_in server_addr;
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_addr.s_addr = inet_addr(ip_addr);
server_addr.sin_port = htons(443);
if (connect(new_server_sock, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) {
perror("Cannot reconnect to server");
close(new_server_sock);
close(client_sock);
return;
}
// Use regular tunneling without interception
handle_regular_https_tunnel(client_sock, new_server_sock);
return;
}
printf("DANE verification successful for %s - generating trusted certificate\n", hostname);
// Generate trusted certificate BEFORE responding to the client
// This ensures the certificate is ready when the client tries to establish the SSL connection
if (!generate_unique_cert(hostname, cert_path, key_path)) {
fprintf(stderr, "Failed to generate trusted certificate for %s\n", hostname);
X509_free(server_cert);
cleanup_ssl_connection(server_ssl, server_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Initialize SSL context with our certificate
ssl_context_t* client_ctx = init_ssl_context(cert_path, key_path);
if (!client_ctx) {
fprintf(stderr, "Failed to initialize SSL context\n");
X509_free(server_cert);
cleanup_ssl_connection(server_ssl, server_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Pre-initialize the SSL object to ensure it's ready
client_ctx->ssl = SSL_new(client_ctx->ctx);
if (!client_ctx->ssl) {
fprintf(stderr, "Failed to create SSL object\n");
X509_free(server_cert);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
cleanup_ssl_connection(server_ssl, server_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Only after everything is set up, send the Connection Established response
const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n";
if (send(client_sock, success_response, strlen(success_response), 0) < 0) {
perror("Failed to send connection established response");
X509_free(server_cert);
cleanup_ssl_connection(server_ssl, server_ctx);
SSL_free(client_ctx->ssl);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
return;
}
// Set up the file descriptor after sending the response
SSL_set_fd(client_ctx->ssl, client_sock);
// Add more detailed error information for SSL_accept
ERR_clear_error(); // Clear any previous errors
int accept_result = SSL_accept(client_ctx->ssl);
if (accept_result <= 0) {
int ssl_err = SSL_get_error(client_ctx->ssl, accept_result);
fprintf(stderr, "SSL_accept failed for %s with error code: %d\n", hostname, ssl_err);
print_ssl_errors("SSL_accept");
X509_free(server_cert);
cleanup_ssl_connection(client_ctx->ssl, NULL);
cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
return;
}
printf("SSL connection successfully established with client for %s\n", hostname);
// Tunnel data
ssl_tunnel_data(client_ctx->ssl, server_ssl);
// Add connection to pool for future reuse
add_to_connection_pool(hostname, 443, server_sock, server_ssl, server_ctx);
// Clean up
X509_free(server_cert);
cleanup_ssl_connection(client_ctx->ssl, NULL);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
} else {
// No DANE records
add_to_cache(hostname, 0);
handle_regular_https_tunnel(client_sock, server_sock);
}
}
// Modify start_proxy_server to use the worker thread pool
int start_proxy_server(int port) {
int server_sock;
struct sockaddr_in server_addr, client_addr;
socklen_t client_len = sizeof(client_addr);
// Initialize proxy components including thread pool
if (proxy_init() != 0) {
fprintf(stderr, "Failed to initialize proxy server\n");
return 1;
}
// Initialize the thread pool
init_thread_pool();
// Create periodic cleanup thread
pthread_t cleanup_thread;
if (pthread_create(&cleanup_thread, NULL, periodic_cleanup, NULL) != 0) {
perror("Failed to create cleanup thread");
// Continue anyway, not critical
} else {
pthread_detach(cleanup_thread);
}
// Create socket and set up server as before
server_sock = socket(AF_INET, SOCK_STREAM, 0);
if (server_sock < 0) {
perror("Cannot create socket");
return 1;
}
// Allow reuse of address
int opt = 1;
if (setsockopt(server_sock, SOL_SOCKET, SO_REUSEADDR, &opt, sizeof(opt)) < 0) {
perror("setsockopt failed");
close(server_sock);
return 1;
}
// Initialize server address
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_addr.s_addr = INADDR_ANY;
server_addr.sin_port = htons(port);
// Bind socket
if (bind(server_sock, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) {
perror("Bind failed");
close(server_sock);
return 1;
}
// Listen for connections
if (listen(server_sock, 10) < 0) {
perror("Listen failed");
close(server_sock);
return 1;
}
// Set up signal handler for graceful termination
signal(SIGINT, handle_signal);
printf("Proxy server listening on port %d\n", port);
printf("Press Ctrl+C to stop the server\n");
// Accept connections and queue them for processing
while (1) {
int client_sock = accept(server_sock, (struct sockaddr*)&client_addr, &client_len);
if (client_sock < 0) {
perror("Accept failed");
continue;
}
printf("New connection from %s:%d\n",
inet_ntoa(client_addr.sin_addr), ntohs(client_addr.sin_port));
// Queue client connection for processing by a worker thread
queue_client_connection(client_sock);
}
// Cleanup
close(server_sock);
return 0;
}
// Initialize the proxy server
int proxy_init() {
// Create the certs directory if it doesn't exist
mkdir("certs", 0755);
// Initialize OpenSSL error strings
SSL_load_error_strings();
ERR_load_crypto_strings();
// Initialize DANE support
if (!dane_init()) {
fprintf(stderr, "Failed to initialize DANE support\n");
return 1;
}
// Initialize DNS thread pool
init_dns_pool();
// Initialize certificate generation thread pool
init_cert_gen_pool();
// Generate a truly random initial seed for certificate serials
unsigned char random_bytes[8];
if (RAND_bytes(random_bytes, sizeof(random_bytes)) == 1) {
unsigned long long random_seed = 0;
for (int i = 0; i < 8; i++) {
random_seed = (random_seed << 8) | random_bytes[i];
}
atomic_init(&next_serial, random_seed);
} else {
// Fallback to time-based initialization with more entropy
struct timespec ts;
clock_gettime(CLOCK_REALTIME, &ts);
atomic_init(&next_serial, ((unsigned long long)ts.tv_sec << 32) | ts.tv_nsec);
}
// Don't pre-generate certificates, only generate for domains with DANE
// pregen_common_certificates();
return 0;
}
// Clean up proxy resources
void proxy_cleanup() {
// Clean up caches
free_cache();
free_dns_cache();
free_cert_serials();
// Clean up certificate generation thread pool
cleanup_cert_gen_pool();
// Clean up DANE resources
dane_cleanup();
}
// Signal handler for graceful termination
void handle_signal(int sig) {
if (sig == SIGINT) {
printf("\nShutting down proxy server...\n");
printf("Cleaning up cache and temporary certificates...\n");
// Clean up resources
proxy_cleanup();
exit(0);
}
}
// Update periodic_cleanup to also clean SSL sessions
void* periodic_cleanup(void* arg) {
(void)arg; // Unused parameter
while (1) {
// Sleep for 5 minutes
sleep(300);
// Clean up expired cache entries
cleanup_cache();
cleanup_dns_cache();
cleanup_cert_serials();
cleanup_connection_pool();
cleanup_ssl_sessions(); // Add this call
printf("Performed periodic cache cleanup\n");
}
return NULL;
}
// Initialize DNS resolution thread pool
void init_dns_pool(void) {
// Initialize the task queue
doh_queue.tasks = malloc(sizeof(dns_task*) * MAX_PARALLEL_DNS_QUERIES * 2);
doh_queue.task_count = 0;
doh_queue.task_capacity = MAX_PARALLEL_DNS_QUERIES * 2;
doh_queue.running = 1;
pthread_mutex_init(&doh_queue.mutex, NULL);
pthread_cond_init(&doh_queue.not_empty, NULL);
pthread_cond_init(&doh_queue.not_full, NULL);
// Create worker threads for DNS resolution
for (int i = 0; i < MAX_PARALLEL_DNS_QUERIES; i++) {
if (pthread_create(&dns_threads[i], NULL, dns_worker, NULL) != 0) {
perror("Failed to create DNS resolution thread");
// Continue anyway
}
}
}
// DNS resolution worker thread
void* dns_worker(void* arg) {
(void)arg;
while (doh_queue.running) {
dns_task* task = NULL;
pthread_mutex_lock(&doh_queue.mutex);
while (doh_queue.task_count == 0 && doh_queue.running) {
pthread_cond_wait(&doh_queue.not_empty, &doh_queue.mutex);
}
if (!doh_queue.running) {
pthread_mutex_unlock(&doh_queue.mutex);
break;
}
// Get the task
task = doh_queue.tasks[0];
// Remove task from queue
doh_queue.task_count--;
if (doh_queue.task_count > 0) {
memmove(&doh_queue.tasks[0], &doh_queue.tasks[1],
sizeof(dns_task*) * doh_queue.task_count);
}
pthread_cond_signal(&doh_queue.not_full);
pthread_mutex_unlock(&doh_queue.mutex);
// Perform DNS resolution
task->result = resolve_doh(task->hostname, task->ip_addr, IP_ADDR_MAX_LEN);
// Mark task as completed and signal waiting threads
pthread_mutex_lock(&task->mutex);
task->completed = 1;
pthread_cond_signal(&task->cond);
pthread_mutex_unlock(&task->mutex);
}
return NULL;
}
// Queue a DNS resolution task
dns_task* queue_dns_resolution(const char* hostname) {
// Create a new task
dns_task* task = malloc(sizeof(dns_task));
if (!task) return NULL;
strncpy(task->hostname, hostname, MAX_URL_LENGTH-1);
task->hostname[MAX_URL_LENGTH-1] = '\0';
task->ip_addr[0] = '\0';
task->result = 0;
task->completed = 0;
pthread_mutex_init(&task->mutex, NULL);
pthread_cond_init(&task->cond, NULL);
// Add task to queue
pthread_mutex_lock(&doh_queue.mutex);
while (doh_queue.task_count >= doh_queue.task_capacity && doh_queue.running) {
pthread_cond_wait(&doh_queue.not_full, &doh_queue.mutex);
}
if (!doh_queue.running) {
pthread_mutex_unlock(&doh_queue.mutex);
free(task);
return NULL;
}
// Add the task to the queue
doh_queue.tasks[doh_queue.task_count++] = task;
// Signal that the queue is not empty
pthread_cond_signal(&doh_queue.not_empty);
pthread_mutex_unlock(&doh_queue.mutex);
return task;
}
// Wait for DNS resolution to complete and get result
int wait_for_dns_resolution(dns_task* task, char* ip_buffer, size_t buffer_size) {
if (!task) return -1;
pthread_mutex_lock(&task->mutex);
while (!task->completed) {
pthread_cond_wait(&task->cond, &task->mutex);
}
int result = task->result;
pthread_mutex_unlock(&task->mutex);
if (result == 0 && ip_buffer) {
strncpy(ip_buffer, task->ip_addr, buffer_size-1);
ip_buffer[buffer_size-1] = '\0';
}
// Clean up task resources
pthread_mutex_destroy(&task->mutex);
pthread_cond_destroy(&task->cond);
free(task);
return result;
}
// Cleanup DNS resolution pool
void cleanup_dns_pool() {
// Signal shutdown
pthread_mutex_lock(&doh_queue.mutex);
doh_queue.running = 0;
pthread_cond_broadcast(&doh_queue.not_empty);
pthread_cond_broadcast(&doh_queue.not_full);
pthread_mutex_unlock(&doh_queue.mutex);
// Wait for threads to exit
for (int i = 0; i < MAX_PARALLEL_DNS_QUERIES; i++) {
pthread_join(dns_threads[i], NULL);
}
// Free any remaining tasks
pthread_mutex_lock(&doh_queue.mutex);
for (int i = 0; i < doh_queue.task_count; i++) {
free(doh_queue.tasks[i]);
}
free(doh_queue.tasks);
pthread_mutex_unlock(&doh_queue.mutex);
// Destroy sync primitives
pthread_mutex_destroy(&doh_queue.mutex);
pthread_cond_destroy(&doh_queue.not_empty);
pthread_cond_destroy(&doh_queue.not_full);
}
// Initialize certificate generation thread pool
void init_cert_gen_pool(void) {
// Initialize the task queue
cert_queue.tasks = malloc(sizeof(cert_gen_task*) * THREAD_POOL_SIZE);
cert_queue.task_count = 0;
cert_queue.task_capacity = THREAD_POOL_SIZE;
cert_queue.running = 1;
pthread_mutex_init(&cert_queue.mutex, NULL);
pthread_cond_init(&cert_queue.not_empty, NULL);
pthread_cond_init(&cert_queue.not_full, NULL);
// Create worker threads for certificate generation
for (int i = 0; i < CERT_GEN_THREAD_POOL_SIZE; i++) {
if (pthread_create(&cert_gen_threads[i], NULL, cert_gen_worker, NULL) != 0) {
perror("Failed to create certificate generation thread");
// Continue anyway
}
}
// Initialize the atomic serial counter
atomic_init(&next_serial, ((unsigned long long)time(NULL) << 32));
}
// Certificate generation worker thread
void* cert_gen_worker(void* arg) {
(void)arg;
while (cert_queue.running) {
cert_gen_task* task = NULL;
pthread_mutex_lock(&cert_queue.mutex);
while (cert_queue.task_count == 0 && cert_queue.running) {
pthread_cond_wait(&cert_queue.not_empty, &cert_queue.mutex);
}
if (!cert_queue.running) {
pthread_mutex_unlock(&cert_queue.mutex);
break;
}
// Get the task
task = cert_queue.tasks[0];
// Remove task from queue
cert_queue.task_count--;
if (cert_queue.task_count > 0) {
memmove(&cert_queue.tasks[0], &cert_queue.tasks[1],
sizeof(cert_gen_task*) * cert_queue.task_count);
}
pthread_cond_signal(&cert_queue.not_full);
pthread_mutex_unlock(&cert_queue.mutex);
// Generate the certificate
unsigned long long serial = atomic_fetch_add(&next_serial, 1);
// Make sure cert and key paths exist
unlink(task->cert_path);
unlink(task->key_path);
// Use the serial for entropy
srand((unsigned int)(serial & 0xFFFFFFFF));
// Generate the cert with a unique serial
task->result = generate_trusted_cert(task->hostname, task->cert_path, task->key_path);
// Mark task as completed and signal waiting threads
pthread_mutex_lock(&task->mutex);
task->completed = 1;
pthread_cond_signal(&task->cond);
pthread_mutex_unlock(&task->mutex);
}
return NULL;
}
// Queue a certificate generation task
cert_gen_task* queue_cert_generation(const char* hostname, const char* cert_path, const char* key_path) {
// Create a new task
cert_gen_task* task = malloc(sizeof(cert_gen_task));
if (!task) return NULL;
strncpy(task->hostname, hostname, MAX_URL_LENGTH-1);
task->hostname[MAX_URL_LENGTH-1] = '\0';
strncpy(task->cert_path, cert_path, 255);
task->cert_path[255] = '\0';
strncpy(task->key_path, key_path, 255);
task->key_path[255] = '\0';
task->result = 0;
task->completed = 0;
pthread_mutex_init(&task->mutex, NULL);
pthread_cond_init(&task->cond, NULL);
// Add task to queue
pthread_mutex_lock(&cert_queue.mutex);
while (cert_queue.task_count >= cert_queue.task_capacity && cert_queue.running) {
pthread_cond_wait(&cert_queue.not_full, &cert_queue.mutex);
}
if (!cert_queue.running) {
pthread_mutex_unlock(&cert_queue.mutex);
free(task);
return NULL;
}
// Add the task to the queue
cert_queue.tasks[cert_queue.task_count++] = task;
// Signal that the queue is not empty
pthread_cond_signal(&cert_queue.not_empty);
pthread_mutex_unlock(&cert_queue.mutex);
return task;
}
// Wait for certificate generation to complete and get result
int wait_for_cert_generation(cert_gen_task* task) {
if (!task) return 0;
pthread_mutex_lock(&task->mutex);
while (!task->completed) {
pthread_cond_wait(&task->cond, &task->mutex);
}
int result = task->result;
pthread_mutex_unlock(&task->mutex);
// Clean up task resources
pthread_mutex_destroy(&task->mutex);
pthread_cond_destroy(&task->cond);
free(task);
return result;
}
// Cleanup certificate generation pool
void cleanup_cert_gen_pool() {
// Signal shutdown
pthread_mutex_lock(&cert_queue.mutex);
cert_queue.running = 0;
pthread_cond_broadcast(&cert_queue.not_empty);
pthread_cond_broadcast(&cert_queue.not_full);
pthread_mutex_unlock(&cert_queue.mutex);
// Wait for threads to exit
for (int i = 0; i < CERT_GEN_THREAD_POOL_SIZE; i++) {
pthread_join(cert_gen_threads[i], NULL);
}
// Free any remaining tasks
pthread_mutex_lock(&cert_queue.mutex);
for (int i = 0; i < cert_queue.task_count; i++) {
free(cert_queue.tasks[i]);
}
free(cert_queue.tasks);
pthread_mutex_unlock(&cert_queue.mutex);
// Destroy sync primitives
pthread_mutex_destroy(&cert_queue.mutex);
pthread_cond_destroy(&cert_queue.not_empty);
pthread_cond_destroy(&cert_queue.not_full);
}
// Pre-generate certificates for popular domains (call during init)
void pregen_common_certificates() {
// List of common domains that might be accessed
const char* common_domains[] = {
"www.google.com", "www.facebook.com", "www.youtube.com",
"www.amazon.com", "www.twitter.com", "www.instagram.com",
"www.linkedin.com", "www.reddit.com", "www.netflix.com",
"www.github.com"
// Add more common domains as needed
};
int num_domains = sizeof(common_domains) / sizeof(common_domains[0]);
int domains_to_pregen = num_domains < PRE_GEN_CERT_COUNT ? num_domains : PRE_GEN_CERT_COUNT;
printf("Pre-generating certificates for %d common domains...\n", domains_to_pregen);
// Create a thread to pre-generate certificates in the background
pthread_t pregen_thread;
pthread_create(&pregen_thread, NULL, pregen_cert_thread, (void*)(long)domains_to_pregen);
pthread_detach(pregen_thread);
}
// Thread function for pre-generating certificates
void* pregen_cert_thread(void* arg) {
int count = (int)(long)arg;
// List of common domains that might be accessed
const char* common_domains[] = {
"www.google.com", "www.facebook.com", "www.youtube.com",
"www.amazon.com", "www.twitter.com", "www.instagram.com",
"www.linkedin.com", "www.reddit.com", "www.netflix.com",
"www.github.com"
// Add more common domains as needed
};
for (int i = 0; i < count; i++) {
char cert_path[256];
char key_path[256];
snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", common_domains[i]);
snprintf(key_path, sizeof(key_path), "certs/%s.key", common_domains[i]);
// Don't regenerate if it exists and is recent
if (access(cert_path, F_OK) == 0 && access(key_path, F_OK) == 0) {
// Check if we need to renew
if (!should_renew_cert(common_domains[i])) {
printf("Certificate for %s already exists and is recent\n", common_domains[i]);
continue;
}
}
printf("Pre-generating certificate for %s\n", common_domains[i]);
generate_unique_cert(common_domains[i], cert_path, key_path);
// Sleep a bit to not overwhelm the system
usleep(100000); // 100ms
}
printf("Finished pre-generating certificates\n");
return NULL;
}
// Regular HTTPS tunneling without interception
void handle_regular_https_tunnel(int client_sock, int server_sock) {
fd_set read_fds;
char buffer[MAX_REQUEST_SIZE];
int max_fd = (client_sock > server_sock) ? client_sock : server_sock;
// Send 200 Connection Established to the client
const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n";
if (send(client_sock, success_response, strlen(success_response), 0) < 0) {
perror("Failed to send connection established response");
return;
}
// Tunnel data between client and server
while (1) {
FD_ZERO(&read_fds);
FD_SET(client_sock, &read_fds);
FD_SET(server_sock, &read_fds);
if (select(max_fd + 1, &read_fds, NULL, NULL, NULL) < 0) {
perror("Select failed");
break;
}
if (FD_ISSET(client_sock, &read_fds)) {
ssize_t bytes_received = recv(client_sock, buffer, sizeof(buffer), 0);
if (bytes_received <= 0) break;
if (send(server_sock, buffer, bytes_received, 0) <= 0) break;
}
if (FD_ISSET(server_sock, &read_fds)) {
ssize_t bytes_received = recv(server_sock, buffer, sizeof(buffer), 0);
if (bytes_received <= 0) break;
if (send(client_sock, buffer, bytes_received, 0) <= 0) break;
}
}
}
// Forward data between SSL connections
void ssl_tunnel_data(SSL* client_ssl, SSL* server_ssl) {
fd_set read_fds;
char buffer[MAX_REQUEST_SIZE];
int client_fd = SSL_get_fd(client_ssl);
int server_fd = SSL_get_fd(server_ssl);
int max_fd = (client_fd > server_fd) ? client_fd : server_fd;
while (1) {
FD_ZERO(&read_fds);
FD_SET(client_fd, &read_fds);
FD_SET(server_fd, &read_fds);
if (select(max_fd + 1, &read_fds, NULL, NULL, NULL) < 0) {
perror("Select failed");
break;
}
if (FD_ISSET(client_fd, &read_fds)) {
int bytes_received = SSL_read(client_ssl, buffer, sizeof(buffer));
if (bytes_received <= 0) {
int err = SSL_get_error(client_ssl, bytes_received);
if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) {
break;
}
} else {
int bytes_sent = SSL_write(server_ssl, buffer, bytes_received);
if (bytes_sent <= 0) {
int err = SSL_get_error(server_ssl, bytes_sent);
if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) {
break;
}
}
}
}
if (FD_ISSET(server_fd, &read_fds)) {
int bytes_received = SSL_read(server_ssl, buffer, sizeof(buffer));
if (bytes_received <= 0) {
int err = SSL_get_error(server_ssl, bytes_received);
if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) {
break;
}
} else {
int bytes_sent = SSL_write(client_ssl, buffer, bytes_received);
if (bytes_sent <= 0) {
int err = SSL_get_error(client_ssl, bytes_sent);
if (err != SSL_ERROR_WANT_READ && err != SSL_ERROR_WANT_WRITE) {
break;
}
}
}
}
}
}
// Modify HTTP request for direct server communication
char* rewrite_http_request(const char* original_request, size_t* new_length) {
// Find the first line ending
const char* first_line_end = strstr(original_request, "\r\n");
if (!first_line_end) {
printf("Malformed HTTP request: no line ending found\n");
return NULL;
}
// Parse the request line
char method[16] = {0};
char url[MAX_URL_LENGTH] = {0};
char version[16] = {0};
sscanf(original_request, "%15s %2047s %15s", method, url, version);
printf("Original request: %s %s %s\n", method, url, version);
// Extract path from URL (remove http://hostname)
char* path = NULL;
if (strncmp(url, "http://", 7) == 0) {
path = strchr(url + 7, '/');
if (!path) {
path = "/";
printf("No path in URL, using default: %s\n", path);
} else {
printf("Extracted path from URL: %s\n", path);
}
} else {
path = url; // Already a path or malformed
printf("Using URL as path: %s\n", path);
}
// Calculate new request size
const char* headers_start = first_line_end + 2; // Skip \r\n
size_t headers_length = strlen(headers_start);
size_t request_line_length = strlen(method) + strlen(path) + strlen(version) + 4; // +4 for spaces and \r\n
size_t total_length = request_line_length + headers_length;
// Allocate memory for new request
char* new_request = malloc(total_length + 1);
if (!new_request) {
printf("Failed to allocate memory for HTTP request rewrite\n");
return NULL;
}
// Write the new request
int written = snprintf(new_request, total_length + 1, "%s %s %s\r\n%s",
method, path, version, headers_start);
if (written < 0 || (size_t)written > total_length) {
printf("Failed to rewrite HTTP request\n");
free(new_request);
return NULL;
}
printf("Rewritten request first line: %s %s %s\n", method, path, version);
*new_length = written;
return new_request;
}
// Optimized client handling
void handle_client_optimized(int client_sock) {
char request[MAX_REQUEST_SIZE];
ssize_t bytes_received = recv(client_sock, request, sizeof(request) - 1, 0);
if (bytes_received <= 0) {
close(client_sock);
return;
}
request[bytes_received] = '\0';
// Check if this is an HTTPS CONNECT request
int is_connect = (strncmp(request, "CONNECT ", 8) == 0);
// Extract host from request
char* host = extract_host(request);
if (!host) {
printf("Failed to extract host from request\n");
close(client_sock);
return;
}
// Extract port from request
int port = extract_port(request);
if (port <= 0 || port > 65535) {
port = is_connect ? 443 : 80;
printf("Invalid port detected, using default port: %d\n", port);
}
printf("Proxying %s request to: %s (port %d)\n",
is_connect ? "HTTPS" : "HTTP", host, port);
// Try to resolve hostname from DNS cache first
char ip_addr[IP_ADDR_MAX_LEN];
int dns_cache_result = check_dns_cache(host, ip_addr, sizeof(ip_addr));
if (dns_cache_result == 0) {
printf("Using cached DNS for %s: %s\n", host, ip_addr);
} else {
// Resolve hostname using DoH
if (resolve_doh(host, ip_addr, sizeof(ip_addr)) != 0) {
printf("Failed to resolve hostname using DoH: %s\n", host);
// Don't cache failed resolutions anymore
close(client_sock);
return;
}
printf("Resolved %s to %s\n", host, ip_addr);
// Only cache successful resolutions
add_to_dns_cache(host, ip_addr, 1);
}
// Connect to the target server
struct sockaddr_in server_addr;
int server_sock = socket(AF_INET, SOCK_STREAM, 0);
if (server_sock < 0) {
perror("Cannot create socket to server");
close(client_sock);
return;
}
memset(&server_addr, 0, sizeof(server_addr));
server_addr.sin_family = AF_INET;
server_addr.sin_addr.s_addr = inet_addr(ip_addr);
server_addr.sin_port = htons(port);
if (connect(server_sock, (struct sockaddr*)&server_addr, sizeof(server_addr)) < 0) {
perror("Cannot connect to server");
close(server_sock);
close(client_sock);
return;
}
if (is_connect) {
// HTTPS: Handle CONNECT tunnel with DANE support
handle_https_tunnel(client_sock, server_sock, host, ip_addr);
} else {
// HTTP: Rewrite the request and forward it
printf("HTTP request received, rewriting for server...\n");
size_t new_length = 0;
char* modified_request = rewrite_http_request(request, &new_length);
if (!modified_request) {
printf("Failed to rewrite HTTP request\n");
close(server_sock);
close(client_sock);
return;
}
printf("Forwarding modified HTTP request (%zu bytes) to server...\n", new_length);
// Set timeouts for sending and receiving
struct timeval tv;
tv.tv_sec = 5; // 5 seconds timeout
tv.tv_usec = 0;
setsockopt(server_sock, SOL_SOCKET, SO_SNDTIMEO, (const char*)&tv, sizeof tv);
setsockopt(server_sock, SOL_SOCKET, SO_RCVTIMEO, (const char*)&tv, sizeof tv);
if (send(server_sock, modified_request, new_length, 0) < 0) {
perror("Failed to send request to server");
free(modified_request);
close(server_sock);
close(client_sock);
return;
}
free(modified_request);
// Force-flush any pending data
int flag = 1;
setsockopt(server_sock, IPPROTO_TCP, TCP_NODELAY, (char*)&flag, sizeof(int));
// Receive response from server and forward to client
printf("Waiting for server response (timeout: 5s)...\n");
char buffer[MAX_REQUEST_SIZE];
fd_set read_fds;
FD_ZERO(&read_fds);
FD_SET(server_sock, &read_fds);
// Use select with timeout to wait for response
if (select(server_sock + 1, &read_fds, NULL, NULL, &tv) <= 0) {
printf("Timeout or error waiting for server response\n");
close(server_sock);
close(client_sock);
return;
}
ssize_t bytes_received;
while ((bytes_received = recv(server_sock, buffer, sizeof(buffer), 0)) > 0) {
printf("Received %zd bytes from server\n", bytes_received);
if (send(client_sock, buffer, bytes_received, 0) < 0) {
perror("Failed to send response to client");
break;
}
}
if (bytes_received < 0) {
perror("Error receiving from server");
} else if (bytes_received == 0) {
printf("Server closed connection normally\n");
}
}
printf("Connection closed: %s (port %d)\n", host, port);
close(server_sock);
close(client_sock);
}
// HTTPS tunnel with DANE verification, optimized version
void handle_https_tunnel_optimized(int client_sock, int server_sock, const char* hostname,
const char* ip_addr, SSL* existing_ssl, SSL_CTX* existing_ctx) {
// Mark ip_addr as used to avoid warning
(void)ip_addr;
// Use existing SSL session if provided
SSL* server_ssl = existing_ssl;
SSL_CTX* server_ctx = existing_ctx;
int need_ssl_setup = (server_ssl == NULL);
// Check the DANE cache first
int cached_result = check_cache(hostname);
// Generate certificate paths
char cert_path[256];
char key_path[256];
snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", hostname);
snprintf(key_path, sizeof(key_path), "certs/%s.key", hostname);
// If in cache and verified as successful, skip verification
if (cached_result == 1) {
// Cached DANE verification successful
// Check if certificate exists and doesn't need renewal
if (access(cert_path, F_OK) == 0 && access(key_path, F_OK) == 0 && !should_renew_cert(hostname)) {
// Initialize SSL context with our certificate
ssl_context_t* client_ctx = NULL;
if (!get_cached_cert(hostname, &client_ctx)) {
// Fall back to disk
client_ctx = init_ssl_context(cert_path, key_path);
if (client_ctx) {
// Now that we've loaded from disk, cache it for next time
cache_cert_in_memory(hostname, cert_path, key_path);
}
}
if (!client_ctx) {
// Fall back to regular tunnel
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// If we need to set up SSL, do it
if (need_ssl_setup) {
server_ctx = SSL_CTX_new(TLS_client_method());
if (!server_ctx) {
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
server_ssl = SSL_new(server_ctx);
SSL_set_fd(server_ssl, server_sock);
SSL_set_tlsext_host_name(server_ssl, hostname);
// Enable session caching for future connections
SSL_CTX_set_session_cache_mode(server_ctx, SSL_SESS_CACHE_CLIENT);
if (SSL_connect(server_ssl) <= 0) {
cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
}
// Send 200 Connection Established to client
const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n";
if (send(client_sock, success_response, strlen(success_response), 0) < 0) {
if (!need_ssl_setup) {
// Don't clean up if we're using a cached connection
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
} else {
cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
}
return;
}
// Set up client SSL
client_ctx->ssl = SSL_new(client_ctx->ctx);
SSL_set_fd(client_ctx->ssl, client_sock);
if (SSL_accept(client_ctx->ssl) <= 0) {
cleanup_ssl_connection(client_ctx->ssl, NULL);
if (need_ssl_setup) {
cleanup_ssl_connection(server_ssl, server_ctx);
}
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
return;
}
// Tunnel data
ssl_tunnel_data(client_ctx->ssl, server_ssl);
// Clean up
cleanup_ssl_connection(client_ctx->ssl, NULL);
if (need_ssl_setup) {
// If we didn't use a pooled connection, add this one to the pool
add_to_connection_pool(hostname, 443, server_sock, server_ssl, server_ctx);
}
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
return;
}
} else if (cached_result == 0) {
// Cached DANE verification failed, use regular tunnel
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Need to perform DANE verification
// Clean up any existing SSL session since we need to validate from scratch
if (!need_ssl_setup) {
// Not using server_ssl/ctx directly as they belong to pool
cleanup_ssl_connection(server_ssl, NULL);
need_ssl_setup = 1;
}
// Set up a new SSL connection to server
server_ctx = SSL_CTX_new(TLS_client_method());
if (!server_ctx) {
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
server_ssl = SSL_new(server_ctx);
SSL_set_fd(server_ssl, server_sock);
SSL_set_tlsext_host_name(server_ssl, hostname);
if (SSL_connect(server_ssl) <= 0) {
cleanup_ssl_connection(server_ssl, server_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Check for DANE records
int has_dane = is_dane_available(hostname);
if (has_dane) {
// Get server certificate and verify against DANE
X509* server_cert = SSL_get_peer_certificate(server_ssl);
if (!server_cert) {
cleanup_ssl_connection(server_ssl, server_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
int dane_verified = verify_cert_against_dane(hostname, server_cert);
add_to_cache(hostname, dane_verified > 0 ? 1 : 0);
if (dane_verified <= 0) {
// DANE verification failed
X509_free(server_cert);
cleanup_ssl_connection(server_ssl, server_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Generate trusted certificate
if (!generate_unique_cert(hostname, cert_path, key_path)) {
X509_free(server_cert);
cleanup_ssl_connection(server_ssl, server_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Set up client connection with our certificate
ssl_context_t* client_ctx = init_ssl_context(cert_path, key_path);
if (!client_ctx) {
X509_free(server_cert);
cleanup_ssl_connection(server_ssl, server_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Send 200 Connection Established to client
const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n";
if (send(client_sock, success_response, strlen(success_response), 0) < 0) {
X509_free(server_cert);
cleanup_ssl_connection(server_ssl, server_ctx);
cleanup_ssl_connection(client_ctx->ssl, NULL);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
return;
}
// Setup client SSL
client_ctx->ssl = SSL_new(client_ctx->ctx);
SSL_set_fd(client_ctx->ssl, client_sock);
if (SSL_accept(client_ctx->ssl) <= 0) {
X509_free(server_cert);
cleanup_ssl_connection(client_ctx->ssl, NULL);
cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
return;
}
// Tunnel data
ssl_tunnel_data(client_ctx->ssl, server_ssl);
// Add connection to pool for future reuse
add_to_connection_pool(hostname, 443, server_sock, server_ssl, server_ctx);
// Clean up
X509_free(server_cert);
cleanup_ssl_connection(client_ctx->ssl, NULL);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
} else {
// No DANE records
add_to_cache(hostname, 0);
cleanup_ssl_connection(server_ssl, server_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
}
}
SSL_SESSION* get_ssl_session(const char* hostname) {
pthread_mutex_lock(&ssl_session_mutex);
ssl_session_cache_entry* entry;
HASH_FIND_STR(ssl_session_cache, hostname, entry);
if (entry) {
time_t now = time(NULL);
// Check session age - 10 minute validity
if (now - entry->timestamp > 600) {
// Session too old, remove it
SSL_SESSION_free(entry->session);
HASH_DEL(ssl_session_cache, entry);
free(entry);
pthread_mutex_unlock(&ssl_session_mutex);
return NULL;
}
// Return a reference to the cached session
SSL_SESSION* session = entry->session;
pthread_mutex_unlock(&ssl_session_mutex);
return session;
}
pthread_mutex_unlock(&ssl_session_mutex);
return NULL;
}
void cache_ssl_session(const char* hostname, SSL_SESSION* session) {
if (!session) return;
pthread_mutex_lock(&ssl_session_mutex);
ssl_session_cache_entry* entry;
HASH_FIND_STR(ssl_session_cache, hostname, entry);
if (entry) {
SSL_SESSION_free(entry->session);
entry->session = session;
entry->timestamp = time(NULL);
} else {
entry = malloc(sizeof(ssl_session_cache_entry));
if (entry) {
strncpy(entry->hostname, hostname, MAX_URL_LENGTH-1);
entry->hostname[MAX_URL_LENGTH-1] = '\0';
entry->session = session;
entry->timestamp = time(NULL);
HASH_ADD_STR(ssl_session_cache, hostname, entry);
}
}
pthread_mutex_unlock(&ssl_session_mutex);
}
void cleanup_ssl_sessions() {
pthread_mutex_lock(&ssl_session_mutex);
ssl_session_cache_entry* entry, *tmp;
time_t now = time(NULL);
HASH_ITER(hh, ssl_session_cache, entry, tmp) {
if (now - entry->timestamp > 600) { // 10 minute timeout
SSL_SESSION_free(entry->session);
HASH_DEL(ssl_session_cache, entry);
free(entry);
}
}
pthread_mutex_unlock(&ssl_session_mutex);
}
int get_cached_cert(const char* hostname, ssl_context_t** ctx) {
pthread_mutex_lock(&cert_cache_mutex);
cert_cache_entry* entry;
HASH_FIND_STR(cert_memory_cache, hostname, entry);
if (!entry) {
pthread_mutex_unlock(&cert_cache_mutex);
return 0;
}
// Check if certificate is still valid
time_t now = time(NULL);
if (now - entry->timestamp > CERT_RENEWAL_TIME) {
pthread_mutex_unlock(&cert_cache_mutex);
return 0;
}
// Create SSL context from memory
*ctx = malloc(sizeof(ssl_context_t));
if (!*ctx) {
pthread_mutex_unlock(&cert_cache_mutex);
return 0;
}
(*ctx)->ctx = SSL_CTX_new(TLS_server_method());
if (!(*ctx)->ctx) {
free(*ctx);
*ctx = NULL;
pthread_mutex_unlock(&cert_cache_mutex);
return 0;
}
// Configure SSL context
SSL_CTX_set_options((*ctx)->ctx, SSL_OP_NO_SSLv2 | SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION);
// Load cert and key from memory
BIO* cbio = BIO_new_mem_buf(entry->cert_data, entry->cert_len);
X509* cert = PEM_read_bio_X509(cbio, NULL, NULL, NULL);
BIO_free(cbio);
BIO* kbio = BIO_new_mem_buf(entry->key_data, entry->key_len);
EVP_PKEY* key = PEM_read_bio_PrivateKey(kbio, NULL, NULL, NULL);
BIO_free(kbio);
if (!cert || !key ||
SSL_CTX_use_certificate((*ctx)->ctx, cert) <= 0 ||
SSL_CTX_use_PrivateKey((*ctx)->ctx, key) <= 0) {
if (cert) X509_free(cert);
if (key) EVP_PKEY_free(key);
SSL_CTX_free((*ctx)->ctx);
free(*ctx);
*ctx = NULL;
pthread_mutex_unlock(&cert_cache_mutex);
return 0;
}
X509_free(cert);
EVP_PKEY_free(key);
pthread_mutex_unlock(&cert_cache_mutex);
return 1;
}
void cache_cert_in_memory(const char* hostname, const char* cert_path, const char* key_path) {
// Read certificate file
FILE* cert_file = fopen(cert_path, "rb");
if (!cert_file) return;
fseek(cert_file, 0, SEEK_END);
size_t cert_size = ftell(cert_file);
fseek(cert_file, 0, SEEK_SET);
unsigned char* cert_data = malloc(cert_size);
if (!cert_data) {
fclose(cert_file);
return;
}
if (fread(cert_data, 1, cert_size, cert_file) != cert_size) {
fclose(cert_file);
free(cert_data);
return;
}
fclose(cert_file);
// Read key file
FILE* key_file = fopen(key_path, "rb");
if (!key_file) {
free(cert_data);
return;
}
fseek(key_file, 0, SEEK_END);
size_t key_size = ftell(key_file);
fseek(key_file, 0, SEEK_SET);
unsigned char* key_data = malloc(key_size);
if (!key_data) {
fclose(key_file);
free(cert_data);
return;
}
if (fread(key_data, 1, key_size, key_file) != key_size) {
fclose(key_file);
free(cert_data);
free(key_data);
return;
}
fclose(key_file);
// Add to cache
pthread_mutex_lock(&cert_cache_mutex);
cert_cache_entry* entry;
HASH_FIND_STR(cert_memory_cache, hostname, entry);
if (entry) {
// Update existing entry
free(entry->cert_data);
free(entry->key_data);
entry->cert_data = cert_data;
entry->cert_len = cert_size;
entry->key_data = key_data;
entry->key_len = key_size;
entry->timestamp = time(NULL);
} else {
// Create new entry
entry = malloc(sizeof(cert_cache_entry));
if (entry) {
strncpy(entry->hostname, hostname, MAX_URL_LENGTH-1);
entry->hostname[MAX_URL_LENGTH-1] = '\0';
entry->cert_data = cert_data;
entry->cert_len = cert_size;
entry->key_data = key_data;
entry->key_len = key_size;
entry->timestamp = time(NULL);
HASH_ADD_STR(cert_memory_cache, hostname, entry);
} else {
free(cert_data);
free(key_data);
}
}
pthread_mutex_unlock(&cert_cache_mutex);
}