import random
from urllib.parse import urlparse
import dns.resolver
import subprocess
import tempfile
import binascii
from cryptography import x509
from cryptography.hazmat.backends import default_backend
import datetime
from dns import resolver
import requests
import re
from bs4 import BeautifulSoup
import requests_doh
import urllib3
import socket

resolver = dns.resolver.Resolver()
resolver.nameservers = ["194.50.5.28","194.50.5.27","194.50.5.26"]
resolver.port = 53
requests_doh.add_dns_provider("HNSDoH", "https://hnsdoh.com/dns-query")

# Disable warnings
urllib3.disable_warnings()


def check_ssl(domain: str):    
    domain_check = False
    returns = {"success": False,"valid":False}
    try:
        # Query the DNS record
        response = resolver.resolve(domain, "A")
        records = []
        for record in response:
            records.append(str(record))

        if not records:
            return {"success": False, "message": "No A record found for " + domain}
            
        returns["ip"] = records[0]
        if len(records) > 1:
            returns["other_ips"] = records[1:]

        returns["dnssec"] = validate_dnssec(domain)


        # Get the first A record        
        ip = records[0]
        
        # Run the openssl s_client command
        s_client_command = ["openssl","s_client","-showcerts","-connect",f"{ip}:443","-servername",domain,]

        s_client_process = subprocess.Popen(s_client_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE)
        s_client_output, _ = s_client_process.communicate(input=b"\n")
        
        certificates = []
        current_cert = ""
        for line in s_client_output.split(b"\n"):
            current_cert += line.decode("utf-8") + "\n"
            if "-----END CERTIFICATE-----" in line.decode("utf-8"):
                certificates.append(current_cert)
                current_cert = ""

        # Remove anything before -----BEGIN CERTIFICATE-----
        certificates = [cert[cert.find("-----BEGIN CERTIFICATE-----"):] for cert in certificates]

        if not certificates:
            returns["cert"] = {
                "cert":"",
                "domains": [],
                "domain": False,
                "expired": False,
                "valid": False,
                "expiry_date": ""
                }
            returns["tlsa"] = {
                "server": "",
                "nameserver": "",
                "match": False
                }
            
        else:
            
            cert = certificates[0]
            returns["cert"] = {"cert": cert}
        
            with tempfile.NamedTemporaryFile(mode="w", delete=False) as temp_cert_file:
                temp_cert_file.write(cert)
                temp_cert_file.seek(0)

            tlsa_command = ["openssl","x509","-in",temp_cert_file.name,"-pubkey","-noout","|","openssl","pkey","-pubin","-outform","der","|","openssl","dgst","-sha256","-binary",]
        
            tlsa_process = subprocess.Popen(" ".join(tlsa_command), shell=True, stdout=subprocess.PIPE)
            tlsa_output, _ = tlsa_process.communicate()

            tlsa_server = "3 1 1 " + binascii.hexlify(tlsa_output).decode("utf-8")

        
            returns["tlsa"] = {
                "server": tlsa_server,
                "nameserver": "",
                "match": False
                }


            # Get domains
            cert_obj = x509.load_pem_x509_certificate(cert.encode("utf-8"), default_backend())

            domains = []
            for ext in cert_obj.extensions:
                if ext.oid == x509.ExtensionOID.SUBJECT_ALTERNATIVE_NAME:
                    san_list = ext.value.get_values_for_type(x509.DNSName)
                    domains.extend(san_list)
        
            # Extract the common name (CN) from the subject
            common_name = cert_obj.subject.get_attributes_for_oid(x509.NameOID.COMMON_NAME)
            if common_name:
                if common_name[0].value not in domains:
                    domains.append(common_name[0].value)


            if domains:
                cert_domains = []
                for cn in domains:
                    cert_domains.append(cn)
                    if cn == domain:
                        domain_check = True
                    elif cn.startswith("*."):
                        if domain.endswith(cn[1:]):
                            domain_check = True

            returns["cert"]["domains"] = cert_domains
            returns["cert"]["domain"] = domain_check

            expiry_date = cert_obj.not_valid_after_utc
            # Check if expiry date is past
            if expiry_date < datetime.datetime.now(datetime.timezone.utc):
                returns["cert"]["expired"] = True
                returns["cert"]["valid"] = False
                
            else:
                returns["cert"]["expired"] = False
                returns["cert"]["valid"] = True if domain_check else False

            returns["cert"]["expiry_date"] = expiry_date.strftime("%d %B %Y %H:%M:%S")

        try:
            # Check for TLSA record
            response = resolver.resolve("_443._tcp."+domain, "TLSA")
            tlsa_records = []
            for record in response:
                tlsa_records.append(str(record))

            if tlsa_records:
                returns["tlsa"]["nameserver"] = tlsa_records[0]
                if tlsa_server == tlsa_records[0]:
                    returns["tlsa"]["match"] = True
        except:
            returns["tlsa"]["error"] = "No TLSA record found on DNS"

        # Check if valid
        if returns["cert"]["valid"] and returns["tlsa"]["match"] and returns["dnssec"]["valid"]:
            returns["valid"] = True
        
        returns["success"] = True
        return returns
            
    # Catch all exceptions
    except Exception as e:
        returns["message"] = f"An error occurred: {e}"
        return returns
    

