From b1532a0951bebf5e3ec322cd265f6b0d019be8d6 Mon Sep 17 00:00:00 2001
From: Nathan Woodburn <github@nathan.woodburn.au>
Date: Fri, 28 Mar 2025 23:29:35 +1100
Subject: [PATCH] feat: Add some performance improvements

---
 requirements.txt |   4 +-
 server.py        | 203 +++++++++++++++++++++++++++++++++++++++--------
 2 files changed, 173 insertions(+), 34 deletions(-)

diff --git a/requirements.txt b/requirements.txt
index dadab87..d7e0dc5 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -6,4 +6,6 @@ dnslib
 python-dateutil
 python-dotenv
 schedule
-apscheduler>=3.9.1
\ No newline at end of file
+apscheduler>=3.9.1
+flask-caching
+brotli
\ No newline at end of file
diff --git a/server.py b/server.py
index 2d65d2d..0fe2288 100644
--- a/server.py
+++ b/server.py
@@ -33,14 +33,28 @@ import sys
 from apscheduler.schedulers.background import BackgroundScheduler
 from apscheduler.triggers.interval import IntervalTrigger
 from apscheduler.events import EVENT_JOB_ERROR, EVENT_JOB_EXECUTED
+from flask_caching import Cache
+import functools
+import io
+import brotli
+from io import BytesIO
 
-# Set up logging
+# Set up logging BEFORE attempting imports that might fail
 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 logger = logging.getLogger(__name__)
 
 dotenv.load_dotenv()
 
+# Configure caching
+cache_config = {
+    'CACHE_TYPE': 'SimpleCache',  # In-memory cache
+    'CACHE_DEFAULT_TIMEOUT': 300,  # 5 minutes default
+    'CACHE_THRESHOLD': 500  # Maximum number of items the cache will store
+}
+
 app = Flask(__name__)
+app.config.from_mapping(cache_config)
+cache = Cache(app)
 
 # Configure scheduler - use coalescing to prevent missed jobs piling up
 scheduler = BackgroundScheduler(daemon=True, job_defaults={'coalesce': True, 'max_instances': 1})
@@ -89,6 +103,9 @@ if (os.getenv("NODES")):
 
 print(f"Log directory: {log_dir}", flush=True)
 
+# In-memory cache for node status to reduce file I/O
+_node_status_cache = None
+_node_status_cache_time = None
 
 def find(name, path):
     for root, dirs, files in os.walk(path):
@@ -96,8 +113,30 @@ def find(name, path):
             return os.path.join(root, name)
 
 
-# Assets routes
+# Add a cache control decorator for static assets
+def add_cache_headers(max_age=3600):
+    def decorator(view_func):
+        @functools.wraps(view_func)
+        def wrapper(*args, **kwargs):
+            response = view_func(*args, **kwargs)
+            if isinstance(response, tuple):
+                response_obj = response[0]
+            else:
+                response_obj = response
+                
+            if hasattr(response_obj, 'cache_control'):
+                response_obj.cache_control.max_age = max_age
+                response_obj.cache_control.public = True
+                # Also set Expires header
+                response_obj.expires = int(time.time() + max_age)
+            
+            return response
+        return wrapper
+    return decorator
+
+# Assets routes with caching
 @app.route("/assets/<path:path>")
