feat: Add SPV features to fix accoutn balances
All checks were successful
Build Docker / Build Image (push) Successful in 2m9s
All checks were successful
Build Docker / Build Image (push) Successful in 2m9s
This commit is contained in:
249
account.py
249
account.py
@@ -11,6 +11,9 @@ import subprocess
|
||||
import atexit
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
import sqlite3
|
||||
from functools import wraps
|
||||
|
||||
|
||||
dotenv.load_dotenv()
|
||||
@@ -46,6 +49,7 @@ if SHOW_EXPIRED is None:
|
||||
SHOW_EXPIRED = False
|
||||
|
||||
HSD_PROCESS = None
|
||||
SPV_MODE = None
|
||||
|
||||
# Get hsdconfig.json
|
||||
HSD_CONFIG = {
|
||||
@@ -59,6 +63,9 @@ HSD_CONFIG = {
|
||||
"--agent=FireWallet"
|
||||
]
|
||||
}
|
||||
|
||||
CACHE_TTL = int(os.getenv("CACHE_TTL",90))
|
||||
|
||||
if not os.path.exists('hsdconfig.json'):
|
||||
with open('hsdconfig.json', 'w') as f:
|
||||
f.write(json.dumps(HSD_CONFIG, indent=4))
|
||||
@@ -89,6 +96,13 @@ def hsdVersion(format=True):
|
||||
info = hsd.getInfo()
|
||||
if 'error' in info:
|
||||
return -1
|
||||
|
||||
# Check if SPV mode is enabled
|
||||
global SPV_MODE
|
||||
if info.get('chain',{}).get('options',{}).get('spv',False):
|
||||
SPV_MODE = True
|
||||
else:
|
||||
SPV_MODE = False
|
||||
if format:
|
||||
return float('.'.join(info['version'].split(".")[:2]))
|
||||
else:
|
||||
@@ -215,6 +229,124 @@ def selectWallet(account: str):
|
||||
"message": response['error']['message']
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def init_domain_db():
|
||||
"""Initialize the SQLite database for domain cache."""
|
||||
os.makedirs('cache', exist_ok=True)
|
||||
db_path = os.path.join('cache', 'domains.db')
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create the domains table if it doesn't exist
|
||||
cursor.execute('''
|
||||
CREATE TABLE IF NOT EXISTS domains (
|
||||
name TEXT PRIMARY KEY,
|
||||
info TEXT,
|
||||
last_updated INTEGER
|
||||
)
|
||||
''')
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def getCachedDomains():
|
||||
"""Get cached domain information from SQLite database."""
|
||||
init_domain_db() # Ensure DB exists
|
||||
|
||||
db_path = os.path.join('cache', 'domains.db')
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.row_factory = sqlite3.Row # This allows accessing columns by name
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get all domains from the database
|
||||
cursor.execute('SELECT name, info, last_updated FROM domains')
|
||||
rows = cursor.fetchall()
|
||||
|
||||
# Convert to dictionary format
|
||||
domain_cache = {}
|
||||
for row in rows:
|
||||
try:
|
||||
domain_cache[row['name']] = json.loads(row['info'])
|
||||
domain_cache[row['name']]['last_updated'] = row['last_updated']
|
||||
except json.JSONDecodeError:
|
||||
print(f"Error parsing cached data for domain {row['name']}")
|
||||
|
||||
conn.close()
|
||||
return domain_cache
|
||||
|
||||
|
||||
ACTIVE_DOMAIN_UPDATES = set() # Track domains being updated
|
||||
DOMAIN_UPDATE_LOCK = threading.Lock() # For thread-safe access to ACTIVE_DOMAIN_UPDATES
|
||||
|
||||
def update_domain_cache(domain_names: list):
|
||||
"""Fetch domain info and update the SQLite cache."""
|
||||
if not domain_names:
|
||||
return
|
||||
|
||||
# Filter out domains that are already being updated
|
||||
domains_to_update = []
|
||||
with DOMAIN_UPDATE_LOCK:
|
||||
for domain in domain_names:
|
||||
if domain not in ACTIVE_DOMAIN_UPDATES:
|
||||
ACTIVE_DOMAIN_UPDATES.add(domain)
|
||||
domains_to_update.append(domain)
|
||||
|
||||
if not domains_to_update:
|
||||
# All requested domains are already being updated
|
||||
return
|
||||
|
||||
try:
|
||||
# Initialize database
|
||||
init_domain_db()
|
||||
|
||||
db_path = os.path.join('cache', 'domains.db')
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
for domain_name in domains_to_update:
|
||||
try:
|
||||
# Get domain info from node
|
||||
domain_info = getDomain(domain_name)
|
||||
|
||||
if 'error' in domain_info or not domain_info.get('info'):
|
||||
print(f"Failed to get info for domain {domain_name}: {domain_info.get('error', 'Unknown error')}", flush=True)
|
||||
continue
|
||||
|
||||
# Update or insert into database
|
||||
now = int(time.time())
|
||||
serialized_info = json.dumps(domain_info)
|
||||
|
||||
cursor.execute(
|
||||
'INSERT OR REPLACE INTO domains (name, info, last_updated) VALUES (?, ?, ?)',
|
||||
(domain_name, serialized_info, now)
|
||||
)
|
||||
|
||||
print(f"Updated cache for domain {domain_name}")
|
||||
except Exception as e:
|
||||
print(f"Error updating cache for domain {domain_name}: {str(e)}")
|
||||
finally:
|
||||
# Always remove from active set, even if there was an error
|
||||
with DOMAIN_UPDATE_LOCK:
|
||||
if domain_name in ACTIVE_DOMAIN_UPDATES:
|
||||
ACTIVE_DOMAIN_UPDATES.remove(domain_name)
|
||||
|
||||
# Commit all changes at once
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error updating domain cache: {str(e)}", flush=True)
|
||||
# Make sure to clean up the active set on any exception
|
||||
with DOMAIN_UPDATE_LOCK:
|
||||
for domain in domains_to_update:
|
||||
if domain in ACTIVE_DOMAIN_UPDATES:
|
||||
ACTIVE_DOMAIN_UPDATES.remove(domain)
|
||||
|
||||
print("Updated cache for domains")
|
||||
|
||||
|
||||
def getBalance(account: str):
|
||||
# Get the total balance
|
||||
@@ -232,9 +364,66 @@ def getBalance(account: str):
|
||||
|
||||
domains = getDomains(account)
|
||||
domainValue = 0
|
||||
for domain in domains:
|
||||
if domain['state'] == "CLOSED":
|
||||
domainValue += domain['value']
|
||||
domains_to_update = [] # Track domains that need cache updates
|
||||
|
||||
if isSPV():
|
||||
# Initialize database if needed
|
||||
init_domain_db()
|
||||
|
||||
# Connect to the database directly for efficient querying
|
||||
db_path = os.path.join('cache', 'domains.db')
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
cursor = conn.cursor()
|
||||
|
||||
now = int(time.time())
|
||||
cache_cutoff = now - (CACHE_TTL * 86400) # Cache TTL in days
|
||||
|
||||
for domain in domains:
|
||||
domain_name = domain['name']
|
||||
|
||||
# Check if domain is in cache and still fresh
|
||||
cursor.execute(
|
||||
'SELECT info, last_updated FROM domains WHERE name = ?',
|
||||
(domain_name,)
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
|
||||
# Only add domain for update if:
|
||||
# 1. Not in cache or stale
|
||||
# 2. Not currently being updated by another thread
|
||||
with DOMAIN_UPDATE_LOCK:
|
||||
if (not row or row['last_updated'] < cache_cutoff) and domain_name not in ACTIVE_DOMAIN_UPDATES:
|
||||
domains_to_update.append(domain_name)
|
||||
continue
|
||||
|
||||
# Use the cached info
|
||||
try:
|
||||
if row: # Make sure we have data
|
||||
domain_info = json.loads(row['info'])
|
||||
if domain_info.get('info', {}).get('state', "") == "CLOSED":
|
||||
domainValue += domain_info.get('info', {}).get('value', 0)
|
||||
except json.JSONDecodeError:
|
||||
# Only add for update if not already being updated
|
||||
with DOMAIN_UPDATE_LOCK:
|
||||
if domain_name not in ACTIVE_DOMAIN_UPDATES:
|
||||
domains_to_update.append(domain_name)
|
||||
|
||||
conn.close()
|
||||
else:
|
||||
for domain in domains:
|
||||
if domain['state'] == "CLOSED":
|
||||
domainValue += domain['value']
|
||||
|
||||
# Start background thread to update cache for missing domains
|
||||
if domains_to_update:
|
||||
thread = threading.Thread(
|
||||
target=update_domain_cache,
|
||||
args=(domains_to_update,),
|
||||
daemon=True
|
||||
)
|
||||
thread.start()
|
||||
|
||||
total = total - (domainValue/1000000)
|
||||
locked = locked - (domainValue/1000000)
|
||||
|
||||
@@ -533,17 +722,7 @@ def isOwnPrevout(account, prevout: dict):
|
||||
|
||||
|
||||
def getDomain(domain: str):
|
||||
# Get the domain
|
||||
response = hsd.rpc_getNameInfo(domain)
|
||||
if response['error'] is not None:
|
||||
return {
|
||||
"error": {
|
||||
"message": response['error']['message']
|
||||
}
|
||||
}
|
||||
|
||||
# If info is None grab from hsd.hns.au
|
||||
if response['result'] is None or response['result'].get('info') is None:
|
||||
if isSPV():
|
||||
response = requests.get(f"https://hsd.hns.au/api/v1/name/{domain}").json()
|
||||
if 'error' in response:
|
||||
return {
|
||||
@@ -553,6 +732,15 @@ def getDomain(domain: str):
|
||||
}
|
||||
return response
|
||||
|
||||
# Get the domain
|
||||
response = hsd.rpc_getNameInfo(domain)
|
||||
if response['error'] is not None:
|
||||
return {
|
||||
"error": {
|
||||
"message": response['error']['message']
|
||||
}
|
||||
}
|
||||
|
||||
return response['result']
|
||||
|
||||
def isKnownDomain(domain: str) -> bool:
|
||||
@@ -561,7 +749,6 @@ def isKnownDomain(domain: str) -> bool:
|
||||
if response['error'] is not None:
|
||||
return False
|
||||
|
||||
# If info is None grab from hsd.hns.au
|
||||
if response['result'] is None or response['result'].get('info') is None:
|
||||
return False
|
||||
return True
|
||||
@@ -570,11 +757,8 @@ def getAddressFromCoin(coinhash: str, coinindex = 0):
|
||||
# Get the address from the hash
|
||||
response = requests.get(get_node_api_url(f"coin/{coinhash}/{coinindex}"))
|
||||
if response.status_code != 200:
|
||||
# Try to get coin from hsd.hns.au
|
||||
response = requests.get(f"https://hsd.hns.au/api/v1/coin/{coinhash}/{coinindex}")
|
||||
if response.status_code != 200:
|
||||
print(f"Error getting address from coin")
|
||||
return "No Owner"
|
||||
print(f"Error getting address from coin")
|
||||
return "No Owner"
|
||||
data = response.json()
|
||||
if 'address' not in data:
|
||||
print(json.dumps(data, indent=4))
|
||||
@@ -1501,10 +1685,20 @@ def generateReport(account, format="{name},{expiry},{value},{maxBid}"):
|
||||
def convertHNS(value: int):
|
||||
return value/1000000
|
||||
|
||||
SPV_EXTERNAL_ROUTES = [
|
||||
"name",
|
||||
"coin",
|
||||
"tx",
|
||||
"block"
|
||||
]
|
||||
|
||||
def get_node_api_url(path=''):
|
||||
"""Construct a URL for the HSD node API."""
|
||||
base_url = f"http://x:{HSD_API}@{HSD_IP}:{HSD_NODE_PORT}"
|
||||
if isSPV() and any(path.startswith(route) for route in SPV_EXTERNAL_ROUTES):
|
||||
# If in SPV mode and the path is one of the external routes, use the external API
|
||||
base_url = f"https://hsd.hns.au/api/v1"
|
||||
|
||||
if path:
|
||||
# Ensure path starts with a slash if it's not empty
|
||||
if not path.startswith('/'):
|
||||
@@ -1522,6 +1716,19 @@ def get_wallet_api_url(path=''):
|
||||
return f"{base_url}{path}"
|
||||
return base_url
|
||||
|
||||
def isSPV() -> bool:
|
||||
global SPV_MODE
|
||||
if SPV_MODE is None:
|
||||
info = hsd.getInfo()
|
||||
if 'error' in info:
|
||||
return False
|
||||
|
||||
# Check if SPV mode is enabled
|
||||
if info.get('chain',{}).get('options',{}).get('spv',False):
|
||||
SPV_MODE = True
|
||||
else:
|
||||
SPV_MODE = False
|
||||
return SPV_MODE
|
||||
|
||||
# region HSD Internal Node
|
||||
|
||||
@@ -1610,6 +1817,7 @@ def hsdInit():
|
||||
|
||||
def hsdStart():
|
||||
global HSD_PROCESS
|
||||
global SPV_MODE
|
||||
if not HSD_INTERNAL_NODE:
|
||||
return
|
||||
|
||||
@@ -1650,6 +1858,7 @@ def hsdStart():
|
||||
cmd.append(f"--chain-migrate={chain_migrate}")
|
||||
if wallet_migrate:
|
||||
cmd.append(f"--wallet-migrate={wallet_migrate}")
|
||||
SPV_MODE = spv
|
||||
if spv:
|
||||
cmd.append("--spv")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user