feat: Add SPV features to fix accoutn balances
All checks were successful
Build Docker / Build Image (push) Successful in 2m9s

This commit is contained in:
2025-08-28 16:42:12 +10:00
parent 26c5b4a4fa
commit f2cda461ba

View File

@@ -11,6 +11,9 @@ import subprocess
import atexit import atexit
import signal import signal
import sys import sys
import threading
import sqlite3
from functools import wraps
dotenv.load_dotenv() dotenv.load_dotenv()
@@ -46,6 +49,7 @@ if SHOW_EXPIRED is None:
SHOW_EXPIRED = False SHOW_EXPIRED = False
HSD_PROCESS = None HSD_PROCESS = None
SPV_MODE = None
# Get hsdconfig.json # Get hsdconfig.json
HSD_CONFIG = { HSD_CONFIG = {
@@ -59,6 +63,9 @@ HSD_CONFIG = {
"--agent=FireWallet" "--agent=FireWallet"
] ]
} }
CACHE_TTL = int(os.getenv("CACHE_TTL",90))
if not os.path.exists('hsdconfig.json'): if not os.path.exists('hsdconfig.json'):
with open('hsdconfig.json', 'w') as f: with open('hsdconfig.json', 'w') as f:
f.write(json.dumps(HSD_CONFIG, indent=4)) f.write(json.dumps(HSD_CONFIG, indent=4))
@@ -89,6 +96,13 @@ def hsdVersion(format=True):
info = hsd.getInfo() info = hsd.getInfo()
if 'error' in info: if 'error' in info:
return -1 return -1
# Check if SPV mode is enabled
global SPV_MODE
if info.get('chain',{}).get('options',{}).get('spv',False):
SPV_MODE = True
else:
SPV_MODE = False
if format: if format:
return float('.'.join(info['version'].split(".")[:2])) return float('.'.join(info['version'].split(".")[:2]))
else: else:
@@ -216,6 +230,124 @@ def selectWallet(account: str):
} }
} }
def init_domain_db():
"""Initialize the SQLite database for domain cache."""
os.makedirs('cache', exist_ok=True)
db_path = os.path.join('cache', 'domains.db')
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Create the domains table if it doesn't exist
cursor.execute('''
CREATE TABLE IF NOT EXISTS domains (
name TEXT PRIMARY KEY,
info TEXT,
last_updated INTEGER
)
''')
conn.commit()
conn.close()
def getCachedDomains():
"""Get cached domain information from SQLite database."""
init_domain_db() # Ensure DB exists
db_path = os.path.join('cache', 'domains.db')
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row # This allows accessing columns by name
cursor = conn.cursor()
# Get all domains from the database
cursor.execute('SELECT name, info, last_updated FROM domains')
rows = cursor.fetchall()
# Convert to dictionary format
domain_cache = {}
for row in rows:
try:
domain_cache[row['name']] = json.loads(row['info'])
domain_cache[row['name']]['last_updated'] = row['last_updated']
except json.JSONDecodeError:
print(f"Error parsing cached data for domain {row['name']}")
conn.close()
return domain_cache
ACTIVE_DOMAIN_UPDATES = set() # Track domains being updated
DOMAIN_UPDATE_LOCK = threading.Lock() # For thread-safe access to ACTIVE_DOMAIN_UPDATES
def update_domain_cache(domain_names: list):
"""Fetch domain info and update the SQLite cache."""
if not domain_names:
return
# Filter out domains that are already being updated
domains_to_update = []
with DOMAIN_UPDATE_LOCK:
for domain in domain_names:
if domain not in ACTIVE_DOMAIN_UPDATES:
ACTIVE_DOMAIN_UPDATES.add(domain)
domains_to_update.append(domain)
if not domains_to_update:
# All requested domains are already being updated
return
try:
# Initialize database
init_domain_db()
db_path = os.path.join('cache', 'domains.db')
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
for domain_name in domains_to_update:
try:
# Get domain info from node
domain_info = getDomain(domain_name)
if 'error' in domain_info or not domain_info.get('info'):
print(f"Failed to get info for domain {domain_name}: {domain_info.get('error', 'Unknown error')}", flush=True)
continue
# Update or insert into database
now = int(time.time())
serialized_info = json.dumps(domain_info)
cursor.execute(
'INSERT OR REPLACE INTO domains (name, info, last_updated) VALUES (?, ?, ?)',
(domain_name, serialized_info, now)
)
print(f"Updated cache for domain {domain_name}")
except Exception as e:
print(f"Error updating cache for domain {domain_name}: {str(e)}")
finally:
# Always remove from active set, even if there was an error
with DOMAIN_UPDATE_LOCK:
if domain_name in ACTIVE_DOMAIN_UPDATES:
ACTIVE_DOMAIN_UPDATES.remove(domain_name)
# Commit all changes at once
conn.commit()
conn.close()
except Exception as e:
print(f"Error updating domain cache: {str(e)}", flush=True)
# Make sure to clean up the active set on any exception
with DOMAIN_UPDATE_LOCK:
for domain in domains_to_update:
if domain in ACTIVE_DOMAIN_UPDATES:
ACTIVE_DOMAIN_UPDATES.remove(domain)
print("Updated cache for domains")
def getBalance(account: str): def getBalance(account: str):
# Get the total balance # Get the total balance
info = hsw.getBalance('default', account) info = hsw.getBalance('default', account)
@@ -232,9 +364,66 @@ def getBalance(account: str):
domains = getDomains(account) domains = getDomains(account)
domainValue = 0 domainValue = 0
domains_to_update = [] # Track domains that need cache updates
if isSPV():
# Initialize database if needed
init_domain_db()
# Connect to the database directly for efficient querying
db_path = os.path.join('cache', 'domains.db')
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
now = int(time.time())
cache_cutoff = now - (CACHE_TTL * 86400) # Cache TTL in days
for domain in domains:
domain_name = domain['name']
# Check if domain is in cache and still fresh
cursor.execute(
'SELECT info, last_updated FROM domains WHERE name = ?',
(domain_name,)
)
row = cursor.fetchone()
# Only add domain for update if:
# 1. Not in cache or stale
# 2. Not currently being updated by another thread
with DOMAIN_UPDATE_LOCK:
if (not row or row['last_updated'] < cache_cutoff) and domain_name not in ACTIVE_DOMAIN_UPDATES:
domains_to_update.append(domain_name)
continue
# Use the cached info
try:
if row: # Make sure we have data
domain_info = json.loads(row['info'])
if domain_info.get('info', {}).get('state', "") == "CLOSED":
domainValue += domain_info.get('info', {}).get('value', 0)
except json.JSONDecodeError:
# Only add for update if not already being updated
with DOMAIN_UPDATE_LOCK:
if domain_name not in ACTIVE_DOMAIN_UPDATES:
domains_to_update.append(domain_name)
conn.close()
else:
for domain in domains: for domain in domains:
if domain['state'] == "CLOSED": if domain['state'] == "CLOSED":
domainValue += domain['value'] domainValue += domain['value']
# Start background thread to update cache for missing domains
if domains_to_update:
thread = threading.Thread(
target=update_domain_cache,
args=(domains_to_update,),
daemon=True
)
thread.start()
total = total - (domainValue/1000000) total = total - (domainValue/1000000)
locked = locked - (domainValue/1000000) locked = locked - (domainValue/1000000)
@@ -533,17 +722,7 @@ def isOwnPrevout(account, prevout: dict):
def getDomain(domain: str): def getDomain(domain: str):
# Get the domain if isSPV():
response = hsd.rpc_getNameInfo(domain)
if response['error'] is not None:
return {
"error": {
"message": response['error']['message']
}
}
# If info is None grab from hsd.hns.au
if response['result'] is None or response['result'].get('info') is None:
response = requests.get(f"https://hsd.hns.au/api/v1/name/{domain}").json() response = requests.get(f"https://hsd.hns.au/api/v1/name/{domain}").json()
if 'error' in response: if 'error' in response:
return { return {
@@ -553,6 +732,15 @@ def getDomain(domain: str):
} }
return response return response
# Get the domain
response = hsd.rpc_getNameInfo(domain)
if response['error'] is not None:
return {
"error": {
"message": response['error']['message']
}
}
return response['result'] return response['result']
def isKnownDomain(domain: str) -> bool: def isKnownDomain(domain: str) -> bool:
@@ -561,7 +749,6 @@ def isKnownDomain(domain: str) -> bool:
if response['error'] is not None: if response['error'] is not None:
return False return False
# If info is None grab from hsd.hns.au
if response['result'] is None or response['result'].get('info') is None: if response['result'] is None or response['result'].get('info') is None:
return False return False
return True return True
@@ -569,9 +756,6 @@ def isKnownDomain(domain: str) -> bool:
def getAddressFromCoin(coinhash: str, coinindex = 0): def getAddressFromCoin(coinhash: str, coinindex = 0):
# Get the address from the hash # Get the address from the hash
response = requests.get(get_node_api_url(f"coin/{coinhash}/{coinindex}")) response = requests.get(get_node_api_url(f"coin/{coinhash}/{coinindex}"))
if response.status_code != 200:
# Try to get coin from hsd.hns.au
response = requests.get(f"https://hsd.hns.au/api/v1/coin/{coinhash}/{coinindex}")
if response.status_code != 200: if response.status_code != 200:
print(f"Error getting address from coin") print(f"Error getting address from coin")
return "No Owner" return "No Owner"
@@ -1501,10 +1685,20 @@ def generateReport(account, format="{name},{expiry},{value},{maxBid}"):
def convertHNS(value: int): def convertHNS(value: int):
return value/1000000 return value/1000000
SPV_EXTERNAL_ROUTES = [
"name",
"coin",
"tx",
"block"
]
def get_node_api_url(path=''): def get_node_api_url(path=''):
"""Construct a URL for the HSD node API.""" """Construct a URL for the HSD node API."""
base_url = f"http://x:{HSD_API}@{HSD_IP}:{HSD_NODE_PORT}" base_url = f"http://x:{HSD_API}@{HSD_IP}:{HSD_NODE_PORT}"
if isSPV() and any(path.startswith(route) for route in SPV_EXTERNAL_ROUTES):
# If in SPV mode and the path is one of the external routes, use the external API
base_url = f"https://hsd.hns.au/api/v1"
if path: if path:
# Ensure path starts with a slash if it's not empty # Ensure path starts with a slash if it's not empty
if not path.startswith('/'): if not path.startswith('/'):
@@ -1522,6 +1716,19 @@ def get_wallet_api_url(path=''):
return f"{base_url}{path}" return f"{base_url}{path}"
return base_url return base_url
def isSPV() -> bool:
global SPV_MODE
if SPV_MODE is None:
info = hsd.getInfo()
if 'error' in info:
return False
# Check if SPV mode is enabled
if info.get('chain',{}).get('options',{}).get('spv',False):
SPV_MODE = True
else:
SPV_MODE = False
return SPV_MODE
# region HSD Internal Node # region HSD Internal Node
@@ -1610,6 +1817,7 @@ def hsdInit():
def hsdStart(): def hsdStart():
global HSD_PROCESS global HSD_PROCESS
global SPV_MODE
if not HSD_INTERNAL_NODE: if not HSD_INTERNAL_NODE:
return return
@@ -1650,6 +1858,7 @@ def hsdStart():
cmd.append(f"--chain-migrate={chain_migrate}") cmd.append(f"--chain-migrate={chain_migrate}")
if wallet_migrate: if wallet_migrate:
cmd.append(f"--wallet-migrate={wallet_migrate}") cmd.append(f"--wallet-migrate={wallet_migrate}")
SPV_MODE = spv
if spv: if spv:
cmd.append("--spv") cmd.append("--spv")