+@add_cache_headers(max_age=86400)  # Cache static assets for 1 day
 def send_report(path):
     if path.endswith(".json"):
         return send_from_directory(
@@ -528,9 +567,27 @@ def check_nodes() -> list:
     return node_status
 
 
+# Optimize check_nodes_from_log function with in-memory caching
 def check_nodes_from_log() -> list:
-    global last_log
-    # Load the last log
+    global last_log, _node_status_cache, _node_status_cache_time
+    
+    # Check if we have a valid cache
+    current_time = datetime.now()
+    staleness_threshold_str = os.getenv("STALENESS_THRESHOLD_MINUTES", "15")
+    
+    try:
+        staleness_threshold = int(staleness_threshold_str)
+    except ValueError:
+        logger.warning(f"Invalid STALENESS_THRESHOLD_MINUTES value: {staleness_threshold_str}")
+        staleness_threshold = 15
+    
+    # Use in-memory cache if it's fresh enough
+    if (_node_status_cache is not None and _node_status_cache_time is not None and 
+            current_time < _node_status_cache_time + relativedelta.relativedelta(minutes=staleness_threshold/2)):
+        logger.info(f"Using in-memory cache from {format_last_check(_node_status_cache_time)}")
+        return _node_status_cache
+    
+    # Otherwise load from disk or run a new check
     try:
         with open(f"{log_dir}/node_status.json", "r") as file:
             data = json.load(file)
@@ -547,24 +604,25 @@ def check_nodes_from_log() -> list:
         
         node_status = newest["nodes"]
         
-        # Get check staleness threshold from environment variable or use default (15 minutes)
-        staleness_threshold_str = os.getenv("STALENESS_THRESHOLD_MINUTES", "15")
-        try:
-            staleness_threshold = int(staleness_threshold_str)
-        except ValueError:
-            logger.warning(f"Invalid STALENESS_THRESHOLD_MINUTES value: {staleness_threshold_str}, using default of 15")
-            staleness_threshold = 15
-        
-        if datetime.now() > newest["date"] + relativedelta.relativedelta(minutes=staleness_threshold):
+        if current_time > newest["date"] + relativedelta.relativedelta(minutes=staleness_threshold):
             logger.warning(f"Data is stale (older than {staleness_threshold} minutes), triggering immediate check")
             node_status = check_nodes()
         else:
             last_log = newest["date"]
             logger.info(f"Using cached node status from {format_last_check(last_log)}")
+        
+        # Update the in-memory cache
+        _node_status_cache = node_status
+        _node_status_cache_time = current_time
+        
     except (FileNotFoundError, json.JSONDecodeError) as e:
         logger.error(f"Error reading node status file: {e}")
         logger.info("Running initial node check")
         node_status = check_nodes()
+        
+        # Update the in-memory cache
+        _node_status_cache = node_status
+        _node_status_cache_time = current_time
     
     return node_status
 
@@ -702,6 +760,7 @@ def log_status(node_status: list):
 
 # endregion
 # region History functions
+@cache.memoize(timeout=300)  # Cache for 5 minutes
 def get_history(days: int) -> list:
     log_files = [
         f
@@ -721,17 +780,29 @@ def get_history(days: int) -> list:
     return history
 
 
+# Create factory functions to replace lambdas that can't be pickled
+def create_default_node_dict():
+    return {
+        "name": "",
+        "location": "",
+        "ip": "",
+        "plain_dns": {"last_down": "Never", "percentage": 0},
+        "doh": {"last_down": "Never", "percentage": 0},
+        "dot": {"last_down": "Never", "percentage": 0},
+    }
+
+def create_default_counts_dict():
+    return {
+        "plain_dns": {"down": 0, "total": 0},
+        "doh": {"down": 0, "total": 0},
+        "dot": {"down": 0, "total": 0},
+    }
+
+@cache.memoize(timeout=600)  # Cache for 10 minutes
 def summarize_history(history: list) -> dict:
-    nodes_status = defaultdict(
-        lambda: {
-            "name": "",
-            "location": "",
-            "ip": "",
-            "plain_dns": {"last_down": "Never", "percentage": 0},
-            "doh": {"last_down": "Never", "percentage": 0},
-            "dot": {"last_down": "Never", "percentage": 0},
-        }
-    )
+    # Replace lambda with named function
+    nodes_status = defaultdict(create_default_node_dict)
+    
     overall_status = {
         "plain_dns": {"last_down": "Never", "percentage": 0},
         "doh": {"last_down": "Never", "percentage": 0},
@@ -739,13 +810,8 @@ def summarize_history(history: list) -> dict:
     }
 
     # Collect data
-    total_counts = defaultdict(
-        lambda: {
-            "plain_dns": {"down": 0, "total": 0},
-            "doh": {"down": 0, "total": 0},
-            "dot": {"down": 0, "total": 0},
-        }
-    )
+    # Replace lambda with named function
+    total_counts = defaultdict(create_default_counts_dict)
 
     for entry in history:
         date = datetime.strptime(entry["date"], "%Y-%m-%d %H:%M:%S")
@@ -882,13 +948,17 @@ def api_index():
         return render_template("api.html",endpoints=endpoints)
 
 
+# Cache node status for API requests
 @app.route("/api/nodes")
+@cache.cached(timeout=60)  # Cache for 1 minute
 def api_nodes():
     node_status = check_nodes_from_log()
     return jsonify(node_status)
 
 
+# Cache history API responses
 @app.route("/api/history")
+@cache.cached(timeout=300, query_string=True)  # Cache for 5 minutes, respect query params
 def api_history():
     history_days = 7
     if "days" in request.args:
@@ -902,6 +972,7 @@ def api_history():
 
 
 @app.route("/api/history/<int:days>")
+@cache.memoize(timeout=300)  # Cache for 5 minutes
 def api_history_days(days: int):
     history = get_history(days)
     history_summary = summarize_history(history)
@@ -909,6 +980,7 @@ def api_history_days(days: int):
 
 
 @app.route("/api/full")
+@cache.cached(timeout=300, query_string=True)  # Cache for 5 minutes, respect query params
 def api_all():
     history_days = 7
     if "history" in request.args:
@@ -931,6 +1003,7 @@ def api_refresh():
     return jsonify(node_status)
 
 @app.route("/api/latest")
+@cache.cached(timeout=60)  # Cache for 1 minute
 def api_errors():
     node_status = check_nodes_from_log()
 
@@ -1036,7 +1109,9 @@ def api_errors():
 
 
 # region Main routes
+# Cache the main page rendering
 @app.route("/")
+@cache.cached(timeout=60, query_string=True)  # Cache for 1 minute, respect query params 
 def index():
     node_status = check_nodes_from_log()
 
@@ -1140,7 +1215,9 @@ def index():
     )
 
 
+# Add cache headers to manifest.json
 @app.route("/manifest.json")
+@add_cache_headers(max_age=86400)  # Cache for 1 day
 def manifest():
     with open("templates/manifest.json", "r") as file:
         manifest = json.load(file)
@@ -1188,10 +1265,20 @@ def scheduled_node_check():
     try:
         logger.info("Running scheduled node check")
         # Get fresh node list on each check to pick up DNS changes
-        global nodes
+        global nodes, _node_status_cache, _node_status_cache_time
         nodes = []  # Reset node list to force refresh
-        check_nodes()
-        logger.info("Completed scheduled node check")
+        
+        # Run the check and update in-memory cache
+        node_status = check_nodes()
+        _node_status_cache = node_status
+        _node_status_cache_time = datetime.now()
+        
+        # Clear relevant caches
+        cache.delete_memoized(api_nodes)
+        cache.delete_memoized(api_errors)
+        cache.delete_memoized(index)
+        
+        logger.info("Completed scheduled node check and updated caches")
     except Exception as e:
         logger.error(f"Error in scheduled node check: {e}")
 
@@ -1249,6 +1336,56 @@ with app.app_context():
     # Run an immediate check
     scheduled_node_check()
 
+# Custom Brotli compression for responses
+@app.after_request
+def add_compression(response):
+    # Skip compression for responses that are:
+    # 1. Already compressed
+    # 2. Too small (< 500 bytes)
+    # 3. In direct passthrough mode (like static files)
+    # 4. Not a compressible MIME type
+    if (response.content_length is None or 
+            response.content_length < 500 or 
+            'Content-Encoding' in response.headers or
+            response.direct_passthrough):
+        return response
+    
+    # Only compress specific MIME types
+    content_type = response.headers.get('Content-Type', '')
+    compressible_types = [
+        'text/html', 
+        'text/css', 
+        'text/plain', 
+        'application/javascript', 
+        'application/json',
+        'application/xml',
+        'text/xml'
+    ]
+    
+    if not any(t in content_type for t in compressible_types):
+        return response
+    
+    accept_encoding = request.headers.get('Accept-Encoding', '')
+    
+    if 'br' in accept_encoding:
+        try:
+            # Get the response content
+            response_data = response.get_data()
+            
+            # Compress with Brotli
+            compressed_data = brotli.compress(response_data, quality=6)
+            
+            # Only apply Brotli if it results in smaller size
+            if len(compressed_data) < len(response_data):
+                response.set_data(compressed_data)
+                response.headers['Content-Encoding'] = 'br'
+                response.headers['Content-Length'] = len(compressed_data)
+        except Exception as e:
+            logger.warning(f"Brotli compression failed: {e}")
+            # If compression fails, we just return the uncompressed response
+    
+    return response
+
 if __name__ == "__main__":
     # The scheduler is already started in the app context above
     # Run the Flask app