From 068ba3519c3a8da01e5ec6dfa0c0f302b9216eb3 Mon Sep 17 00:00:00 2001 From: Nathan Woodburn <github@nathan.woodburn.au> Date: Fri, 28 Mar 2025 22:39:05 +1100 Subject: [PATCH] fix: Update scheduler for automated tests --- requirements.txt | 3 +- server.py | 379 ++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 307 insertions(+), 75 deletions(-) diff --git a/requirements.txt b/requirements.txt index d7c884a..dadab87 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ dnspython dnslib python-dateutil python-dotenv -schedule \ No newline at end of file +schedule +apscheduler>=3.9.1 \ No newline at end of file diff --git a/server.py b/server.py index e967932..2d65d2d 100644 --- a/server.py +++ b/server.py @@ -1,5 +1,5 @@ from collections import defaultdict -from functools import cache +from functools import cache, wraps import json from flask import ( Flask, @@ -26,11 +26,25 @@ import socket from datetime import datetime from dateutil import relativedelta import dotenv +import time +import logging +import signal +import sys +from apscheduler.schedulers.background import BackgroundScheduler +from apscheduler.triggers.interval import IntervalTrigger +from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED + +# Set up logging +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) dotenv.load_dotenv() app = Flask(__name__) +# Configure scheduler - use coalescing to prevent missed jobs piling up +scheduler = BackgroundScheduler(daemon=True, job_defaults={'coalesce': True, 'max_instances': 1}) + node_names = { "18.169.98.42": "Easy HNS", "172.233.46.92": "EZ Domains", @@ -144,9 +158,34 @@ def get_node_list() -> list: return ips +# Add retry decorator for network operations +def retry(max_attempts=3, delay_seconds=1): + def decorator(func): + @wraps(func) + def wrapper(*args, **kwargs): + attempts = 0 + last_error = None + while attempts < max_attempts: + try: + return func(*args, **kwargs) + except (socket.timeout, socket.error, dns.exception.Timeout, requests.exceptions.RequestException) as e: + attempts += 1 + last_error = e + logger.warning(f"Attempt {attempts} failed with error: {e} - retrying in {delay_seconds} seconds") + if attempts < max_attempts: + time.sleep(delay_seconds) + logger.error(f"All {max_attempts} attempts failed. Last error: {last_error}") + return False # Return False as a fallback for checks + return wrapper + return decorator + + +@retry(max_attempts=3, delay_seconds=2) def check_plain_dns(ip: str) -> bool: resolver = dns.resolver.Resolver() resolver.nameservers = [ip] + resolver.timeout = 5 # Set a reasonable timeout + resolver.lifetime = 5 # Total timeout for the query try: result = resolver.resolve("1.wdbrn", "TXT") @@ -154,8 +193,17 @@ def check_plain_dns(ip: str) -> bool: if "Test 1" in txt.to_text(): return True return False + except dns.resolver.NXDOMAIN: + logger.info(f"Domain not found for plain DNS check on {ip}") + return False + except dns.resolver.NoAnswer: + logger.info(f"No answer received for plain DNS check on {ip}") + return False + except (dns.exception.Timeout, socket.timeout): + logger.warning(f"Timeout during plain DNS check on {ip}") + raise # Re-raise for retry decorator except Exception as e: - print(e) + logger.error(f"Error during plain DNS check on {ip}: {e}") return False @@ -167,9 +215,13 @@ def build_dns_query(domain: str, qtype: str = "A"): return q.pack() +@retry(max_attempts=3, delay_seconds=2) def check_doh(ip: str) -> dict: status = False server_name = [] + sock = None + ssock = None + try: dns_query = build_dns_query("2.wdbrn", "TXT") request = ( @@ -181,43 +233,83 @@ def check_doh(ip: str) -> dict: "\r\n" ) wireframe_request = request.encode() + dns_query - sock = socket.create_connection((ip, 443)) + + # Create socket with timeout + sock = socket.create_connection((ip, 443), timeout=10) context = ssl.create_default_context() + context.check_hostname = False # Skip hostname verification for IP-based connection ssock = context.wrap_socket(sock, server_hostname="hnsdoh.com") + ssock.settimeout(10) # Set a timeout for socket operations ssock.sendall(wireframe_request) + response_data = b"" while True: - data = ssock.recv(4096) - if not data: - break - response_data += data + try: + data = ssock.recv(4096) + if not data: + break + response_data += data + except socket.timeout: + logger.warning(f"Socket timeout while receiving data from {ip}") + if response_data: # We might have partial data + break + else: + raise - response_str = response_data.decode("latin-1") + if not response_data: + logger.warning(f"No data received from {ip}") + return {"status": status, "server": server_name} + + response_str = response_data.decode("latin-1", errors="replace") + + # Check if we have a complete HTTP response with headers and body + if "\r\n\r\n" not in response_str: + logger.warning(f"Incomplete HTTP response from {ip}") + return {"status": status, "server": server_name} + headers, body = response_str.split("\r\n\r\n", 1) # Try to get server from headers for header in headers.split("\r\n"): - if header.startswith("Server:"): - server_name.append(header.split(":")[1].strip()) - + if header.lower().startswith("server:"): + server_name.append(header.split(":", 1)[1].strip()) - dns_response: dnslib.DNSRecord = dnslib.DNSRecord.parse(body.encode("latin-1")) - for rr in dns_response.rr: - if "Test 2" in str(rr): - status = True + try: + dns_response = dnslib.DNSRecord.parse(body.encode("latin-1")) + for rr in dns_response.rr: + if "Test 2" in str(rr): + status = True + break + except Exception as e: + logger.error(f"Error parsing DNS response from {ip}: {e}") + except (socket.timeout, socket.error) as e: + logger.warning(f"Socket error during DoH check on {ip}: {e}") + raise # Re-raise for retry decorator + except ssl.SSLError as e: + logger.error(f"SSL error during DoH check on {ip}: {e}") + return {"status": False, "server": server_name} except Exception as e: - print(e) - + logger.error(f"Unexpected error during DoH check on {ip}: {e}") + return {"status": False, "server": server_name} finally: - # Close the socket connection - # Check if ssock is defined - if "ssock" in locals(): - ssock.close() + # Ensure sockets are always closed + if ssock: + try: + ssock.close() + except: + pass + if sock and sock != ssock: + try: + sock.close() + except: + pass + return {"status": status, "server": server_name} +@retry(max_attempts=3, delay_seconds=2) def check_dot(ip: str) -> bool: qname = dns.name.from_text("3.wdbrn") q = dns.message.make_query(qname, dns.rdatatype.TXT) @@ -231,39 +323,73 @@ def check_dot(ip: str) -> bool: if "Test 3" in rr.to_text(): return True return False + except dns.exception.Timeout: + logger.warning(f"Timeout during DoT check on {ip}") + raise # Re-raise for retry decorator + except ssl.SSLError as e: + logger.error(f"SSL error during DoT check on {ip}: {e}") + return False except Exception as e: - print(e) + logger.error(f"Error during DoT check on {ip}: {e}") return False -def verify_cert(ip: str, port: int) -> bool: +@retry(max_attempts=3, delay_seconds=2) +def verify_cert(ip: str, port: int) -> dict: expires = "ERROR" valid = False expiry_date_str = (datetime.now() - relativedelta.relativedelta(years=1)).strftime("%b %d %H:%M:%S %Y GMT") + sock = None + ssock = None + try: - sock = socket.create_connection((ip, port)) + sock = socket.create_connection((ip, port), timeout=10) # Wrap the socket in SSL/TLS context = ssl.create_default_context() + context.check_hostname = False # Skip hostname verification for IP-based connection ssock = context.wrap_socket(sock, server_hostname="hnsdoh.com") - + ssock.settimeout(10) # Set timeout for socket operations # Retrieve the server's certificate cert = ssock.getpeercert() + if not cert: + logger.error(f"No certificate returned from {ip}:{port}") + return {"valid": False, "expires": "ERROR", "expiry_date": expiry_date_str} # Extract the expiry date from the certificate + if "notAfter" not in cert: + logger.error(f"Certificate from {ip}:{port} missing notAfter field") + return {"valid": False, "expires": "ERROR", "expiry_date": expiry_date_str} + expiry_date_str = cert["notAfter"] # Convert the expiry date string to a datetime object expiry_date = datetime.strptime(expiry_date_str, "%b %d %H:%M:%S %Y GMT") expires = format_relative_time(expiry_date) valid = expiry_date > datetime.now() - except Exception as e: - print(e) + except (socket.timeout, socket.error) as e: + logger.warning(f"Socket error during certificate check on {ip}:{port}: {e}") + raise # Re-raise for retry decorator + except ssl.SSLError as e: + logger.error(f"SSL error during certificate check on {ip}:{port}: {e}") + return {"valid": False, "expires": "ERROR", "expiry_date": expiry_date_str} + except Exception as e: + logger.error(f"Error during certificate check on {ip}:{port}: {e}") + return {"valid": False, "expires": "ERROR", "expiry_date": expiry_date_str} finally: - # Close the SSL and socket connection - if "ssock" in locals(): - ssock.close() + # Ensure sockets are always closed + if ssock: + try: + ssock.close() + except: + pass + if sock and sock != ssock: + try: + sock.close() + except: + pass + return {"valid": valid, "expires": expires, "expiry_date": expiry_date_str} @@ -321,8 +447,17 @@ def check_nodes() -> list: else: if len(nodes) == 0: nodes = get_node_list() - node_status = [] - for ip in nodes: + + node_status = [] + for ip in nodes: + logger.info(f"Checking node {ip}") + try: + plain_dns_result = check_plain_dns(ip) + doh_check = check_doh(ip) + dot_result = check_dot(ip) + cert_result = verify_cert(ip, 443) + cert_853_result = verify_cert(ip, 853) + node_status.append( { "ip": ip, @@ -330,17 +465,18 @@ def check_nodes() -> list: "location": ( node_locations[ip] if ip in node_locations else "Unknown" ), - "plain_dns": check_plain_dns(ip), - "doh": check_doh(ip)["status"], - "doh_server": check_doh(ip)["server"], - "dot": check_dot(ip), - "cert": verify_cert(ip, 443), - "cert_853": verify_cert(ip, 853), + "plain_dns": plain_dns_result, + "doh": doh_check["status"], + "doh_server": doh_check["server"], + "dot": dot_result, + "cert": cert_result, + "cert_853": cert_853_result, } ) - else: - node_status = [] - for ip in nodes: + logger.info(f"Node {ip} check complete") + except Exception as e: + logger.error(f"Error checking node {ip}: {e}") + # Add a failed entry for this node to ensure it's still included node_status.append( { "ip": ip, @@ -348,17 +484,18 @@ def check_nodes() -> list: "location": ( node_locations[ip] if ip in node_locations else "Unknown" ), - "plain_dns": check_plain_dns(ip), - "doh": check_doh(ip)["status"], - "doh_server": check_doh(ip)["server"], - "dot": check_dot(ip), - "cert": verify_cert(ip, 443), - "cert_853": verify_cert(ip, 853), + "plain_dns": False, + "doh": False, + "doh_server": [], + "dot": False, + "cert": {"valid": False, "expires": "ERROR", "expiry_date": "ERROR"}, + "cert_853": {"valid": False, "expires": "ERROR", "expiry_date": "ERROR"}, } ) + # Save the node status to a file log_status(node_status) - print("Finished checking nodes", flush=True) + logger.info("Finished checking nodes") # Send notifications if any nodes are down for node in node_status: @@ -372,39 +509,63 @@ def check_nodes() -> list: send_down_notification(node) continue # Check if cert is expiring in 7 days - cert_expiry = datetime.strptime( - node["cert"]["expiry_date"], "%b %d %H:%M:%S %Y GMT" - ) - if cert_expiry < datetime.now() + relativedelta.relativedelta(days=7): - send_down_notification(node) - continue - cert_853_expiry = datetime.strptime( - node["cert_853"]["expiry_date"], "%b %d %H:%M:%S %Y GMT" - ) - if cert_853_expiry < datetime.now() + relativedelta.relativedelta(days=7): - send_down_notification(node) + try: + cert_expiry = datetime.strptime( + node["cert"]["expiry_date"], "%b %d %H:%M:%S %Y GMT" + ) + if cert_expiry < datetime.now() + relativedelta.relativedelta(days=7): + send_down_notification(node) + continue + + cert_853_expiry = datetime.strptime( + node["cert_853"]["expiry_date"], "%b %d %H:%M:%S %Y GMT" + ) + if cert_853_expiry < datetime.now() + relativedelta.relativedelta(days=7): + send_down_notification(node) + except Exception as e: + logger.error(f"Error processing certificate expiry for {node['ip']}: {e}") + return node_status def check_nodes_from_log() -> list: global last_log # Load the last log - with open(f"{log_dir}/node_status.json", "r") as file: - data = json.load(file) - newest = { - "date": datetime.now() - relativedelta.relativedelta(years=1), - "nodes": [], - } - for entry in data: - if datetime.strptime(entry["date"], "%Y-%m-%d %H:%M:%S") > newest["date"]: - newest = entry - newest["date"] = datetime.strptime(newest["date"], "%Y-%m-%d %H:%M:%S") - node_status = newest["nodes"] - if datetime.now() > newest["date"] + relativedelta.relativedelta(minutes=10): - print("Failed to get a new enough log, checking nodes", flush=True) + try: + with open(f"{log_dir}/node_status.json", "r") as file: + data = json.load(file) + + newest = { + "date": datetime.now() - relativedelta.relativedelta(years=1), + "nodes": [], + } + + for entry in data: + if datetime.strptime(entry["date"], "%Y-%m-%d %H:%M:%S") > newest["date"]: + newest = entry + newest["date"] = datetime.strptime(newest["date"], "%Y-%m-%d %H:%M:%S") + + node_status = newest["nodes"] + + # Get check staleness threshold from environment variable or use default (15 minutes) + staleness_threshold_str = os.getenv("STALENESS_THRESHOLD_MINUTES", "15") + try: + staleness_threshold = int(staleness_threshold_str) + except ValueError: + logger.warning(f"Invalid STALENESS_THRESHOLD_MINUTES value: {staleness_threshold_str}, using default of 15") + staleness_threshold = 15 + + if datetime.now() > newest["date"] + relativedelta.relativedelta(minutes=staleness_threshold): + logger.warning(f"Data is stale (older than {staleness_threshold} minutes), triggering immediate check") + node_status = check_nodes() + else: + last_log = newest["date"] + logger.info(f"Using cached node status from {format_last_check(last_log)}") + except (FileNotFoundError, json.JSONDecodeError) as e: + logger.error(f"Error reading node status file: {e}") + logger.info("Running initial node check") node_status = check_nodes() - else: - last_log = newest["date"] + return node_status @@ -1020,5 +1181,75 @@ def not_found(e): # endregion + +# After defining check_nodes() function +def scheduled_node_check(): + """Function to be called by the scheduler to check all nodes""" + try: + logger.info("Running scheduled node check") + # Get fresh node list on each check to pick up DNS changes + global nodes + nodes = [] # Reset node list to force refresh + check_nodes() + logger.info("Completed scheduled node check") + except Exception as e: + logger.error(f"Error in scheduled node check: {e}") + +def scheduler_listener(event): + """Listener for scheduler events""" + if event.exception: + logger.error(f"Error in scheduled job: {event.exception}") + else: + logger.debug("Scheduled job completed successfully") + +# Function to start the scheduler +def start_scheduler(): + # Get check interval from environment variable or use default (5 minutes) + check_interval_str = os.getenv("CHECK_INTERVAL_MINUTES", "5") + try: + check_interval = int(check_interval_str) + except ValueError: + logger.warning(f"Invalid CHECK_INTERVAL_MINUTES value: {check_interval_str}, using default of 5") + check_interval = 5 + + logger.info(f"Setting up scheduler to run every {check_interval} minutes") + + # Add the job to the scheduler + scheduler.add_job( + scheduled_node_check, + IntervalTrigger(minutes=check_interval), + id='node_check_job', + replace_existing=True + ) + + # Add listener for job events + scheduler.add_listener(scheduler_listener, EVENT_JOB_ERROR | EVENT_JOB_EXECUTED) + + # Start the scheduler if it's not already running + if not scheduler.running: + scheduler.start() + logger.info("Scheduler started") + +# Register signal handlers for graceful shutdown in Docker +def signal_handler(sig, frame): + logger.info(f"Received signal {sig}, shutting down...") + if scheduler.running: + scheduler.shutdown() + logger.info("Scheduler shut down") + sys.exit(0) + +# Register the signal handlers for Docker +signal.signal(signal.SIGINT, signal_handler) +signal.signal(signal.SIGTERM, signal_handler) + +# Initialize the scheduler when the app starts without relying on @before_first_request +# which is deprecated in newer Flask versions +with app.app_context(): + start_scheduler() + # Run an immediate check + scheduled_node_check() + if __name__ == "__main__": + # The scheduler is already started in the app context above + # Run the Flask app app.run(debug=True, port=5000, host="0.0.0.0")