diff --git a/requirements.txt b/requirements.txt index dadab87..d7e0dc5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,6 @@ dnslib python-dateutil python-dotenv schedule -apscheduler>=3.9.1 \ No newline at end of file +apscheduler>=3.9.1 +flask-caching +brotli \ No newline at end of file diff --git a/server.py b/server.py index 2d65d2d..0fe2288 100644 --- a/server.py +++ b/server.py @@ -33,14 +33,28 @@ import sys from apscheduler.schedulers.background import BackgroundScheduler from apscheduler.triggers.interval import IntervalTrigger from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED +from flask_caching import Cache +import functools +import io +import brotli +from io import BytesIO -# Set up logging +# Set up logging BEFORE attempting imports that might fail logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) dotenv.load_dotenv() +# Configure caching +cache_config = { + 'CACHE_TYPE': 'SimpleCache', # In-memory cache + 'CACHE_DEFAULT_TIMEOUT': 300, # 5 minutes default + 'CACHE_THRESHOLD': 500 # Maximum number of items the cache will store +} + app = Flask(__name__) +app.config.from_mapping(cache_config) +cache = Cache(app) # Configure scheduler - use coalescing to prevent missed jobs piling up scheduler = BackgroundScheduler(daemon=True, job_defaults={'coalesce': True, 'max_instances': 1}) @@ -89,6 +103,9 @@ if (os.getenv("NODES")): print(f"Log directory: {log_dir}", flush=True) +# In-memory cache for node status to reduce file I/O +_node_status_cache = None +_node_status_cache_time = None def find(name, path): for root, dirs, files in os.walk(path): @@ -96,8 +113,30 @@ def find(name, path): return os.path.join(root, name) -# Assets routes +# Add a cache control decorator for static assets +def add_cache_headers(max_age=3600): + def decorator(view_func): + @functools.wraps(view_func) + def wrapper(*args, **kwargs): + response = view_func(*args, **kwargs) + if isinstance(response, tuple): + response_obj = response[0] + else: + response_obj = response + + if hasattr(response_obj, 'cache_control'): + response_obj.cache_control.max_age = max_age + response_obj.cache_control.public = True + # Also set Expires header + response_obj.expires = int(time.time() + max_age) + + return response + return wrapper + return decorator + +# Assets routes with caching @app.route("/assets/<path:path>") +@add_cache_headers(max_age=86400) # Cache static assets for 1 day def send_report(path): if path.endswith(".json"): return send_from_directory( @@ -528,9 +567,27 @@ def check_nodes() -> list: return node_status +# Optimize check_nodes_from_log function with in-memory caching def check_nodes_from_log() -> list: - global last_log - # Load the last log + global last_log, _node_status_cache, _node_status_cache_time + + # Check if we have a valid cache + current_time = datetime.now() + staleness_threshold_str = os.getenv("STALENESS_THRESHOLD_MINUTES", "15") + + try: + staleness_threshold = int(staleness_threshold_str) + except ValueError: + logger.warning(f"Invalid STALENESS_THRESHOLD_MINUTES value: {staleness_threshold_str}") + staleness_threshold = 15 + + # Use in-memory cache if it's fresh enough + if (_node_status_cache is not None and _node_status_cache_time is not None and + current_time < _node_status_cache_time + relativedelta.relativedelta(minutes=staleness_threshold/2)): + logger.info(f"Using in-memory cache from {format_last_check(_node_status_cache_time)}") + return _node_status_cache + + # Otherwise load from disk or run a new check try: with open(f"{log_dir}/node_status.json", "r") as file: data = json.load(file) @@ -547,24 +604,25 @@ def check_nodes_from_log() -> list: node_status = newest["nodes"] - # Get check staleness threshold from environment variable or use default (15 minutes) - staleness_threshold_str = os.getenv("STALENESS_THRESHOLD_MINUTES", "15") - try: - staleness_threshold = int(staleness_threshold_str) - except ValueError: - logger.warning(f"Invalid STALENESS_THRESHOLD_MINUTES value: {staleness_threshold_str}, using default of 15") - staleness_threshold = 15 - - if datetime.now() > newest["date"] + relativedelta.relativedelta(minutes=staleness_threshold): + if current_time > newest["date"] + relativedelta.relativedelta(minutes=staleness_threshold): logger.warning(f"Data is stale (older than {staleness_threshold} minutes), triggering immediate check") node_status = check_nodes() else: last_log = newest["date"] logger.info(f"Using cached node status from {format_last_check(last_log)}") + + # Update the in-memory cache + _node_status_cache = node_status + _node_status_cache_time = current_time + except (FileNotFoundError, json.JSONDecodeError) as e: logger.error(f"Error reading node status file: {e}") logger.info("Running initial node check") node_status = check_nodes() + + # Update the in-memory cache + _node_status_cache = node_status + _node_status_cache_time = current_time return node_status @@ -702,6 +760,7 @@ def log_status(node_status: list): # endregion # region History functions +@cache.memoize(timeout=300) # Cache for 5 minutes def get_history(days: int) -> list: log_files = [ f @@ -721,17 +780,29 @@ def get_history(days: int) -> list: return history +# Create factory functions to replace lambdas that can't be pickled +def create_default_node_dict(): + return { + "name": "", + "location": "", + "ip": "", + "plain_dns": {"last_down": "Never", "percentage": 0}, + "doh": {"last_down": "Never", "percentage": 0}, + "dot": {"last_down": "Never", "percentage": 0}, + } + +def create_default_counts_dict(): + return { + "plain_dns": {"down": 0, "total": 0}, + "doh": {"down": 0, "total": 0}, + "dot": {"down": 0, "total": 0}, + } + +@cache.memoize(timeout=600) # Cache for 10 minutes def summarize_history(history: list) -> dict: - nodes_status = defaultdict( - lambda: { - "name": "", - "location": "", - "ip": "", - "plain_dns": {"last_down": "Never", "percentage": 0}, - "doh": {"last_down": "Never", "percentage": 0}, - "dot": {"last_down": "Never", "percentage": 0}, - } - ) + # Replace lambda with named function + nodes_status = defaultdict(create_default_node_dict) + overall_status = { "plain_dns": {"last_down": "Never", "percentage": 0}, "doh": {"last_down": "Never", "percentage": 0}, @@ -739,13 +810,8 @@ def summarize_history(history: list) -> dict: } # Collect data - total_counts = defaultdict( - lambda: { - "plain_dns": {"down": 0, "total": 0}, - "doh": {"down": 0, "total": 0}, - "dot": {"down": 0, "total": 0}, - } - ) + # Replace lambda with named function + total_counts = defaultdict(create_default_counts_dict) for entry in history: date = datetime.strptime(entry["date"], "%Y-%m-%d %H:%M:%S") @@ -882,13 +948,17 @@ def api_index(): return render_template("api.html",endpoints=endpoints) +# Cache node status for API requests @app.route("/api/nodes") +@cache.cached(timeout=60) # Cache for 1 minute def api_nodes(): node_status = check_nodes_from_log() return jsonify(node_status) +# Cache history API responses @app.route("/api/history") +@cache.cached(timeout=300, query_string=True) # Cache for 5 minutes, respect query params def api_history(): history_days = 7 if "days" in request.args: @@ -902,6 +972,7 @@ def api_history(): @app.route("/api/history/<int:days>") +@cache.memoize(timeout=300) # Cache for 5 minutes def api_history_days(days: int): history = get_history(days) history_summary = summarize_history(history) @@ -909,6 +980,7 @@ def api_history_days(days: int): @app.route("/api/full") +@cache.cached(timeout=300, query_string=True) # Cache for 5 minutes, respect query params def api_all(): history_days = 7 if "history" in request.args: @@ -931,6 +1003,7 @@ def api_refresh(): return jsonify(node_status) @app.route("/api/latest") +@cache.cached(timeout=60) # Cache for 1 minute def api_errors(): node_status = check_nodes_from_log() @@ -1036,7 +1109,9 @@ def api_errors(): # region Main routes +# Cache the main page rendering @app.route("/") +@cache.cached(timeout=60, query_string=True) # Cache for 1 minute, respect query params def index(): node_status = check_nodes_from_log() @@ -1140,7 +1215,9 @@ def index(): ) +# Add cache headers to manifest.json @app.route("/manifest.json") +@add_cache_headers(max_age=86400) # Cache for 1 day def manifest(): with open("templates/manifest.json", "r") as file: manifest = json.load(file) @@ -1188,10 +1265,20 @@ def scheduled_node_check(): try: logger.info("Running scheduled node check") # Get fresh node list on each check to pick up DNS changes - global nodes + global nodes, _node_status_cache, _node_status_cache_time nodes = [] # Reset node list to force refresh - check_nodes() - logger.info("Completed scheduled node check") + + # Run the check and update in-memory cache + node_status = check_nodes() + _node_status_cache = node_status + _node_status_cache_time = datetime.now() + + # Clear relevant caches + cache.delete_memoized(api_nodes) + cache.delete_memoized(api_errors) + cache.delete_memoized(index) + + logger.info("Completed scheduled node check and updated caches") except Exception as e: logger.error(f"Error in scheduled node check: {e}") @@ -1249,6 +1336,56 @@ with app.app_context(): # Run an immediate check scheduled_node_check() +# Custom Brotli compression for responses +@app.after_request +def add_compression(response): + # Skip compression for responses that are: + # 1. Already compressed + # 2. Too small (< 500 bytes) + # 3. In direct passthrough mode (like static files) + # 4. Not a compressible MIME type + if (response.content_length is None or + response.content_length < 500 or + 'Content-Encoding' in response.headers or + response.direct_passthrough): + return response + + # Only compress specific MIME types + content_type = response.headers.get('Content-Type', '') + compressible_types = [ + 'text/html', + 'text/css', + 'text/plain', + 'application/javascript', + 'application/json', + 'application/xml', + 'text/xml' + ] + + if not any(t in content_type for t in compressible_types): + return response + + accept_encoding = request.headers.get('Accept-Encoding', '') + + if 'br' in accept_encoding: + try: + # Get the response content + response_data = response.get_data() + + # Compress with Brotli + compressed_data = brotli.compress(response_data, quality=6) + + # Only apply Brotli if it results in smaller size + if len(compressed_data) < len(response_data): + response.set_data(compressed_data) + response.headers['Content-Encoding'] = 'br' + response.headers['Content-Length'] = len(compressed_data) + except Exception as e: + logger.warning(f"Brotli compression failed: {e}") + # If compression fails, we just return the uncompressed response + + return response + if __name__ == "__main__": # The scheduler is already started in the app context above # Run the Flask app