import random
import dns.resolver
import subprocess
import tempfile
import binascii
from cryptography import x509
from cryptography.hazmat.backends import default_backend
import datetime
from dns import resolver, dnssec, name, exception
import time

resolver = dns.resolver.Resolver()
resolver.nameservers = ["194.50.5.28","194.50.5.27","194.50.5.26"]
resolver.port = 53


def check_ssl(domain: str):    
    domain_check = False
    returns = {"success": False,"valid":False}
    try:
        # Query the DNS record
        response = resolver.resolve(domain, "A")
        records = []
        for record in response:
            records.append(str(record))

        if not records:
            return {"success": False, "message": "No A record found for " + domain}
            
        returns["ip"] = records[0]
        if len(records) > 1:
            returns["other_ips"] = records[1:]

        returns["dnssec"] = validate_dnssec(domain)


        # Get the first A record        
        ip = records[0]
        
        # Run the openssl s_client command
        s_client_command = ["openssl","s_client","-showcerts","-connect",f"{ip}:443","-servername",domain,]

        s_client_process = subprocess.Popen(s_client_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
        s_client_output, _ = s_client_process.communicate(input=b"\n")
        
        certificates = []
        current_cert = ""
        for line in s_client_output.split(b"\n"):
            current_cert += line.decode("utf-8") + "\n"
            if "-----END CERTIFICATE-----" in line.decode("utf-8"):
                certificates.append(current_cert)
                current_cert = ""

        # Remove anything before -----BEGIN CERTIFICATE-----
        certificates = [cert[cert.find("-----BEGIN CERTIFICATE-----"):] for cert in certificates]

        if not certificates:
            returns["message"] = "No certificate found on remote webserver"
            return returns
            
        cert = certificates[0]
        returns["cert"] = {"cert": cert}
        
        with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_cert_file:
            temp_cert_file.write(cert)
            temp_cert_file.seek(0)

        tlsa_command = ["openssl","x509","-in",temp_cert_file.name,"-pubkey","-noout","|","openssl","pkey","-pubin","-outform","der","|","openssl","dgst","-sha256","-binary",]
        
        tlsa_process = subprocess.Popen(" ".join(tlsa_command), shell=True, stdout=subprocess.PIPE)
        tlsa_output, _ = tlsa_process.communicate()

        tlsa_server = "3 1 1 " + binascii.hexlify(tlsa_output).decode("utf-8")

        
        returns["tlsa"] = {
            "server": tlsa_server,
            "nameserver": "",
            "match": False
            }


        # Get domains
        cert_obj = x509.load_pem_x509_certificate(cert.encode("utf-8"), default_backend())

        domains = []
        for ext in cert_obj.extensions:
            if ext.oid == x509.ExtensionOID.SUBJECT_ALTERNATIVE_NAME:
                san_list = ext.value.get_values_for_type(x509.DNSName)
                domains.extend(san_list)
        
        # Extract the common name (CN) from the subject
        common_name = cert_obj.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)
        if common_name:
            if common_name[0].value not in domains:
                domains.append(common_name[0].value)


        if domains:
            cert_domains = []
            for cn in domains:
                cert_domains.append(cn)
                if cn == domain:
                    domain_check = True
                elif cn.startswith("*."):
                    if domain.endswith(cn[1:]):
                        domain_check = True

        returns["cert"]["domains"] = cert_domains
        returns["cert"]["domain"] = domain_check

        expiry_date = cert_obj.not_valid_after_utc
        # Check if expiry date is past
        if expiry_date < datetime.datetime.now(datetime.timezone.utc):
            returns["cert"]["expired"] = True
            returns["cert"]["valid"] = False
            
        else:
            returns["cert"]["expired"] = False
            returns["cert"]["valid"] = True if domain_check else False

        returns["cert"]["expiry_date"] = expiry_date.strftime("%d %B %Y %H:%M:%S")


        

        try:
            # Check for TLSA record
            response = resolver.resolve("_443._tcp."+domain, "TLSA")
            tlsa_records = []
            for record in response:
                tlsa_records.append(str(record))

            if not tlsa_records:
                returns["message"] = "No TLSA record found on DNS"
                return returns
            
            returns["tlsa"]["nameserver"] = tlsa_records[0]
            if tlsa_server == tlsa_records[0]:
                returns["tlsa"]["match"] = True
            
        
        except:
            returns["message"] = "No TLSA record found on DNS"
            return returns

        # Check if valid
        if returns["cert"]["valid"] and returns["tlsa"]["match"] and returns["dnssec"]["valid"]:
            returns["valid"] = True
        
        returns["success"] = True
        return returns
            
    # Catch all exceptions
    except Exception as e:
        returns["message"] = f"An error occurred: {e}"
        return returns
    

def validate_dnssec(domain):
    # Pick a random resolver
    resolverIP = random.choice(resolver.nameservers)
    # delv @194.50.5.28 -a hsd-ksk nathan.woodburn A +rtrace +vtrace
    command = f"delv @{resolverIP} -a hsd-ksk {domain} A +rtrace +vtrace"
    result = subprocess.run(command, shell=True, capture_output=True, text=True)
    if "; fully validated" in result.stdout:
        return {"valid": True, "message": "DNSSEC is valid", "output": result.stdout, "errors": result.stderr}
    else:
        return {"valid": False, "message": "DNSSEC is not valid", "output": result.stdout, "errors": result.stderr}