Files
hnsdoh-status/hnsdoh_status/checks.py

251 lines
7.9 KiB
Python

from __future__ import annotations
import socket
import ssl
import time
from datetime import datetime, timezone
import dns.message
import dns.query
import dns.rcode
import dns.rdatatype
import dns.resolver
from hnsdoh_status.models import CheckResult, NodeSnapshot, ProtocolName, Snapshot
def utcnow() -> datetime:
return datetime.now(timezone.utc)
def discover_nodes(domain: str) -> tuple[list[str], str]:
resolver = dns.resolver.Resolver()
try:
answer = resolver.resolve(domain, "A")
return sorted({record.address for record in answer}), ""
except Exception as exc: # noqa: BLE001
return [], str(exc)
def _check_dns_udp(ip: str, timeout: float) -> CheckResult:
started = time.perf_counter()
checked_at = utcnow()
query = dns.message.make_query("hnsdoh.com", dns.rdatatype.A)
try:
response = dns.query.udp(query, ip, timeout=timeout, port=53)
latency = (time.perf_counter() - started) * 1000
return CheckResult(
protocol="dns_udp",
ok=bool(response.answer),
latency_ms=latency,
checked_at=checked_at,
reason="ok" if response.answer else "empty answer",
)
except Exception as exc: # noqa: BLE001
return CheckResult("dns_udp", False, None, checked_at, str(exc))
def _check_dns_tcp(ip: str, timeout: float) -> CheckResult:
started = time.perf_counter()
checked_at = utcnow()
query = dns.message.make_query("hnsdoh.com", dns.rdatatype.A)
try:
response = dns.query.tcp(query, ip, timeout=timeout, port=53)
latency = (time.perf_counter() - started) * 1000
return CheckResult(
protocol="dns_tcp",
ok=bool(response.answer),
latency_ms=latency,
checked_at=checked_at,
reason="ok" if response.answer else "empty answer",
)
except Exception as exc: # noqa: BLE001
return CheckResult("dns_tcp", False, None, checked_at, str(exc))
def _tls_connection(ip: str, port: int, hostname: str, timeout: float) -> ssl.SSLSocket:
context = ssl.create_default_context()
raw = socket.create_connection((ip, port), timeout=timeout)
try:
tls_socket = context.wrap_socket(raw, server_hostname=hostname)
return tls_socket
except Exception:
raw.close()
raise
def _decode_chunked_body(data: bytes) -> bytes:
output = bytearray()
cursor = 0
while True:
line_end = data.find(b"\r\n", cursor)
if line_end < 0:
raise ValueError("invalid chunk framing")
size_token = data[cursor:line_end].split(b";", maxsplit=1)[0].strip()
size = int(size_token or b"0", 16)
cursor = line_end + 2
if size == 0:
break
next_cursor = cursor + size
if next_cursor + 2 > len(data):
raise ValueError("truncated chunk payload")
output.extend(data[cursor:next_cursor])
if data[next_cursor : next_cursor + 2] != b"\r\n":
raise ValueError("invalid chunk terminator")
cursor = next_cursor + 2
return bytes(output)
def _parse_http_response(response: bytes) -> tuple[str, dict[str, str], bytes]:
head, separator, body = response.partition(b"\r\n\r\n")
if not separator:
raise ValueError("invalid HTTP response")
lines = head.split(b"\r\n")
status_line = lines[0].decode("latin-1", errors="replace")
headers: dict[str, str] = {}
for line in lines[1:]:
if b":" not in line:
continue
key, value = line.split(b":", maxsplit=1)
headers[key.decode("latin-1", errors="replace").lower()] = value.decode(
"latin-1", errors="replace"
).strip()
transfer_encoding = headers.get("transfer-encoding", "").lower()
if "chunked" in transfer_encoding:
body = _decode_chunked_body(body)
return status_line, headers, body
def _check_doh(ip: str, hostname: str, path: str, timeout: float) -> CheckResult:
started = time.perf_counter()
checked_at = utcnow()
query = dns.message.make_query(hostname, dns.rdatatype.A)
query_wire = query.to_wire()
request = (
f"POST {path} HTTP/1.1\r\n"
f"Host: {hostname}\r\n"
"Accept: application/dns-message\r\n"
"Content-Type: application/dns-message\r\n"
f"Content-Length: {len(query_wire)}\r\n"
"Connection: close\r\n\r\n"
).encode("ascii") + query_wire
try:
with _tls_connection(ip, 443, hostname, timeout) as conn:
conn.settimeout(timeout)
conn.sendall(request)
response = b""
while True:
chunk = conn.recv(4096)
if not chunk:
break
response += chunk
latency = (time.perf_counter() - started) * 1000
status_line, _, body = _parse_http_response(response)
status_ok = " 200 " in status_line
payload_ok = False
reason = ""
if status_ok and body:
try:
parsed_dns = dns.message.from_wire(body)
payload_ok = parsed_dns.rcode() == dns.rcode.NOERROR and bool(
parsed_dns.answer
)
if payload_ok:
reason = "ok"
elif parsed_dns.rcode() != dns.rcode.NOERROR:
reason = f"dns rcode {dns.rcode.to_text(parsed_dns.rcode())}"
else:
reason = "empty answer"
except Exception: # noqa: BLE001
reason = "invalid dns wireformat payload"
ok = status_ok and payload_ok
if not reason:
reason = f"http status failed: {status_line}"
return CheckResult("doh", ok, latency, checked_at, reason)
except ssl.SSLCertVerificationError as exc:
return CheckResult("doh", False, None, checked_at, f"tls verify failed: {exc}")
except Exception as exc: # noqa: BLE001
return CheckResult("doh", False, None, checked_at, str(exc))
def _check_dot(ip: str, hostname: str, timeout: float) -> CheckResult:
started = time.perf_counter()
checked_at = utcnow()
query = dns.message.make_query("hnsdoh.com", dns.rdatatype.A)
context = ssl.create_default_context()
try:
response = dns.query.tls(
query,
where=ip,
timeout=timeout,
port=853,
ssl_context=context,
server_hostname=hostname,
)
latency = (time.perf_counter() - started) * 1000
return CheckResult(
protocol="dot",
ok=bool(response.answer),
latency_ms=latency,
checked_at=checked_at,
reason="ok" if response.answer else "empty answer",
)
except ssl.SSLCertVerificationError as exc:
return CheckResult("dot", False, None, checked_at, f"tls verify failed: {exc}")
except Exception as exc: # noqa: BLE001
return CheckResult("dot", False, None, checked_at, str(exc))
def check_node(
ip: str,
hostname: str,
doh_path: str,
dns_timeout: float,
doh_timeout: float,
dot_timeout: float,
) -> NodeSnapshot:
results: dict[ProtocolName, CheckResult] = {
"dns_udp": _check_dns_udp(ip, dns_timeout),
"dns_tcp": _check_dns_tcp(ip, dns_timeout),
"doh": _check_doh(ip, hostname, doh_path, doh_timeout),
"dot": _check_dot(ip, hostname, dot_timeout),
}
return NodeSnapshot(ip=ip, results=results)
def run_full_check(
domain: str,
doh_path: str,
dns_timeout: float,
doh_timeout: float,
dot_timeout: float,
) -> Snapshot:
checked_at = utcnow()
nodes, discovery_error = discover_nodes(domain)
snapshots = [
check_node(ip, domain, doh_path, dns_timeout, doh_timeout, dot_timeout)
for ip in nodes
]
return Snapshot(
domain=domain,
checked_at=checked_at,
node_count=len(snapshots),
nodes=snapshots,
discovery_error=discovery_error,
)