def validate_dnssec(domain):
    # Pick a random resolver
    resolverIP = resolver.nameservers[0]
    # delv @194.50.5.28 -a hsd-ksk nathan.woodburn A +rtrace +vtrace
    command = f"delv @{resolverIP} -a hsd-ksk {domain} A +rtrace +vtrace"
    result = subprocess.run(command, shell=True, capture_output=True, text=True)
    if "; fully validated" in result.stdout or "; negative response, fully validated" in result.stdout:
        return {"valid": True, "message": "DNSSEC is valid", "output": result.stderr + result.stdout}
    else:
        return {"valid": False, "message": "DNSSEC is not valid", "output": result.stderr + result.stdout}
    


def curl(url: str):
    if not url.startswith("http"):
        url = "http://" + url
    try:
        # curl --doh-url https://hnsdoh.com/dns-query {url} --insecure
        command = f"curl --doh-url https://hnsdoh.com/dns-query {url} --insecure --silent"
        response = subprocess.run(command, shell=True, capture_output=True, text=True, timeout=10)
        if response.returncode != 0:
            return {"success": False, "error": response.stderr}
        else:
            return {"success": True, "result": response.stdout}

    except Exception as e:
        return {"success": False, "error": "An error occurred", "message": str(e)}


class ProxyError(Exception):
    def __init__(self, message):
        self.message = message
        self.text = message
        self.ok = False
        self.status_code = 500
        super().__init__(self.message)


def proxy(url: str) -> requests.Response:
    try:
        session = requests_doh.DNSOverHTTPSSession("HNSDoH")
        r = session.get(url,verify=False,timeout=30)
        return r
    except Exception as e:
        return ProxyError(str(e))

def cleanProxyContent(htmlContent: str,url:str, proxyHost: str):
    # Set proxy host to https if not 127.0.0.1 or localhost
    if ":5000" not in proxyHost:
        proxyHost = proxyHost.replace("http","https")

    # Find all instances of the url in the html
    hostUrl = f"{urlparse(url).scheme}://{urlparse(url).netloc}"
    proxyUrl = f"{proxyHost}proxy/{hostUrl}"
    # htmlContent = htmlContent.replace(hostUrl,proxyUrl)

    # parse html
    soup = BeautifulSoup(htmlContent, 'html.parser')
    # find all resources


    for linkType in ['link','img','script', 'a']:
        links = soup.find_all(linkType)
        for link in links:
            for attrib in ['src','href']:
                if link.has_attr(attrib):
                    if str(link[attrib]).startswith('/'):
                        link.attrs[attrib] = proxyUrl + link[attrib]
                        continue
                    if str(link[attrib]).startswith('http'):
                        link.attrs[attrib] = str(link[attrib]).replace(hostUrl,proxyUrl)
                        continue
                    ignored = False
                    for ignore in ["data:", "mailto:", "tel:", "javascript:", "blob:"]:
                        if str(link[attrib]).startswith(ignore):
                            ignored = True
                            break
                    if not ignored:
                        # link.attrs[attrib] = f"{proxyUrl}/{link[attrib]}"
                        # Add path also
                        link.attrs[attrib] = f"{proxyUrl}/{urlparse(link[attrib]).path}/{link[attrib]}"
    
    scripts = soup.find_all('script')
    for script in scripts:
        if script.has_attr("text"):
            script.attrs["text"] = proxyCleanJS(script.text,url,proxyHost)
            continue
        if not script.has_attr("contents"):
            continue
        if len(script.contents) > 0:
            newScript = soup.new_tag("script")
            for content in script.contents:
                newScript.append(proxyCleanJS(content,url,proxyHost))
            script.replace_with(newScript)
    
    return soup.prettify()

def proxyCleanJS(jsContent: str, url: str, proxyHost: str):
    # Set proxy host to https if not 127.0.0.1 or localhost
    if ":5000" not in proxyHost:
        proxyHost = proxyHost.replace("http","https")
    
    hostUrl = f"{urlparse(url).scheme}://{urlparse(url).netloc}"
    proxyUrl = f"{proxyHost}proxy/{hostUrl}"
    
    if "dprofile" in url:
        jsContent = jsContent.replace("window.location.hostname", f"\"{urlparse(url).netloc}\"")
        jsContent = jsContent.replace("src=\"img", f"src=\"{proxyUrl}/img")

        return jsContent


    # Replace all instances of the url with the proxy url
    hostUrl = f"{urlparse(url).scheme}://{urlparse(url).netloc}"
    proxyUrl = f"{proxyHost}proxy/{hostUrl}"

    jsContent = jsContent.replace(hostUrl,proxyUrl)
    # Common ways to get current url
    for locator in ["window.location.href","window.location","location.href","location"]:
        jsContent = jsContent.replace(locator,proxyUrl)


    return jsContent
    


# if __name__ == "__main__":
#     print(curl("https://dso.dprofile"))