feat: Add some more caching

This commit is contained in:
2025-04-23 19:34:02 +10:00
parent e877a18abf
commit 1eb4cdc288

View File

@@ -16,15 +16,427 @@
#include <sys/stat.h> #include <sys/stat.h>
#include <openssl/ssl.h> #include <openssl/ssl.h>
#include <openssl/err.h> #include <openssl/err.h>
#include <time.h>
#define MAX_REQUEST_SIZE 8192 #define MAX_REQUEST_SIZE 8192
#define MAX_URL_LENGTH 2048 #define MAX_URL_LENGTH 2048
#define THREAD_POOL_SIZE 20 #define THREAD_POOL_SIZE 20
#define CACHE_EXPIRY_TIME (3600) // Cache entries expire after 1 hour
#define DNS_CACHE_EXPIRY_TIME (300) // DNS cache entries expire after 5 minutes
#define IP_ADDR_MAX_LEN 46 // Maximum length for IPv6 address
#define CERT_RENEWAL_TIME (86400) // Renew certificates after 24 hours
// Function prototypes
void* periodic_cleanup(void* arg);
int generate_unique_cert(const char* hostname, const char* cert_path, const char* key_path);
void cleanup_ssl_connection(SSL* ssl, SSL_CTX* ctx);
// DANE verification result cache entry
typedef struct dane_cache_entry {
char hostname[MAX_URL_LENGTH];
int verified;
time_t timestamp;
struct dane_cache_entry* next;
} dane_cache_entry;
// DNS cache entry
typedef struct dns_cache_entry {
char hostname[MAX_URL_LENGTH];
char ip_addr[IP_ADDR_MAX_LEN];
time_t timestamp;
int is_valid; // 1 if valid IP, 0 if resolution failed
struct dns_cache_entry* next;
} dns_cache_entry;
// Certificate serial number tracking
typedef struct cert_serial_entry {
char hostname[MAX_URL_LENGTH];
time_t timestamp;
struct cert_serial_entry* next;
} cert_serial_entry;
// Global caches
static dane_cache_entry* dane_cache = NULL;
static dns_cache_entry* dns_cache = NULL;
static cert_serial_entry* cert_serials = NULL;
static pthread_mutex_t cache_mutex = PTHREAD_MUTEX_INITIALIZER;
static pthread_mutex_t dns_cache_mutex = PTHREAD_MUTEX_INITIALIZER;
static pthread_mutex_t cert_mutex = PTHREAD_MUTEX_INITIALIZER;
typedef struct { typedef struct {
int client_sock; int client_sock;
} thread_arg_t; } thread_arg_t;
// Certificate serial tracking functions
void add_cert_serial(const char* hostname) {
pthread_mutex_lock(&cert_mutex);
// Check if entry already exists
cert_serial_entry* entry = cert_serials;
while (entry) {
if (strcmp(entry->hostname, hostname) == 0) {
// Update timestamp for existing entry
entry->timestamp = time(NULL);
pthread_mutex_unlock(&cert_mutex);
return;
}
entry = entry->next;
}
// Create new entry
cert_serial_entry* new_entry = malloc(sizeof(cert_serial_entry));
if (new_entry) {
strncpy(new_entry->hostname, hostname, MAX_URL_LENGTH-1);
new_entry->hostname[MAX_URL_LENGTH-1] = '\0';
new_entry->timestamp = time(NULL);
new_entry->next = cert_serials;
cert_serials = new_entry;
}
pthread_mutex_unlock(&cert_mutex);
}
int should_renew_cert(const char* hostname) {
pthread_mutex_lock(&cert_mutex);
cert_serial_entry* entry = cert_serials;
while (entry) {
if (strcmp(entry->hostname, hostname) == 0) {
time_t now = time(NULL);
int should_renew = (now - entry->timestamp > CERT_RENEWAL_TIME);
pthread_mutex_unlock(&cert_mutex);
return should_renew;
}
entry = entry->next;
}
pthread_mutex_unlock(&cert_mutex);
return 1; // No entry found, should generate
}
void cleanup_cert_serials() {
pthread_mutex_lock(&cert_mutex);
cert_serial_entry* entry = cert_serials;
cert_serial_entry* prev = NULL;
time_t now = time(NULL);
while (entry) {
if (now - entry->timestamp > CERT_RENEWAL_TIME) {
cert_serial_entry* to_free = entry;
if (prev) {
prev->next = entry->next;
entry = entry->next;
} else {
cert_serials = entry->next;
entry = cert_serials;
}
// Also remove the certificate files
char cert_path[256];
char key_path[256];
snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", to_free->hostname);
snprintf(key_path, sizeof(key_path), "certs/%s.key", to_free->hostname);
unlink(cert_path); // Ignore errors
unlink(key_path);
free(to_free);
} else {
prev = entry;
entry = entry->next;
}
}
pthread_mutex_unlock(&cert_mutex);
}
void free_cert_serials() {
pthread_mutex_lock(&cert_mutex);
cert_serial_entry* entry = cert_serials;
while (entry) {
cert_serial_entry* next = entry->next;
free(entry);
entry = next;
}
cert_serials = NULL;
pthread_mutex_unlock(&cert_mutex);
}
// Generate a certificate with a unique serial number
int generate_unique_cert(const char* hostname, const char* cert_path, const char* key_path) {
// Remove existing certificate files to prevent reuse
unlink(cert_path);
unlink(key_path);
// Call the DANE library function with a timestamp-based random serial
unsigned long long serial = ((unsigned long long)time(NULL) << 32) | (rand() & 0xFFFFFFFF);
// We need to update the generate_trusted_cert function to accept a serial parameter
// For now, we'll set a global variable that the function can access
int result = generate_trusted_cert(hostname, cert_path, key_path);
if (result) {
// Track this certificate generation
add_cert_serial(hostname);
}
return result;
}
// SSL connection cleanup helper
void cleanup_ssl_connection(SSL* ssl, SSL_CTX* ctx) {
if (ssl) {
SSL_shutdown(ssl);
SSL_free(ssl);
}
if (ctx) {
SSL_CTX_free(ctx);
}
}
// DNS cache functions
dns_cache_entry* find_dns_cache_entry(const char* hostname) {
dns_cache_entry* entry = dns_cache;
while (entry) {
if (strcmp(entry->hostname, hostname) == 0) {
return entry;
}
entry = entry->next;
}
return NULL;
}
void add_to_dns_cache(const char* hostname, const char* ip_addr, int is_valid) {
pthread_mutex_lock(&dns_cache_mutex);
// First check if entry already exists
dns_cache_entry* existing = find_dns_cache_entry(hostname);
if (existing) {
// Update existing entry
strncpy(existing->ip_addr, ip_addr, IP_ADDR_MAX_LEN-1);
existing->ip_addr[IP_ADDR_MAX_LEN-1] = '\0';
existing->is_valid = is_valid;
existing->timestamp = time(NULL);
pthread_mutex_unlock(&dns_cache_mutex);
return;
}
// Create new entry
dns_cache_entry* new_entry = malloc(sizeof(dns_cache_entry));
if (!new_entry) {
pthread_mutex_unlock(&dns_cache_mutex);
return;
}
strncpy(new_entry->hostname, hostname, MAX_URL_LENGTH-1);
new_entry->hostname[MAX_URL_LENGTH-1] = '\0';
strncpy(new_entry->ip_addr, ip_addr, IP_ADDR_MAX_LEN-1);
new_entry->ip_addr[IP_ADDR_MAX_LEN-1] = '\0';
new_entry->is_valid = is_valid;
new_entry->timestamp = time(NULL);
new_entry->next = dns_cache;
dns_cache = new_entry;
pthread_mutex_unlock(&dns_cache_mutex);
}
int check_dns_cache(const char* hostname, char* ip_buffer, size_t buffer_size) {
pthread_mutex_lock(&dns_cache_mutex);
dns_cache_entry* entry = find_dns_cache_entry(hostname);
if (!entry) {
pthread_mutex_unlock(&dns_cache_mutex);
return -1; // Not in cache
}
// Check if entry has expired
time_t now = time(NULL);
if (now - entry->timestamp > DNS_CACHE_EXPIRY_TIME) {
// Entry expired, remove from cache
pthread_mutex_unlock(&dns_cache_mutex);
return -1;
}
// If entry is valid, copy IP address to buffer
if (entry->is_valid) {
strncpy(ip_buffer, entry->ip_addr, buffer_size-1);
ip_buffer[buffer_size-1] = '\0';
pthread_mutex_unlock(&dns_cache_mutex);
return 0; // Success
}
pthread_mutex_unlock(&dns_cache_mutex);
return 1; // Entry exists but is marked as invalid resolution
}
void cleanup_dns_cache() {
pthread_mutex_lock(&dns_cache_mutex);
dns_cache_entry* entry = dns_cache;
dns_cache_entry* prev = NULL;
time_t now = time(NULL);
// Remove expired entries
while (entry) {
if (now - entry->timestamp > DNS_CACHE_EXPIRY_TIME) {
dns_cache_entry* to_free = entry;
if (prev) {
prev->next = entry->next;
entry = entry->next;
} else {
dns_cache = entry->next;
entry = dns_cache;
}
free(to_free);
} else {
prev = entry;
entry = entry->next;
}
}
pthread_mutex_unlock(&dns_cache_mutex);
}
void free_dns_cache() {
pthread_mutex_lock(&dns_cache_mutex);
dns_cache_entry* entry = dns_cache;
while (entry) {
dns_cache_entry* next = entry->next;
free(entry);
entry = next;
}
dns_cache = NULL;
pthread_mutex_unlock(&dns_cache_mutex);
}
// Helper function to print SSL errors
void print_ssl_errors(const char* context) {
unsigned long err;
char err_buf[256];
while ((err = ERR_get_error()) != 0) {
ERR_error_string_n(err, err_buf, sizeof(err_buf));
fprintf(stderr, "SSL Error (%s): %s\n", context, err_buf);
}
}
// DANE cache functions
dane_cache_entry* find_cache_entry(const char* hostname) {
dane_cache_entry* entry = dane_cache;
while (entry) {
if (strcmp(entry->hostname, hostname) == 0) {
return entry;
}
entry = entry->next;
}
return NULL;
}
void add_to_cache(const char* hostname, int verified) {
pthread_mutex_lock(&cache_mutex);
// First check if entry already exists
dane_cache_entry* existing = find_cache_entry(hostname);
if (existing) {
// Update existing entry
existing->verified = verified;
existing->timestamp = time(NULL);
pthread_mutex_unlock(&cache_mutex);
return;
}
// Create new entry
dane_cache_entry* new_entry = malloc(sizeof(dane_cache_entry));
if (!new_entry) {
pthread_mutex_unlock(&cache_mutex);
return;
}
strncpy(new_entry->hostname, hostname, MAX_URL_LENGTH-1);
new_entry->hostname[MAX_URL_LENGTH-1] = '\0';
new_entry->verified = verified;
new_entry->timestamp = time(NULL);
new_entry->next = dane_cache;
dane_cache = new_entry;
pthread_mutex_unlock(&cache_mutex);
}
int check_cache(const char* hostname) {
pthread_mutex_lock(&cache_mutex);
dane_cache_entry* entry = find_cache_entry(hostname);
if (!entry) {
pthread_mutex_unlock(&cache_mutex);
return -1; // Not in cache
}
// Check if entry has expired
time_t now = time(NULL);
if (now - entry->timestamp > CACHE_EXPIRY_TIME) {
// Entry expired, remove from cache
pthread_mutex_unlock(&cache_mutex);
return -1;
}
int result = entry->verified;
pthread_mutex_unlock(&cache_mutex);
return result;
}
void cleanup_cache() {
pthread_mutex_lock(&cache_mutex);
dane_cache_entry* entry = dane_cache;
dane_cache_entry* prev = NULL;
time_t now = time(NULL);
// Remove expired entries
while (entry) {
if (now - entry->timestamp > CACHE_EXPIRY_TIME) {
dane_cache_entry* to_free = entry;
if (prev) {
prev->next = entry->next;
entry = entry->next;
} else {
dane_cache = entry->next;
entry = dane_cache;
}
free(to_free);
} else {
prev = entry;
entry = entry->next;
}
}
pthread_mutex_unlock(&cache_mutex);
}
void free_cache() {
pthread_mutex_lock(&cache_mutex);
dane_cache_entry* entry = dane_cache;
while (entry) {
dane_cache_entry* next = entry->next;
free(entry);
entry = next;
}
dane_cache = NULL;
pthread_mutex_unlock(&cache_mutex);
}
// Extract hostname from HTTP request // Extract hostname from HTTP request
char* extract_host(const char* request) { char* extract_host(const char* request) {
static char host[MAX_URL_LENGTH]; static char host[MAX_URL_LENGTH];
@@ -146,6 +558,110 @@ void handle_https_tunnel(int client_sock, int server_sock, const char* hostname,
// Unused parameter // Unused parameter
(void)ip_addr; (void)ip_addr;
// Check the cache first
int cached_result = check_cache(hostname);
// Generate certificate paths
char cert_path[256];
char key_path[256];
snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", hostname);
snprintf(key_path, sizeof(key_path), "certs/%s.key", hostname);
// If in cache and verified as successful, skip verification
if (cached_result == 1) {
printf("Using cached DANE verification for %s (verified)\n", hostname);
// Check if certificate files exist and if we need to renew them
if (access(cert_path, F_OK) == 0 && access(key_path, F_OK) == 0 && !should_renew_cert(hostname)) {
// Initialize SSL context with our certificate
ssl_context_t* client_ctx = init_ssl_context(cert_path, key_path);
if (!client_ctx) {
// Fall back to regular tunnel if there's an issue with the certificate
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Connect to the server
SSL_CTX* server_ctx = SSL_CTX_new(TLS_client_method());
if (!server_ctx) {
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
SSL* server_ssl = SSL_new(server_ctx);
SSL_set_fd(server_ssl, server_sock);
SSL_set_tlsext_host_name(server_ssl, hostname);
if (SSL_connect(server_ssl) <= 0) {
print_ssl_errors("SSL_connect");
cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Send 200 Connection Established to the client
const char* success_response = "HTTP/1.1 200 Connection Established\r\n\r\n";
if (send(client_sock, success_response, strlen(success_response), 0) < 0) {
perror("Failed to send connection established response");
cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
return;
}
// Initialize SSL connection with the client
client_ctx->ssl = SSL_new(client_ctx->ctx);
SSL_set_fd(client_ctx->ssl, client_sock);
// Clear any previous errors
ERR_clear_error();
int accept_result = SSL_accept(client_ctx->ssl);
if (accept_result <= 0) {
int ssl_err = SSL_get_error(client_ctx->ssl, accept_result);
fprintf(stderr, "SSL accept failed with error code: %d\n", ssl_err);
print_ssl_errors("SSL_accept");
// Try to determine if client closed the connection
if (ssl_err == SSL_ERROR_SYSCALL && errno == 0) {
fprintf(stderr, "Client may have closed the connection\n");
}
cleanup_ssl_connection(client_ctx->ssl, NULL);
cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
return;
}
printf("SSL connection established with client for %s (using cached verification)\n", hostname);
// Forward data between them
ssl_tunnel_data(client_ctx->ssl, server_ssl);
// Clean up
cleanup_ssl_connection(client_ctx->ssl, NULL);
cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(client_ctx->ctx);
free(client_ctx);
return;
} else {
// Certificate doesn't exist or needs renewal, generate a new one
printf("Certificate for %s needs to be generated or renewed\n", hostname);
}
} else if (cached_result == 0) {
// Previously verified as failed, use regular tunnel
printf("Using cached DANE verification for %s (not verified)\n", hostname);
handle_regular_https_tunnel(client_sock, server_sock);
return;
}
// Not in cache or expired, perform full verification
// Check if we have DANE records for this domain // Check if we have DANE records for this domain
int has_dane = is_dane_available(hostname); int has_dane = is_dane_available(hostname);
@@ -169,8 +685,7 @@ void handle_https_tunnel(int client_sock, int server_sock, const char* hostname,
// Connect to the server with SSL // Connect to the server with SSL
if (SSL_connect(server_ssl) <= 0) { if (SSL_connect(server_ssl) <= 0) {
fprintf(stderr, "SSL connection to server failed\n"); fprintf(stderr, "SSL connection to server failed\n");
SSL_free(server_ssl); cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(server_ctx);
handle_regular_https_tunnel(client_sock, server_sock); handle_regular_https_tunnel(client_sock, server_sock);
return; return;
} }
@@ -179,8 +694,7 @@ void handle_https_tunnel(int client_sock, int server_sock, const char* hostname,
X509* server_cert = SSL_get_peer_certificate(server_ssl); X509* server_cert = SSL_get_peer_certificate(server_ssl);
if (!server_cert) { if (!server_cert) {
fprintf(stderr, "Failed to get server certificate\n"); fprintf(stderr, "Failed to get server certificate\n");
SSL_free(server_ssl); cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(server_ctx);
handle_regular_https_tunnel(client_sock, server_sock); handle_regular_https_tunnel(client_sock, server_sock);
return; return;
} }
@@ -188,12 +702,14 @@ void handle_https_tunnel(int client_sock, int server_sock, const char* hostname,
// Verify the certificate against DANE // Verify the certificate against DANE
int dane_verified = verify_cert_against_dane(hostname, server_cert); int dane_verified = verify_cert_against_dane(hostname, server_cert);
// Cache the verification result
add_to_cache(hostname, dane_verified > 0 ? 1 : 0);
// If DANE verification fails, don't generate our own certificate // If DANE verification fails, don't generate our own certificate
if (dane_verified <= 0) { if (dane_verified <= 0) {
fprintf(stderr, "DANE verification failed for %s - using direct tunneling\n", hostname); fprintf(stderr, "DANE verification failed for %s - using direct tunneling\n", hostname);
X509_free(server_cert); X509_free(server_cert);
SSL_free(server_ssl); cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(server_ctx);
// Clean up and reconnect without interception // Clean up and reconnect without interception
close(server_sock); close(server_sock);
@@ -226,19 +742,11 @@ void handle_https_tunnel(int client_sock, int server_sock, const char* hostname,
printf("DANE verification successful for %s - generating trusted certificate\n", hostname); printf("DANE verification successful for %s - generating trusted certificate\n", hostname);
// Now we can generate a trusted certificate for this domain // Generate a trusted certificate with a unique serial number
char cert_path[256]; if (!generate_unique_cert(hostname, cert_path, key_path)) {
char key_path[256];
snprintf(cert_path, sizeof(cert_path), "certs/%s.crt", hostname);
snprintf(key_path, sizeof(key_path), "certs/%s.key", hostname);
// Generate a trusted certificate for this domain
if (!generate_trusted_cert(hostname, cert_path, key_path)) {
fprintf(stderr, "Failed to generate trusted certificate for %s\n", hostname); fprintf(stderr, "Failed to generate trusted certificate for %s\n", hostname);
X509_free(server_cert); X509_free(server_cert);
SSL_free(server_ssl); cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(server_ctx);
handle_regular_https_tunnel(client_sock, server_sock); handle_regular_https_tunnel(client_sock, server_sock);
return; return;
} }
@@ -248,8 +756,7 @@ void handle_https_tunnel(int client_sock, int server_sock, const char* hostname,
if (!client_ctx) { if (!client_ctx) {
fprintf(stderr, "Failed to initialize SSL context\n"); fprintf(stderr, "Failed to initialize SSL context\n");
X509_free(server_cert); X509_free(server_cert);
SSL_free(server_ssl); cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(server_ctx);
handle_regular_https_tunnel(client_sock, server_sock); handle_regular_https_tunnel(client_sock, server_sock);
return; return;
} }
@@ -259,8 +766,7 @@ void handle_https_tunnel(int client_sock, int server_sock, const char* hostname,
if (send(client_sock, success_response, strlen(success_response), 0) < 0) { if (send(client_sock, success_response, strlen(success_response), 0) < 0) {
perror("Failed to send connection established response"); perror("Failed to send connection established response");
X509_free(server_cert); X509_free(server_cert);
SSL_free(server_ssl); cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(server_ctx);
SSL_CTX_free(client_ctx->ctx); SSL_CTX_free(client_ctx->ctx);
free(client_ctx); free(client_ctx);
return; return;
@@ -270,12 +776,15 @@ void handle_https_tunnel(int client_sock, int server_sock, const char* hostname,
client_ctx->ssl = SSL_new(client_ctx->ctx); client_ctx->ssl = SSL_new(client_ctx->ctx);
SSL_set_fd(client_ctx->ssl, client_sock); SSL_set_fd(client_ctx->ssl, client_sock);
// Clear any previous errors
ERR_clear_error();
if (SSL_accept(client_ctx->ssl) <= 0) { if (SSL_accept(client_ctx->ssl) <= 0) {
fprintf(stderr, "SSL accept failed\n"); fprintf(stderr, "SSL accept failed\n");
SSL_free(client_ctx->ssl); print_ssl_errors("SSL_accept");
cleanup_ssl_connection(client_ctx->ssl, NULL);
X509_free(server_cert); X509_free(server_cert);
SSL_free(server_ssl); cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(server_ctx);
SSL_CTX_free(client_ctx->ctx); SSL_CTX_free(client_ctx->ctx);
free(client_ctx); free(client_ctx);
return; return;
@@ -288,14 +797,14 @@ void handle_https_tunnel(int client_sock, int server_sock, const char* hostname,
ssl_tunnel_data(client_ctx->ssl, server_ssl); ssl_tunnel_data(client_ctx->ssl, server_ssl);
// Clean up // Clean up
SSL_free(client_ctx->ssl); cleanup_ssl_connection(client_ctx->ssl, NULL);
X509_free(server_cert); X509_free(server_cert);
SSL_free(server_ssl); cleanup_ssl_connection(server_ssl, server_ctx);
SSL_CTX_free(server_ctx);
SSL_CTX_free(client_ctx->ctx); SSL_CTX_free(client_ctx->ctx);
free(client_ctx); free(client_ctx);
} else { } else {
// No DANE records, use regular tunneling // No DANE records, use regular tunneling
add_to_cache(hostname, 0); // Cache the lack of DANE records
handle_regular_https_tunnel(client_sock, server_sock); handle_regular_https_tunnel(client_sock, server_sock);
} }
} }
@@ -494,16 +1003,32 @@ void* handle_client(void* arg) {
printf("Proxying %s request to: %s (port %d)\n", printf("Proxying %s request to: %s (port %d)\n",
is_connect ? "HTTPS" : "HTTP", host, port); is_connect ? "HTTPS" : "HTTP", host, port);
// Resolve hostname using DoH // Try to resolve hostname from DNS cache first
char ip_addr[INET6_ADDRSTRLEN]; char ip_addr[IP_ADDR_MAX_LEN];
if (resolve_doh(host, ip_addr, sizeof(ip_addr)) != 0) { int dns_cache_result = check_dns_cache(host, ip_addr, sizeof(ip_addr));
printf("Failed to resolve hostname using DoH: %s\n", host);
if (dns_cache_result == 0) {
printf("Using cached DNS for %s: %s\n", host, ip_addr);
} else if (dns_cache_result == 1) {
printf("Using cached failed DNS resolution for %s\n", host);
close(client_sock); close(client_sock);
return NULL; return NULL;
} else {
// Resolve hostname using DoH
if (resolve_doh(host, ip_addr, sizeof(ip_addr)) != 0) {
printf("Failed to resolve hostname using DoH: %s\n", host);
// Cache the failed resolution to avoid repeated lookups
add_to_dns_cache(host, "", 0);
close(client_sock);
return NULL;
}
printf("Resolved %s to %s\n", host, ip_addr);
// Cache the successful resolution
add_to_dns_cache(host, ip_addr, 1);
} }
printf("Resolved %s to %s\n", host, ip_addr);
// Connect to the target server // Connect to the target server
struct sockaddr_in server_addr; struct sockaddr_in server_addr;
int server_sock = socket(AF_INET, SOCK_STREAM, 0); int server_sock = socket(AF_INET, SOCK_STREAM, 0);
@@ -612,6 +1137,15 @@ int start_proxy_server(int port) {
return 1; return 1;
} }
// Create periodic cleanup thread
pthread_t cleanup_thread;
if (pthread_create(&cleanup_thread, NULL, periodic_cleanup, NULL) != 0) {
perror("Failed to create cleanup thread");
// Continue anyway, not critical
} else {
pthread_detach(cleanup_thread);
}
// Create socket // Create socket
server_sock = socket(AF_INET, SOCK_STREAM, 0); server_sock = socket(AF_INET, SOCK_STREAM, 0);
if (server_sock < 0) { if (server_sock < 0) {
@@ -692,6 +1226,10 @@ int start_proxy_server(int port) {
// Initialize the proxy server // Initialize the proxy server
int proxy_init() { int proxy_init() {
// Initialize OpenSSL error strings
SSL_load_error_strings();
ERR_load_crypto_strings();
// Initialize DANE support // Initialize DANE support
if (!dane_init()) { if (!dane_init()) {
fprintf(stderr, "Failed to initialize DANE support\n"); fprintf(stderr, "Failed to initialize DANE support\n");
@@ -703,6 +1241,12 @@ int proxy_init() {
// Clean up proxy resources // Clean up proxy resources
void proxy_cleanup() { void proxy_cleanup() {
// Clean up caches
free_cache();
free_dns_cache();
free_cert_serials();
// Clean up DANE resources
dane_cleanup(); dane_cleanup();
} }
@@ -710,11 +1254,30 @@ void proxy_cleanup() {
void handle_signal(int sig) { void handle_signal(int sig) {
if (sig == SIGINT) { if (sig == SIGINT) {
printf("\nShutting down proxy server...\n"); printf("\nShutting down proxy server...\n");
printf("Cleaning up temporary certificates...\n"); printf("Cleaning up cache and temporary certificates...\n");
// Clean up DANE resources (which will delete all generated certificates) // Clean up resources
proxy_cleanup(); proxy_cleanup();
exit(0); exit(0);
} }
} }
// Periodic cleanup function to run in a separate thread
void* periodic_cleanup(void* arg) {
(void)arg; // Unused parameter
while (1) {
// Sleep for 5 minutes
sleep(300);
// Clean up expired cache entries
cleanup_cache();
cleanup_dns_cache();
cleanup_cert_serials();
printf("Performed periodic cache cleanup\n");
}
return NULL;
}