From f2cda461ba433a3d43225427ec4f1e46cbc45922 Mon Sep 17 00:00:00 2001 From: Nathan Woodburn Date: Thu, 28 Aug 2025 16:42:12 +1000 Subject: [PATCH] feat: Add SPV features to fix accoutn balances --- account.py | 249 ++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 229 insertions(+), 20 deletions(-) diff --git a/account.py b/account.py index d427f8c..8c6df1d 100644 --- a/account.py +++ b/account.py @@ -11,6 +11,9 @@ import subprocess import atexit import signal import sys +import threading +import sqlite3 +from functools import wraps dotenv.load_dotenv() @@ -46,6 +49,7 @@ if SHOW_EXPIRED is None: SHOW_EXPIRED = False HSD_PROCESS = None +SPV_MODE = None # Get hsdconfig.json HSD_CONFIG = { @@ -59,6 +63,9 @@ HSD_CONFIG = { "--agent=FireWallet" ] } + +CACHE_TTL = int(os.getenv("CACHE_TTL",90)) + if not os.path.exists('hsdconfig.json'): with open('hsdconfig.json', 'w') as f: f.write(json.dumps(HSD_CONFIG, indent=4)) @@ -89,6 +96,13 @@ def hsdVersion(format=True): info = hsd.getInfo() if 'error' in info: return -1 + + # Check if SPV mode is enabled + global SPV_MODE + if info.get('chain',{}).get('options',{}).get('spv',False): + SPV_MODE = True + else: + SPV_MODE = False if format: return float('.'.join(info['version'].split(".")[:2])) else: @@ -215,6 +229,124 @@ def selectWallet(account: str): "message": response['error']['message'] } } + + +def init_domain_db(): + """Initialize the SQLite database for domain cache.""" + os.makedirs('cache', exist_ok=True) + db_path = os.path.join('cache', 'domains.db') + + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # Create the domains table if it doesn't exist + cursor.execute(''' + CREATE TABLE IF NOT EXISTS domains ( + name TEXT PRIMARY KEY, + info TEXT, + last_updated INTEGER + ) + ''') + + conn.commit() + conn.close() + + +def getCachedDomains(): + """Get cached domain information from SQLite database.""" + init_domain_db() # Ensure DB exists + + db_path = os.path.join('cache', 'domains.db') + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row # This allows accessing columns by name + cursor = conn.cursor() + + # Get all domains from the database + cursor.execute('SELECT name, info, last_updated FROM domains') + rows = cursor.fetchall() + + # Convert to dictionary format + domain_cache = {} + for row in rows: + try: + domain_cache[row['name']] = json.loads(row['info']) + domain_cache[row['name']]['last_updated'] = row['last_updated'] + except json.JSONDecodeError: + print(f"Error parsing cached data for domain {row['name']}") + + conn.close() + return domain_cache + + +ACTIVE_DOMAIN_UPDATES = set() # Track domains being updated +DOMAIN_UPDATE_LOCK = threading.Lock() # For thread-safe access to ACTIVE_DOMAIN_UPDATES + +def update_domain_cache(domain_names: list): + """Fetch domain info and update the SQLite cache.""" + if not domain_names: + return + + # Filter out domains that are already being updated + domains_to_update = [] + with DOMAIN_UPDATE_LOCK: + for domain in domain_names: + if domain not in ACTIVE_DOMAIN_UPDATES: + ACTIVE_DOMAIN_UPDATES.add(domain) + domains_to_update.append(domain) + + if not domains_to_update: + # All requested domains are already being updated + return + + try: + # Initialize database + init_domain_db() + + db_path = os.path.join('cache', 'domains.db') + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + for domain_name in domains_to_update: + try: + # Get domain info from node + domain_info = getDomain(domain_name) + + if 'error' in domain_info or not domain_info.get('info'): + print(f"Failed to get info for domain {domain_name}: {domain_info.get('error', 'Unknown error')}", flush=True) + continue + + # Update or insert into database + now = int(time.time()) + serialized_info = json.dumps(domain_info) + + cursor.execute( + 'INSERT OR REPLACE INTO domains (name, info, last_updated) VALUES (?, ?, ?)', + (domain_name, serialized_info, now) + ) + + print(f"Updated cache for domain {domain_name}") + except Exception as e: + print(f"Error updating cache for domain {domain_name}: {str(e)}") + finally: + # Always remove from active set, even if there was an error + with DOMAIN_UPDATE_LOCK: + if domain_name in ACTIVE_DOMAIN_UPDATES: + ACTIVE_DOMAIN_UPDATES.remove(domain_name) + + # Commit all changes at once + conn.commit() + conn.close() + + except Exception as e: + print(f"Error updating domain cache: {str(e)}", flush=True) + # Make sure to clean up the active set on any exception + with DOMAIN_UPDATE_LOCK: + for domain in domains_to_update: + if domain in ACTIVE_DOMAIN_UPDATES: + ACTIVE_DOMAIN_UPDATES.remove(domain) + + print("Updated cache for domains") + def getBalance(account: str): # Get the total balance @@ -232,9 +364,66 @@ def getBalance(account: str): domains = getDomains(account) domainValue = 0 - for domain in domains: - if domain['state'] == "CLOSED": - domainValue += domain['value'] + domains_to_update = [] # Track domains that need cache updates + + if isSPV(): + # Initialize database if needed + init_domain_db() + + # Connect to the database directly for efficient querying + db_path = os.path.join('cache', 'domains.db') + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + cursor = conn.cursor() + + now = int(time.time()) + cache_cutoff = now - (CACHE_TTL * 86400) # Cache TTL in days + + for domain in domains: + domain_name = domain['name'] + + # Check if domain is in cache and still fresh + cursor.execute( + 'SELECT info, last_updated FROM domains WHERE name = ?', + (domain_name,) + ) + row = cursor.fetchone() + + # Only add domain for update if: + # 1. Not in cache or stale + # 2. Not currently being updated by another thread + with DOMAIN_UPDATE_LOCK: + if (not row or row['last_updated'] < cache_cutoff) and domain_name not in ACTIVE_DOMAIN_UPDATES: + domains_to_update.append(domain_name) + continue + + # Use the cached info + try: + if row: # Make sure we have data + domain_info = json.loads(row['info']) + if domain_info.get('info', {}).get('state', "") == "CLOSED": + domainValue += domain_info.get('info', {}).get('value', 0) + except json.JSONDecodeError: + # Only add for update if not already being updated + with DOMAIN_UPDATE_LOCK: + if domain_name not in ACTIVE_DOMAIN_UPDATES: + domains_to_update.append(domain_name) + + conn.close() + else: + for domain in domains: + if domain['state'] == "CLOSED": + domainValue += domain['value'] + + # Start background thread to update cache for missing domains + if domains_to_update: + thread = threading.Thread( + target=update_domain_cache, + args=(domains_to_update,), + daemon=True + ) + thread.start() + total = total - (domainValue/1000000) locked = locked - (domainValue/1000000) @@ -533,17 +722,7 @@ def isOwnPrevout(account, prevout: dict): def getDomain(domain: str): - # Get the domain - response = hsd.rpc_getNameInfo(domain) - if response['error'] is not None: - return { - "error": { - "message": response['error']['message'] - } - } - - # If info is None grab from hsd.hns.au - if response['result'] is None or response['result'].get('info') is None: + if isSPV(): response = requests.get(f"https://hsd.hns.au/api/v1/name/{domain}").json() if 'error' in response: return { @@ -553,6 +732,15 @@ def getDomain(domain: str): } return response + # Get the domain + response = hsd.rpc_getNameInfo(domain) + if response['error'] is not None: + return { + "error": { + "message": response['error']['message'] + } + } + return response['result'] def isKnownDomain(domain: str) -> bool: @@ -561,7 +749,6 @@ def isKnownDomain(domain: str) -> bool: if response['error'] is not None: return False - # If info is None grab from hsd.hns.au if response['result'] is None or response['result'].get('info') is None: return False return True @@ -570,11 +757,8 @@ def getAddressFromCoin(coinhash: str, coinindex = 0): # Get the address from the hash response = requests.get(get_node_api_url(f"coin/{coinhash}/{coinindex}")) if response.status_code != 200: - # Try to get coin from hsd.hns.au - response = requests.get(f"https://hsd.hns.au/api/v1/coin/{coinhash}/{coinindex}") - if response.status_code != 200: - print(f"Error getting address from coin") - return "No Owner" + print(f"Error getting address from coin") + return "No Owner" data = response.json() if 'address' not in data: print(json.dumps(data, indent=4)) @@ -1501,10 +1685,20 @@ def generateReport(account, format="{name},{expiry},{value},{maxBid}"): def convertHNS(value: int): return value/1000000 +SPV_EXTERNAL_ROUTES = [ + "name", + "coin", + "tx", + "block" +] def get_node_api_url(path=''): """Construct a URL for the HSD node API.""" base_url = f"http://x:{HSD_API}@{HSD_IP}:{HSD_NODE_PORT}" + if isSPV() and any(path.startswith(route) for route in SPV_EXTERNAL_ROUTES): + # If in SPV mode and the path is one of the external routes, use the external API + base_url = f"https://hsd.hns.au/api/v1" + if path: # Ensure path starts with a slash if it's not empty if not path.startswith('/'): @@ -1522,6 +1716,19 @@ def get_wallet_api_url(path=''): return f"{base_url}{path}" return base_url +def isSPV() -> bool: + global SPV_MODE + if SPV_MODE is None: + info = hsd.getInfo() + if 'error' in info: + return False + + # Check if SPV mode is enabled + if info.get('chain',{}).get('options',{}).get('spv',False): + SPV_MODE = True + else: + SPV_MODE = False + return SPV_MODE # region HSD Internal Node @@ -1610,6 +1817,7 @@ def hsdInit(): def hsdStart(): global HSD_PROCESS + global SPV_MODE if not HSD_INTERNAL_NODE: return @@ -1650,6 +1858,7 @@ def hsdStart(): cmd.append(f"--chain-migrate={chain_migrate}") if wallet_migrate: cmd.append(f"--wallet-migrate={wallet_migrate}") + SPV_MODE = spv if spv: cmd.append("--spv")