123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406 |
- import json
- import os
- import random
- import re
- import string
- import time
- import warnings
- from datetime import datetime
- from json import JSONDecodeError
- from typing import Optional, Tuple, Iterable
- import dns
- import dns.name
- import dns.query
- import dns.rdtypes.svcbbase
- import dns.zone
- import pytest
- import requests
- from requests.exceptions import SSLError
- from urllib3.exceptions import InsecureRequestWarning
- def tsprint(s, *args, **kwargs):
- print(f"{datetime.now().strftime('%d-%b (%H:%M:%S)')} {s}", *args, **kwargs)
- def _strip_quotes_decorator(func):
- return lambda *args, **kwargs: func(*args, **kwargs)[1:-1]
- # Ensure that dnspython agrees with pdns' expectations for SVCB / HTTPS parameters.
- # WARNING: This is a global side-effect. It can't be done by extending a class, because dnspython hardcodes the use of
- # their dns.rdtypes.svcbbase.*Param classes in the global dns.rdtypes.svcbbase._class_for_key dictionary. We either have
- # to globally mess with that dict and insert our custom class, or we just mess with their classes directly.
- dns.rdtypes.svcbbase.ALPNParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.ALPNParam.to_text)
- dns.rdtypes.svcbbase.IPv4HintParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.IPv4HintParam.to_text)
- dns.rdtypes.svcbbase.IPv6HintParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.IPv6HintParam.to_text)
- dns.rdtypes.svcbbase.MandatoryParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.MandatoryParam.to_text)
- dns.rdtypes.svcbbase.PortParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.PortParam.to_text)
- def random_mixed_case_string(n):
- k = random.randint(1, n-1)
- s = random.choices(string.ascii_lowercase, k=k) + random.choices(string.ascii_uppercase, k=n-k)
- random.shuffle(s)
- return ''.join(s)
- def random_email() -> str:
- return f'{random_mixed_case_string(10)}@{random_mixed_case_string(10)}.desec.test'
- def random_password() -> str:
- return "".join(random.choice(string.ascii_letters) for _ in range(16))
- def random_domainname() -> str:
- return (
- "".join(random.choice(string.ascii_lowercase) for _ in range(16))
- + ".test"
- )
- def random_local_public_suffix_domainname() -> str:
- return (
- "".join(random.choice(string.ascii_lowercase) for _ in range(16))
- + ".dedyn."
- + os.environ['DESECSTACK_DOMAIN']
- )
- class DeSECAPIV1Client:
- base_url = "https://desec." + os.environ["DESECSTACK_DOMAIN"] + "/api/v1"
- headers = {
- "Accept": "application/json",
- "Content-Type": "application/json",
- "User-Agent": "e2e2",
- }
- def __init__(self) -> None:
- super().__init__()
- self.email = None
- self.password = None
- self.domains = {}
- # We support two certificate verification methods
- # (1) against self-signed certificates, if /autocert path is present
- # (this is usually the case when run inside a docker container)
- # (2) against the default certificate store, if /autocert is not available
- # (this is usually the case when run outside a docker container)
- self.verify = True
- self.verify_alt = [
- f'/autocert/desec.{os.environ["DESECSTACK_DOMAIN"]}.cer',
- f'/autocert/get.desec.{os.environ["DESECSTACK_DOMAIN"]}.cer',
- ]
- @staticmethod
- def _filter_response_output(output: dict) -> dict:
- try:
- output['challenge'] = output['challenge'][:10] + '...'
- except (KeyError, TypeError):
- pass
- return output
- @property
- def domain(self):
- try:
- return next(iter(self.domains))
- except StopIteration:
- return None
- def _do_request(self, *args, **kwargs):
- verify_list = [self.verify] + self.verify_alt
- # do not verify SSL if we're in faketime (cert will be expired!?)
- if faketime_get() != '+0d':
- verify_list = [False]
- exc = None
- for verify in verify_list:
- try:
- with warnings.catch_warnings():
- if verify_list == [False]:
- # Supress insecurity warning if we do not want to verify
- warnings.filterwarnings('ignore', category=InsecureRequestWarning)
- reply = requests.request(*args, **kwargs, verify=verify)
- except SSLError as e:
- tsprint(f'API <<< SSL could not verify against "{verify}"')
- exc = e
- else:
- # note verification preference for next time
- self.verify = verify
- self.verify_alt = verify_list
- self.verify_alt.remove(self.verify)
- return reply
- tsprint(f'API <<< SSL could not be verified against any verification method')
- raise exc
- def _request(self, method: str, *, path: str, data: Optional[dict] = None, **kwargs) -> requests.Response:
- if data is not None:
- data = json.dumps(data)
- url = self.base_url + path if re.match(r'^https?://', path) is None else path
- tsprint(f"API >>> {method} {url}")
- if data:
- tsprint(f"API >>> {type(data)}: {self._shorten(data)}")
- response = self._do_request(
- method,
- url,
- data=data,
- headers=self.headers,
- **kwargs,
- )
- tsprint(f"API <<< {response.status_code}")
- if response.text:
- try:
- tsprint(f"API <<< {self._shorten(str(self._filter_response_output(response.json())))}")
- except JSONDecodeError:
- tsprint(f"API <<< {response.text}")
- return response
- @staticmethod
- def _shorten(s: str):
- if len(s) < 200:
- return s
- else:
- return s[:50] + '...' + s[-50:]
- def get(self, path: str, **kwargs) -> requests.Response:
- return self._request("GET", path=path, **kwargs)
- def post(self, path: str, data: Optional[dict] = None, **kwargs) -> requests.Response:
- return self._request("POST", path=path, data=data, **kwargs)
- def patch(self, path: str, data: Optional[dict] = None, **kwargs) -> requests.Response:
- return self._request("PATCH", path=path, data=data, **kwargs)
- def delete(self, path: str, **kwargs) -> requests.Response:
- return self._request("DELETE", path=path, **kwargs)
- def register(self, email: str, password: str) -> Tuple[requests.Response, requests.Response]:
- self.email = email
- self.password = password
- captcha = self.post("/captcha/")
- return captcha, self.post(
- "/auth/",
- data={
- "email": email,
- "password": password,
- "captcha": {
- "id": captcha.json()["id"],
- "solution": captcha.json()[
- "content"
- ], # available via e2e configuration magic
- },
- },
- )
- def login(self, email: str, password: str) -> requests.Response:
- response = self.post(
- "/auth/login/", data={"email": email, "password": password}
- )
- token = response.json().get('token')
- if token is not None:
- self.headers["Authorization"] = f'Token {response.json()["token"]}'
- self.patch( # make token last forever
- f"/auth/tokens/{response.json().get('id')}/",
- data={'max_unused_period': None, 'max_age': None}
- )
- return response
- def domain_list(self) -> requests.Response:
- return self.get("/domains/").json()
- def domain_create(self, name) -> requests.Response:
- if name in self.domains:
- raise ValueError
- response = self.post("/domains/", data={"name": name})
- self.domains[name] = response.json()
- return response
- def domain_destroy(self, name) -> requests.Response:
- if name not in self.domains:
- raise ValueError
- response = self.delete(f"/domains/{name}/")
- self.domains.pop(name)
- return response
- def rr_set_create(self, domain_name: str, rr_type: str, records: Iterable[str], subname: str = '',
- ttl: int = 3600) -> requests.Response:
- return self.post(
- f"/domains/{domain_name}/rrsets/",
- data={
- "subname": subname,
- "type": rr_type,
- "ttl": ttl,
- "records": records,
- }
- )
- def rr_set_create_bulk(self, domain_name: str, data: list) -> requests.Response:
- return self.patch(f"/domains/{domain_name}/rrsets/", data=data)
- def rr_set_delete(self, domain_name: str, rr_type: str, subname: str = '') -> requests.Response:
- return self.delete(f"/domains/{domain_name}/rrsets/{subname}.../{rr_type}/")
- def get_key_params(self, domain_name: str, rr_type: str) -> list:
- keys = self.domains[domain_name]['keys']
- if rr_type in ('CDNSKEY', 'DNSKEY'):
- params = {key['dnskey'] for key in keys}
- elif rr_type == 'CDS':
- params = {ds for key in keys for ds in key['ds']}
- else:
- raise ValueError
- # Split into four fields and remove additional spaces
- params = [map(lambda x: x.replace(' ', ''), param.split(' ', 3)) for param in params]
- # For (C)DNSKEY, add spaces every 32 characters
- if rr_type in ('CDNSKEY', 'DNSKEY'):
- params = [[a, b, c, ' '.join(d[i:i + 32] for i in range(0, len(d), 32))] for a, b, c, d in params]
- # Join again
- return {' '.join(param) for param in params}
- @pytest.fixture
- def api_anon() -> DeSECAPIV1Client:
- """
- Anonymous access to the API.
- """
- return DeSECAPIV1Client()
- @pytest.fixture()
- def api_user() -> DeSECAPIV1Client:
- """
- Access to the API with a fresh user account (zero domains, one token). Authorization header
- is preconfigured, email address and password are randomly chosen.
- """
- api = DeSECAPIV1Client()
- email = random_email()
- password = random_password()
- api.register(email, password)
- api.login(email, password)
- return api
- @pytest.fixture()
- def api_user_domain(api_user) -> DeSECAPIV1Client:
- """
- Access to the API with a fresh user account that owns a domain with random name. The domain has
- no records other than the default ones.
- """
- api_user.domain_create(random_domainname())
- return api_user
- class NSClient:
- where = None
- @classmethod
- def query(cls, qname: str, qtype: str):
- tsprint(f'DNS >>> {qname}/{qtype} @{cls.where}')
- qname = dns.name.from_text(qname)
- qtype = dns.rdatatype.from_text(qtype)
- answer = dns.query.tcp(
- q=dns.message.make_query(qname, qtype),
- where=cls.where,
- timeout=2
- )
- try:
- section = dns.message.AUTHORITY if qtype == dns.rdatatype.from_text('NS') else dns.message.ANSWER
- response = answer.find_rrset(section, qname, dns.rdataclass.IN, qtype)
- tsprint(f'DNS <<< {response}')
- return {i.to_text() for i in response.items}
- except KeyError:
- tsprint('DNS <<< !!! not found !!! Complete Answer below:\n' + answer.to_text())
- return {}
- class NSLordClient(NSClient):
- where = os.environ["DESECSTACK_IPV4_REAR_PREFIX16"] + '.0.129'
- def query_replication(zone: str, qname: str, qtype: str, covers: str = None):
- if qtype == 'RRSIG':
- assert covers, 'If querying RRSIG, covers parameter must be set to a RR type, e.g. SOA.'
- else:
- assert not covers
- covers = dns.rdatatype.NONE
- zonefile = os.path.join('/zones', zone + '.zone')
- zone = dns.name.from_text(zone, origin=dns.name.root)
- qname = dns.name.from_text(qname, origin=zone)
- if not os.path.exists(zonefile):
- tsprint(f'RPL <<< Zone file for {zone} not found '
- f'(number of zones: {len(list(filter(lambda f: f.endswith(".zone"), os.listdir("/zones"))))})')
- return None
- try:
- tsprint(f'RPL >>> {qname}/{qtype} in {zone}')
- z = dns.zone.from_file(f=zonefile, origin=zone, relativize=False)
- v = {i.to_text() for i in z.find_rrset(qname, qtype, covers=covers).items}
- tsprint(f'RPL <<< {v}')
- return v
- except KeyError:
- tsprint(f'RPL <<< RR Set {qname}/{qtype} not found')
- return {}
- except dns.zone.NoSOA:
- tsprint(f'RPL <<< Zone {zone} not found')
- return None
- def return_eventually(expression: callable, min_pause: float = .1, max_pause: float = 2, timeout: float = 5,
- retry_on: Tuple[type] = (Exception,)):
- if not callable(expression):
- raise ValueError('Expression given not callable. Did you forget "lambda:"?')
- wait = min_pause
- started = datetime.now()
- while True:
- try:
- return expression()
- except retry_on as e:
- if (datetime.now() - started).total_seconds() > timeout:
- tsprint(f'{expression.__code__} failed with {e}, no more retries')
- raise e
- time.sleep(wait)
- wait = min(2 * wait, max_pause)
- def assert_eventually(assertion: callable, min_pause: float = .1, max_pause: float = 2, timeout: float = 5) -> None:
- def _assert():
- assert assertion()
- return_eventually(_assert, min_pause, max_pause, timeout, retry_on=(AssertionError,))
- def faketime(t: str):
- print('FAKETIME', t)
- with open('/etc/faketime/faketime.rc', 'w') as f:
- f.write(t + '\n')
- def faketime_get():
- try:
- with open('/etc/faketime/faketime.rc', 'r') as f:
- return f.readline().strip()
- except FileNotFoundError:
- return '+0d'
- def faketime_add(days: int):
- assert days >= 0
- current_faketime = faketime_get()
- assert current_faketime[0] == '+'
- assert current_faketime[-1] == 'd'
- current_days = int(current_faketime[1:-1])
- faketime(f'+{current_days + days:n}d')
|