fix: Update scheduler for automated tests

This commit is contained in:
Nathan Woodburn 2025-03-28 22:39:05 +11:00
parent 7e1ceecd0c
commit 068ba3519c
Signed by: nathanwoodburn
GPG Key ID: 203B000478AD0EF1
2 changed files with 307 additions and 75 deletions

View File

@ -5,4 +5,5 @@ dnspython
dnslib
python-dateutil
python-dotenv
schedule
schedule
apscheduler>=3.9.1

379
server.py
View File

@ -1,5 +1,5 @@
from collections import defaultdict
from functools import cache
from functools import cache, wraps
import json
from flask import (
Flask,
@ -26,11 +26,25 @@ import socket
from datetime import datetime
from dateutil import relativedelta
import dotenv
import time
import logging
import signal
import sys
from apscheduler.schedulers.background import BackgroundScheduler
from apscheduler.triggers.interval import IntervalTrigger
from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
dotenv.load_dotenv()
app = Flask(__name__)
# Configure scheduler - use coalescing to prevent missed jobs piling up
scheduler = BackgroundScheduler(daemon=True, job_defaults={'coalesce': True, 'max_instances': 1})
node_names = {
"18.169.98.42": "Easy HNS",
"172.233.46.92": "EZ Domains",
@ -144,9 +158,34 @@ def get_node_list() -> list:
return ips
# Add retry decorator for network operations
def retry(max_attempts=3, delay_seconds=1):
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
attempts = 0
last_error = None
while attempts < max_attempts:
try:
return func(*args, **kwargs)
except (socket.timeout, socket.error, dns.exception.Timeout, requests.exceptions.RequestException) as e:
attempts += 1
last_error = e
logger.warning(f"Attempt {attempts} failed with error: {e} - retrying in {delay_seconds} seconds")
if attempts < max_attempts:
time.sleep(delay_seconds)
logger.error(f"All {max_attempts} attempts failed. Last error: {last_error}")
return False # Return False as a fallback for checks
return wrapper
return decorator
@retry(max_attempts=3, delay_seconds=2)
def check_plain_dns(ip: str) -> bool:
resolver = dns.resolver.Resolver()
resolver.nameservers = [ip]
resolver.timeout = 5 # Set a reasonable timeout
resolver.lifetime = 5 # Total timeout for the query
try:
result = resolver.resolve("1.wdbrn", "TXT")
@ -154,8 +193,17 @@ def check_plain_dns(ip: str) -> bool:
if "Test 1" in txt.to_text():
return True
return False
except dns.resolver.NXDOMAIN:
logger.info(f"Domain not found for plain DNS check on {ip}")
return False
except dns.resolver.NoAnswer:
logger.info(f"No answer received for plain DNS check on {ip}")
return False
except (dns.exception.Timeout, socket.timeout):
logger.warning(f"Timeout during plain DNS check on {ip}")
raise # Re-raise for retry decorator
except Exception as e:
print(e)
logger.error(f"Error during plain DNS check on {ip}: {e}")
return False
@ -167,9 +215,13 @@ def build_dns_query(domain: str, qtype: str = "A"):
return q.pack()
@retry(max_attempts=3, delay_seconds=2)
def check_doh(ip: str) -> dict:
status = False
server_name = []
sock = None
ssock = None
try:
dns_query = build_dns_query("2.wdbrn", "TXT")
request = (
@ -181,43 +233,83 @@ def check_doh(ip: str) -> dict:
"\r\n"
)
wireframe_request = request.encode() + dns_query
sock = socket.create_connection((ip, 443))
# Create socket with timeout
sock = socket.create_connection((ip, 443), timeout=10)
context = ssl.create_default_context()
context.check_hostname = False # Skip hostname verification for IP-based connection
ssock = context.wrap_socket(sock, server_hostname="hnsdoh.com")
ssock.settimeout(10) # Set a timeout for socket operations
ssock.sendall(wireframe_request)
response_data = b""
while True:
data = ssock.recv(4096)
if not data:
break
response_data += data
try:
data = ssock.recv(4096)
if not data:
break
response_data += data
except socket.timeout:
logger.warning(f"Socket timeout while receiving data from {ip}")
if response_data: # We might have partial data
break
else:
raise
response_str = response_data.decode("latin-1")
if not response_data:
logger.warning(f"No data received from {ip}")
return {"status": status, "server": server_name}
response_str = response_data.decode("latin-1", errors="replace")
# Check if we have a complete HTTP response with headers and body
if "\r\n\r\n" not in response_str:
logger.warning(f"Incomplete HTTP response from {ip}")
return {"status": status, "server": server_name}
headers, body = response_str.split("\r\n\r\n", 1)
# Try to get server from headers
for header in headers.split("\r\n"):
if header.startswith("Server:"):
server_name.append(header.split(":")[1].strip())
if header.lower().startswith("server:"):
server_name.append(header.split(":", 1)[1].strip())
dns_response: dnslib.DNSRecord = dnslib.DNSRecord.parse(body.encode("latin-1"))
for rr in dns_response.rr:
if "Test 2" in str(rr):
status = True
try:
dns_response = dnslib.DNSRecord.parse(body.encode("latin-1"))
for rr in dns_response.rr:
if "Test 2" in str(rr):
status = True
break
except Exception as e:
logger.error(f"Error parsing DNS response from {ip}: {e}")
except (socket.timeout, socket.error) as e:
logger.warning(f"Socket error during DoH check on {ip}: {e}")
raise # Re-raise for retry decorator
except ssl.SSLError as e:
logger.error(f"SSL error during DoH check on {ip}: {e}")
return {"status": False, "server": server_name}
except Exception as e:
print(e)
logger.error(f"Unexpected error during DoH check on {ip}: {e}")
return {"status": False, "server": server_name}
finally:
# Close the socket connection
# Check if ssock is defined
if "ssock" in locals():
ssock.close()
# Ensure sockets are always closed
if ssock:
try:
ssock.close()
except:
pass
if sock and sock != ssock:
try:
sock.close()
except:
pass
return {"status": status, "server": server_name}
@retry(max_attempts=3, delay_seconds=2)
def check_dot(ip: str) -> bool:
qname = dns.name.from_text("3.wdbrn")
q = dns.message.make_query(qname, dns.rdatatype.TXT)
@ -231,39 +323,73 @@ def check_dot(ip: str) -> bool:
if "Test 3" in rr.to_text():
return True
return False
except dns.exception.Timeout:
logger.warning(f"Timeout during DoT check on {ip}")
raise # Re-raise for retry decorator
except ssl.SSLError as e:
logger.error(f"SSL error during DoT check on {ip}: {e}")
return False
except Exception as e:
print(e)
logger.error(f"Error during DoT check on {ip}: {e}")
return False
def verify_cert(ip: str, port: int) -> bool:
@retry(max_attempts=3, delay_seconds=2)
def verify_cert(ip: str, port: int) -> dict:
expires = "ERROR"
valid = False
expiry_date_str = (datetime.now() - relativedelta.relativedelta(years=1)).strftime("%b %d %H:%M:%S %Y GMT")
sock = None
ssock = None
try:
sock = socket.create_connection((ip, port))
sock = socket.create_connection((ip, port), timeout=10)
# Wrap the socket in SSL/TLS
context = ssl.create_default_context()
context.check_hostname = False # Skip hostname verification for IP-based connection
ssock = context.wrap_socket(sock, server_hostname="hnsdoh.com")
ssock.settimeout(10) # Set timeout for socket operations
# Retrieve the server's certificate
cert = ssock.getpeercert()
if not cert:
logger.error(f"No certificate returned from {ip}:{port}")
return {"valid": False, "expires": "ERROR", "expiry_date": expiry_date_str}
# Extract the expiry date from the certificate
if "notAfter" not in cert:
logger.error(f"Certificate from {ip}:{port} missing notAfter field")
return {"valid": False, "expires": "ERROR", "expiry_date": expiry_date_str}
expiry_date_str = cert["notAfter"]
# Convert the expiry date string to a datetime object
expiry_date = datetime.strptime(expiry_date_str, "%b %d %H:%M:%S %Y GMT")
expires = format_relative_time(expiry_date)
valid = expiry_date > datetime.now()
except Exception as e:
print(e)
except (socket.timeout, socket.error) as e:
logger.warning(f"Socket error during certificate check on {ip}:{port}: {e}")
raise # Re-raise for retry decorator
except ssl.SSLError as e:
logger.error(f"SSL error during certificate check on {ip}:{port}: {e}")
return {"valid": False, "expires": "ERROR", "expiry_date": expiry_date_str}
except Exception as e:
logger.error(f"Error during certificate check on {ip}:{port}: {e}")
return {"valid": False, "expires": "ERROR", "expiry_date": expiry_date_str}
finally:
# Close the SSL and socket connection
if "ssock" in locals():
ssock.close()
# Ensure sockets are always closed
if ssock:
try:
ssock.close()
except:
pass
if sock and sock != ssock:
try:
sock.close()
except:
pass
return {"valid": valid, "expires": expires, "expiry_date": expiry_date_str}
@ -321,8 +447,17 @@ def check_nodes() -> list:
else:
if len(nodes) == 0:
nodes = get_node_list()
node_status = []
for ip in nodes:
node_status = []
for ip in nodes:
logger.info(f"Checking node {ip}")
try:
plain_dns_result = check_plain_dns(ip)
doh_check = check_doh(ip)
dot_result = check_dot(ip)
cert_result = verify_cert(ip, 443)
cert_853_result = verify_cert(ip, 853)
node_status.append(
{
"ip": ip,
@ -330,17 +465,18 @@ def check_nodes() -> list:
"location": (
node_locations[ip] if ip in node_locations else "Unknown"
),
"plain_dns": check_plain_dns(ip),
"doh": check_doh(ip)["status"],
"doh_server": check_doh(ip)["server"],
"dot": check_dot(ip),
"cert": verify_cert(ip, 443),
"cert_853": verify_cert(ip, 853),
"plain_dns": plain_dns_result,
"doh": doh_check["status"],
"doh_server": doh_check["server"],
"dot": dot_result,
"cert": cert_result,
"cert_853": cert_853_result,
}
)
else:
node_status = []
for ip in nodes:
logger.info(f"Node {ip} check complete")
except Exception as e:
logger.error(f"Error checking node {ip}: {e}")
# Add a failed entry for this node to ensure it's still included
node_status.append(
{
"ip": ip,
@ -348,17 +484,18 @@ def check_nodes() -> list:
"location": (
node_locations[ip] if ip in node_locations else "Unknown"
),
"plain_dns": check_plain_dns(ip),
"doh": check_doh(ip)["status"],
"doh_server": check_doh(ip)["server"],
"dot": check_dot(ip),
"cert": verify_cert(ip, 443),
"cert_853": verify_cert(ip, 853),
"plain_dns": False,
"doh": False,
"doh_server": [],
"dot": False,
"cert": {"valid": False, "expires": "ERROR", "expiry_date": "ERROR"},
"cert_853": {"valid": False, "expires": "ERROR", "expiry_date": "ERROR"},
}
)
# Save the node status to a file
log_status(node_status)
print("Finished checking nodes", flush=True)
logger.info("Finished checking nodes")
# Send notifications if any nodes are down
for node in node_status:
@ -372,39 +509,63 @@ def check_nodes() -> list:
send_down_notification(node)
continue
# Check if cert is expiring in 7 days
cert_expiry = datetime.strptime(
node["cert"]["expiry_date"], "%b %d %H:%M:%S %Y GMT"
)
if cert_expiry < datetime.now() + relativedelta.relativedelta(days=7):
send_down_notification(node)
continue
cert_853_expiry = datetime.strptime(
node["cert_853"]["expiry_date"], "%b %d %H:%M:%S %Y GMT"
)
if cert_853_expiry < datetime.now() + relativedelta.relativedelta(days=7):
send_down_notification(node)
try:
cert_expiry = datetime.strptime(
node["cert"]["expiry_date"], "%b %d %H:%M:%S %Y GMT"
)
if cert_expiry < datetime.now() + relativedelta.relativedelta(days=7):
send_down_notification(node)
continue
cert_853_expiry = datetime.strptime(
node["cert_853"]["expiry_date"], "%b %d %H:%M:%S %Y GMT"
)
if cert_853_expiry < datetime.now() + relativedelta.relativedelta(days=7):
send_down_notification(node)
except Exception as e:
logger.error(f"Error processing certificate expiry for {node['ip']}: {e}")
return node_status
def check_nodes_from_log() -> list:
global last_log
# Load the last log
with open(f"{log_dir}/node_status.json", "r") as file:
data = json.load(file)
newest = {
"date": datetime.now() - relativedelta.relativedelta(years=1),
"nodes": [],
}
for entry in data:
if datetime.strptime(entry["date"], "%Y-%m-%d %H:%M:%S") > newest["date"]:
newest = entry
newest["date"] = datetime.strptime(newest["date"], "%Y-%m-%d %H:%M:%S")
node_status = newest["nodes"]
if datetime.now() > newest["date"] + relativedelta.relativedelta(minutes=10):
print("Failed to get a new enough log, checking nodes", flush=True)
try:
with open(f"{log_dir}/node_status.json", "r") as file:
data = json.load(file)
newest = {
"date": datetime.now() - relativedelta.relativedelta(years=1),
"nodes": [],
}
for entry in data:
if datetime.strptime(entry["date"], "%Y-%m-%d %H:%M:%S") > newest["date"]:
newest = entry
newest["date"] = datetime.strptime(newest["date"], "%Y-%m-%d %H:%M:%S")
node_status = newest["nodes"]
# Get check staleness threshold from environment variable or use default (15 minutes)
staleness_threshold_str = os.getenv("STALENESS_THRESHOLD_MINUTES", "15")
try:
staleness_threshold = int(staleness_threshold_str)
except ValueError:
logger.warning(f"Invalid STALENESS_THRESHOLD_MINUTES value: {staleness_threshold_str}, using default of 15")
staleness_threshold = 15
if datetime.now() > newest["date"] + relativedelta.relativedelta(minutes=staleness_threshold):
logger.warning(f"Data is stale (older than {staleness_threshold} minutes), triggering immediate check")
node_status = check_nodes()
else:
last_log = newest["date"]
logger.info(f"Using cached node status from {format_last_check(last_log)}")
except (FileNotFoundError, json.JSONDecodeError) as e:
logger.error(f"Error reading node status file: {e}")
logger.info("Running initial node check")
node_status = check_nodes()
else:
last_log = newest["date"]
return node_status
@ -1020,5 +1181,75 @@ def not_found(e):
# endregion
# After defining check_nodes() function
def scheduled_node_check():
"""Function to be called by the scheduler to check all nodes"""
try:
logger.info("Running scheduled node check")
# Get fresh node list on each check to pick up DNS changes
global nodes
nodes = [] # Reset node list to force refresh
check_nodes()
logger.info("Completed scheduled node check")
except Exception as e:
logger.error(f"Error in scheduled node check: {e}")
def scheduler_listener(event):
"""Listener for scheduler events"""
if event.exception:
logger.error(f"Error in scheduled job: {event.exception}")
else:
logger.debug("Scheduled job completed successfully")
# Function to start the scheduler
def start_scheduler():
# Get check interval from environment variable or use default (5 minutes)
check_interval_str = os.getenv("CHECK_INTERVAL_MINUTES", "5")
try:
check_interval = int(check_interval_str)
except ValueError:
logger.warning(f"Invalid CHECK_INTERVAL_MINUTES value: {check_interval_str}, using default of 5")
check_interval = 5
logger.info(f"Setting up scheduler to run every {check_interval} minutes")
# Add the job to the scheduler
scheduler.add_job(
scheduled_node_check,
IntervalTrigger(minutes=check_interval),
id='node_check_job',
replace_existing=True
)
# Add listener for job events
scheduler.add_listener(scheduler_listener, EVENT_JOB_ERROR | EVENT_JOB_EXECUTED)
# Start the scheduler if it's not already running
if not scheduler.running:
scheduler.start()
logger.info("Scheduler started")
# Register signal handlers for graceful shutdown in Docker
def signal_handler(sig, frame):
logger.info(f"Received signal {sig}, shutting down...")
if scheduler.running:
scheduler.shutdown()
logger.info("Scheduler shut down")
sys.exit(0)
# Register the signal handlers for Docker
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
# Initialize the scheduler when the app starts without relying on @before_first_request
# which is deprecated in newer Flask versions
with app.app_context():
start_scheduler()
# Run an immediate check
scheduled_node_check()
if __name__ == "__main__":
# The scheduler is already started in the app context above
# Run the Flask app
app.run(debug=True, port=5000, host="0.0.0.0")