feat: Add new status page
This commit is contained in:
250
hnsdoh_status/checks.py
Normal file
250
hnsdoh_status/checks.py
Normal file
@@ -0,0 +1,250 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import socket
|
||||
import ssl
|
||||
import time
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import dns.message
|
||||
import dns.query
|
||||
import dns.rcode
|
||||
import dns.rdatatype
|
||||
import dns.resolver
|
||||
|
||||
from hnsdoh_status.models import CheckResult, NodeSnapshot, ProtocolName, Snapshot
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
def discover_nodes(domain: str) -> tuple[list[str], str]:
|
||||
resolver = dns.resolver.Resolver()
|
||||
try:
|
||||
answer = resolver.resolve(domain, "A")
|
||||
return sorted({record.address for record in answer}), ""
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return [], str(exc)
|
||||
|
||||
|
||||
def _check_dns_udp(ip: str, timeout: float) -> CheckResult:
|
||||
started = time.perf_counter()
|
||||
checked_at = utcnow()
|
||||
query = dns.message.make_query("hnsdoh.com", dns.rdatatype.A)
|
||||
try:
|
||||
response = dns.query.udp(query, ip, timeout=timeout, port=53)
|
||||
latency = (time.perf_counter() - started) * 1000
|
||||
return CheckResult(
|
||||
protocol="dns_udp",
|
||||
ok=bool(response.answer),
|
||||
latency_ms=latency,
|
||||
checked_at=checked_at,
|
||||
reason="ok" if response.answer else "empty answer",
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return CheckResult("dns_udp", False, None, checked_at, str(exc))
|
||||
|
||||
|
||||
def _check_dns_tcp(ip: str, timeout: float) -> CheckResult:
|
||||
started = time.perf_counter()
|
||||
checked_at = utcnow()
|
||||
query = dns.message.make_query("hnsdoh.com", dns.rdatatype.A)
|
||||
try:
|
||||
response = dns.query.tcp(query, ip, timeout=timeout, port=53)
|
||||
latency = (time.perf_counter() - started) * 1000
|
||||
return CheckResult(
|
||||
protocol="dns_tcp",
|
||||
ok=bool(response.answer),
|
||||
latency_ms=latency,
|
||||
checked_at=checked_at,
|
||||
reason="ok" if response.answer else "empty answer",
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return CheckResult("dns_tcp", False, None, checked_at, str(exc))
|
||||
|
||||
|
||||
def _tls_connection(ip: str, port: int, hostname: str, timeout: float) -> ssl.SSLSocket:
|
||||
context = ssl.create_default_context()
|
||||
raw = socket.create_connection((ip, port), timeout=timeout)
|
||||
try:
|
||||
tls_socket = context.wrap_socket(raw, server_hostname=hostname)
|
||||
return tls_socket
|
||||
except Exception:
|
||||
raw.close()
|
||||
raise
|
||||
|
||||
|
||||
def _decode_chunked_body(data: bytes) -> bytes:
|
||||
output = bytearray()
|
||||
cursor = 0
|
||||
|
||||
while True:
|
||||
line_end = data.find(b"\r\n", cursor)
|
||||
if line_end < 0:
|
||||
raise ValueError("invalid chunk framing")
|
||||
|
||||
size_token = data[cursor:line_end].split(b";", maxsplit=1)[0].strip()
|
||||
size = int(size_token or b"0", 16)
|
||||
cursor = line_end + 2
|
||||
|
||||
if size == 0:
|
||||
break
|
||||
|
||||
next_cursor = cursor + size
|
||||
if next_cursor + 2 > len(data):
|
||||
raise ValueError("truncated chunk payload")
|
||||
output.extend(data[cursor:next_cursor])
|
||||
|
||||
if data[next_cursor : next_cursor + 2] != b"\r\n":
|
||||
raise ValueError("invalid chunk terminator")
|
||||
cursor = next_cursor + 2
|
||||
|
||||
return bytes(output)
|
||||
|
||||
|
||||
def _parse_http_response(response: bytes) -> tuple[str, dict[str, str], bytes]:
|
||||
head, separator, body = response.partition(b"\r\n\r\n")
|
||||
if not separator:
|
||||
raise ValueError("invalid HTTP response")
|
||||
|
||||
lines = head.split(b"\r\n")
|
||||
status_line = lines[0].decode("latin-1", errors="replace")
|
||||
headers: dict[str, str] = {}
|
||||
|
||||
for line in lines[1:]:
|
||||
if b":" not in line:
|
||||
continue
|
||||
key, value = line.split(b":", maxsplit=1)
|
||||
headers[key.decode("latin-1", errors="replace").lower()] = value.decode(
|
||||
"latin-1", errors="replace"
|
||||
).strip()
|
||||
|
||||
transfer_encoding = headers.get("transfer-encoding", "").lower()
|
||||
if "chunked" in transfer_encoding:
|
||||
body = _decode_chunked_body(body)
|
||||
|
||||
return status_line, headers, body
|
||||
|
||||
|
||||
def _check_doh(ip: str, hostname: str, path: str, timeout: float) -> CheckResult:
|
||||
started = time.perf_counter()
|
||||
checked_at = utcnow()
|
||||
query = dns.message.make_query(hostname, dns.rdatatype.A)
|
||||
query_wire = query.to_wire()
|
||||
request = (
|
||||
f"POST {path} HTTP/1.1\r\n"
|
||||
f"Host: {hostname}\r\n"
|
||||
"Accept: application/dns-message\r\n"
|
||||
"Content-Type: application/dns-message\r\n"
|
||||
f"Content-Length: {len(query_wire)}\r\n"
|
||||
"Connection: close\r\n\r\n"
|
||||
).encode("ascii") + query_wire
|
||||
|
||||
try:
|
||||
with _tls_connection(ip, 443, hostname, timeout) as conn:
|
||||
conn.settimeout(timeout)
|
||||
conn.sendall(request)
|
||||
response = b""
|
||||
while True:
|
||||
chunk = conn.recv(4096)
|
||||
if not chunk:
|
||||
break
|
||||
response += chunk
|
||||
|
||||
latency = (time.perf_counter() - started) * 1000
|
||||
status_line, _, body = _parse_http_response(response)
|
||||
status_ok = " 200 " in status_line
|
||||
|
||||
payload_ok = False
|
||||
reason = ""
|
||||
if status_ok and body:
|
||||
try:
|
||||
parsed_dns = dns.message.from_wire(body)
|
||||
payload_ok = parsed_dns.rcode() == dns.rcode.NOERROR and bool(
|
||||
parsed_dns.answer
|
||||
)
|
||||
if payload_ok:
|
||||
reason = "ok"
|
||||
elif parsed_dns.rcode() != dns.rcode.NOERROR:
|
||||
reason = f"dns rcode {dns.rcode.to_text(parsed_dns.rcode())}"
|
||||
else:
|
||||
reason = "empty answer"
|
||||
except Exception: # noqa: BLE001
|
||||
reason = "invalid dns wireformat payload"
|
||||
|
||||
ok = status_ok and payload_ok
|
||||
if not reason:
|
||||
reason = f"http status failed: {status_line}"
|
||||
|
||||
return CheckResult("doh", ok, latency, checked_at, reason)
|
||||
except ssl.SSLCertVerificationError as exc:
|
||||
return CheckResult("doh", False, None, checked_at, f"tls verify failed: {exc}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return CheckResult("doh", False, None, checked_at, str(exc))
|
||||
|
||||
|
||||
def _check_dot(ip: str, hostname: str, timeout: float) -> CheckResult:
|
||||
started = time.perf_counter()
|
||||
checked_at = utcnow()
|
||||
query = dns.message.make_query("hnsdoh.com", dns.rdatatype.A)
|
||||
context = ssl.create_default_context()
|
||||
try:
|
||||
response = dns.query.tls(
|
||||
query,
|
||||
where=ip,
|
||||
timeout=timeout,
|
||||
port=853,
|
||||
ssl_context=context,
|
||||
server_hostname=hostname,
|
||||
)
|
||||
latency = (time.perf_counter() - started) * 1000
|
||||
return CheckResult(
|
||||
protocol="dot",
|
||||
ok=bool(response.answer),
|
||||
latency_ms=latency,
|
||||
checked_at=checked_at,
|
||||
reason="ok" if response.answer else "empty answer",
|
||||
)
|
||||
except ssl.SSLCertVerificationError as exc:
|
||||
return CheckResult("dot", False, None, checked_at, f"tls verify failed: {exc}")
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return CheckResult("dot", False, None, checked_at, str(exc))
|
||||
|
||||
|
||||
def check_node(
|
||||
ip: str,
|
||||
hostname: str,
|
||||
doh_path: str,
|
||||
dns_timeout: float,
|
||||
doh_timeout: float,
|
||||
dot_timeout: float,
|
||||
) -> NodeSnapshot:
|
||||
results: dict[ProtocolName, CheckResult] = {
|
||||
"dns_udp": _check_dns_udp(ip, dns_timeout),
|
||||
"dns_tcp": _check_dns_tcp(ip, dns_timeout),
|
||||
"doh": _check_doh(ip, hostname, doh_path, doh_timeout),
|
||||
"dot": _check_dot(ip, hostname, dot_timeout),
|
||||
}
|
||||
return NodeSnapshot(ip=ip, results=results)
|
||||
|
||||
|
||||
def run_full_check(
|
||||
domain: str,
|
||||
doh_path: str,
|
||||
dns_timeout: float,
|
||||
doh_timeout: float,
|
||||
dot_timeout: float,
|
||||
) -> Snapshot:
|
||||
checked_at = utcnow()
|
||||
nodes, discovery_error = discover_nodes(domain)
|
||||
snapshots = [
|
||||
check_node(ip, domain, doh_path, dns_timeout, doh_timeout, dot_timeout)
|
||||
for ip in nodes
|
||||
]
|
||||
return Snapshot(
|
||||
domain=domain,
|
||||
checked_at=checked_at,
|
||||
node_count=len(snapshots),
|
||||
nodes=snapshots,
|
||||
discovery_error=discovery_error,
|
||||
)
|
||||
Reference in New Issue
Block a user