conftest.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. import json
  2. import os
  3. import random
  4. import re
  5. import string
  6. import time
  7. import warnings
  8. from datetime import datetime
  9. from json import JSONDecodeError
  10. from typing import Optional, Tuple, Iterable
  11. import dns
  12. import dns.name
  13. import dns.query
  14. import dns.rdtypes.svcbbase
  15. import dns.zone
  16. import pytest
  17. import requests
  18. from requests.exceptions import SSLError
  19. from urllib3.exceptions import InsecureRequestWarning
  20. def tsprint(s, *args, **kwargs):
  21. print(f"{datetime.now().strftime('%d-%b (%H:%M:%S)')} {s}", *args, **kwargs)
  22. def _strip_quotes_decorator(func):
  23. return lambda *args, **kwargs: func(*args, **kwargs)[1:-1]
  24. # Ensure that dnspython agrees with pdns' expectations for SVCB / HTTPS parameters.
  25. # WARNING: This is a global side-effect. It can't be done by extending a class, because dnspython hardcodes the use of
  26. # their dns.rdtypes.svcbbase.*Param classes in the global dns.rdtypes.svcbbase._class_for_key dictionary. We either have
  27. # to globally mess with that dict and insert our custom class, or we just mess with their classes directly.
  28. dns.rdtypes.svcbbase.ALPNParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.ALPNParam.to_text)
  29. dns.rdtypes.svcbbase.IPv4HintParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.IPv4HintParam.to_text)
  30. dns.rdtypes.svcbbase.IPv6HintParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.IPv6HintParam.to_text)
  31. dns.rdtypes.svcbbase.MandatoryParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.MandatoryParam.to_text)
  32. dns.rdtypes.svcbbase.PortParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.PortParam.to_text)
  33. def random_mixed_case_string(n):
  34. k = random.randint(1, n-1)
  35. s = random.choices(string.ascii_lowercase, k=k) + random.choices(string.ascii_uppercase, k=n-k)
  36. random.shuffle(s)
  37. return ''.join(s)
  38. def random_email() -> str:
  39. return f'{random_mixed_case_string(10)}@{random_mixed_case_string(10)}.desec.test'
  40. def random_password() -> str:
  41. return "".join(random.choice(string.ascii_letters) for _ in range(16))
  42. def random_domainname() -> str:
  43. return (
  44. "".join(random.choice(string.ascii_lowercase) for _ in range(16))
  45. + ".test"
  46. )
  47. def random_local_public_suffix_domainname() -> str:
  48. return (
  49. "".join(random.choice(string.ascii_lowercase) for _ in range(16))
  50. + ".dedyn."
  51. + os.environ['DESECSTACK_DOMAIN']
  52. )
  53. class DeSECAPIV1Client:
  54. base_url = "https://desec." + os.environ["DESECSTACK_DOMAIN"] + "/api/v1"
  55. headers = {
  56. "Accept": "application/json",
  57. "Content-Type": "application/json",
  58. "User-Agent": "e2e2",
  59. }
  60. def __init__(self) -> None:
  61. super().__init__()
  62. self.email = None
  63. self.password = None
  64. self.domains = {}
  65. # We support two certificate verification methods
  66. # (1) against self-signed certificates, if /autocert path is present
  67. # (this is usually the case when run inside a docker container)
  68. # (2) against the default certificate store, if /autocert is not available
  69. # (this is usually the case when run outside a docker container)
  70. self.verify = True
  71. self.verify_alt = [
  72. f'/autocert/desec.{os.environ["DESECSTACK_DOMAIN"]}.cer',
  73. f'/autocert/get.desec.{os.environ["DESECSTACK_DOMAIN"]}.cer',
  74. ]
  75. @staticmethod
  76. def _filter_response_output(output: dict) -> dict:
  77. try:
  78. output['challenge'] = output['challenge'][:10] + '...'
  79. except (KeyError, TypeError):
  80. pass
  81. return output
  82. @property
  83. def domain(self):
  84. try:
  85. return next(iter(self.domains))
  86. except StopIteration:
  87. return None
  88. def _do_request(self, *args, **kwargs):
  89. verify_list = [self.verify] + self.verify_alt
  90. # do not verify SSL if we're in faketime (cert will be expired!?)
  91. if faketime_get() != '+0d':
  92. verify_list = [False]
  93. exc = None
  94. for verify in verify_list:
  95. try:
  96. with warnings.catch_warnings():
  97. if verify_list == [False]:
  98. # Supress insecurity warning if we do not want to verify
  99. warnings.filterwarnings('ignore', category=InsecureRequestWarning)
  100. reply = requests.request(*args, **kwargs, verify=verify)
  101. except SSLError as e:
  102. tsprint(f'API <<< SSL could not verify against "{verify}"')
  103. exc = e
  104. else:
  105. # note verification preference for next time
  106. self.verify = verify
  107. self.verify_alt = verify_list
  108. self.verify_alt.remove(self.verify)
  109. return reply
  110. tsprint(f'API <<< SSL could not be verified against any verification method')
  111. raise exc
  112. def _request(self, method: str, *, path: str, data: Optional[dict] = None, **kwargs) -> requests.Response:
  113. if data is not None:
  114. data = json.dumps(data)
  115. url = self.base_url + path if re.match(r'^https?://', path) is None else path
  116. tsprint(f"API >>> {method} {url}")
  117. if data:
  118. tsprint(f"API >>> {type(data)}: {self._shorten(data)}")
  119. response = self._do_request(
  120. method,
  121. url,
  122. data=data,
  123. headers=self.headers,
  124. **kwargs,
  125. )
  126. tsprint(f"API <<< {response.status_code}")
  127. if response.text:
  128. try:
  129. tsprint(f"API <<< {self._shorten(str(self._filter_response_output(response.json())))}")
  130. except JSONDecodeError:
  131. tsprint(f"API <<< {response.text}")
  132. return response
  133. @staticmethod
  134. def _shorten(s: str):
  135. if len(s) < 200:
  136. return s
  137. else:
  138. return s[:50] + '...' + s[-50:]
  139. def get(self, path: str, **kwargs) -> requests.Response:
  140. return self._request("GET", path=path, **kwargs)
  141. def post(self, path: str, data: Optional[dict] = None, **kwargs) -> requests.Response:
  142. return self._request("POST", path=path, data=data, **kwargs)
  143. def patch(self, path: str, data: Optional[dict] = None, **kwargs) -> requests.Response:
  144. return self._request("PATCH", path=path, data=data, **kwargs)
  145. def delete(self, path: str, **kwargs) -> requests.Response:
  146. return self._request("DELETE", path=path, **kwargs)
  147. def register(self, email: str, password: str) -> Tuple[requests.Response, requests.Response]:
  148. self.email = email
  149. self.password = password
  150. captcha = self.post("/captcha/")
  151. return captcha, self.post(
  152. "/auth/",
  153. data={
  154. "email": email,
  155. "password": password,
  156. "captcha": {
  157. "id": captcha.json()["id"],
  158. "solution": captcha.json()[
  159. "content"
  160. ], # available via e2e configuration magic
  161. },
  162. },
  163. )
  164. def login(self, email: str, password: str) -> requests.Response:
  165. response = self.post(
  166. "/auth/login/", data={"email": email, "password": password}
  167. )
  168. token = response.json().get('token')
  169. if token is not None:
  170. self.headers["Authorization"] = f'Token {response.json()["token"]}'
  171. self.patch( # make token last forever
  172. f"/auth/tokens/{response.json().get('id')}/",
  173. data={'max_unused_period': None, 'max_age': None}
  174. )
  175. return response
  176. def domain_list(self) -> requests.Response:
  177. return self.get("/domains/").json()
  178. def domain_create(self, name) -> requests.Response:
  179. if name in self.domains:
  180. raise ValueError
  181. response = self.post("/domains/", data={"name": name})
  182. self.domains[name] = response.json()
  183. return response
  184. def domain_destroy(self, name) -> requests.Response:
  185. if name not in self.domains:
  186. raise ValueError
  187. response = self.delete(f"/domains/{name}/")
  188. self.domains.pop(name)
  189. return response
  190. def rr_set_create(self, domain_name: str, rr_type: str, records: Iterable[str], subname: str = '',
  191. ttl: int = 3600) -> requests.Response:
  192. return self.post(
  193. f"/domains/{domain_name}/rrsets/",
  194. data={
  195. "subname": subname,
  196. "type": rr_type,
  197. "ttl": ttl,
  198. "records": records,
  199. }
  200. )
  201. def rr_set_create_bulk(self, domain_name: str, data: list) -> requests.Response:
  202. return self.patch(f"/domains/{domain_name}/rrsets/", data=data)
  203. def rr_set_delete(self, domain_name: str, rr_type: str, subname: str = '') -> requests.Response:
  204. return self.delete(f"/domains/{domain_name}/rrsets/{subname}.../{rr_type}/")
  205. def get_key_params(self, domain_name: str, rr_type: str) -> list:
  206. keys = self.domains[domain_name]['keys']
  207. if rr_type in ('CDNSKEY', 'DNSKEY'):
  208. params = {key['dnskey'] for key in keys}
  209. elif rr_type == 'CDS':
  210. params = {ds for key in keys for ds in key['ds']}
  211. else:
  212. raise ValueError
  213. # Split into four fields and remove additional spaces
  214. params = [map(lambda x: x.replace(' ', ''), param.split(' ', 3)) for param in params]
  215. # For (C)DNSKEY, add spaces every 32 characters
  216. if rr_type in ('CDNSKEY', 'DNSKEY'):
  217. params = [[a, b, c, ' '.join(d[i:i + 32] for i in range(0, len(d), 32))] for a, b, c, d in params]
  218. # Join again
  219. return {' '.join(param) for param in params}
  220. @pytest.fixture
  221. def api_anon() -> DeSECAPIV1Client:
  222. """
  223. Anonymous access to the API.
  224. """
  225. return DeSECAPIV1Client()
  226. @pytest.fixture()
  227. def api_user() -> DeSECAPIV1Client:
  228. """
  229. Access to the API with a fresh user account (zero domains, one token). Authorization header
  230. is preconfigured, email address and password are randomly chosen.
  231. """
  232. api = DeSECAPIV1Client()
  233. email = random_email()
  234. password = random_password()
  235. api.register(email, password)
  236. api.login(email, password)
  237. return api
  238. @pytest.fixture()
  239. def api_user_domain(api_user) -> DeSECAPIV1Client:
  240. """
  241. Access to the API with a fresh user account that owns a domain with random name. The domain has
  242. no records other than the default ones.
  243. """
  244. api_user.domain_create(random_domainname())
  245. return api_user
  246. class NSClient:
  247. where = None
  248. @classmethod
  249. def query(cls, qname: str, qtype: str):
  250. tsprint(f'DNS >>> {qname}/{qtype} @{cls.where}')
  251. qname = dns.name.from_text(qname)
  252. qtype = dns.rdatatype.from_text(qtype)
  253. answer = dns.query.tcp(
  254. q=dns.message.make_query(qname, qtype),
  255. where=cls.where,
  256. timeout=2
  257. )
  258. try:
  259. section = dns.message.AUTHORITY if qtype == dns.rdatatype.from_text('NS') else dns.message.ANSWER
  260. response = answer.find_rrset(section, qname, dns.rdataclass.IN, qtype)
  261. tsprint(f'DNS <<< {response}')
  262. return {i.to_text() for i in response.items}
  263. except KeyError:
  264. tsprint('DNS <<< !!! not found !!! Complete Answer below:\n' + answer.to_text())
  265. return {}
  266. class NSLordClient(NSClient):
  267. where = os.environ["DESECSTACK_IPV4_REAR_PREFIX16"] + '.0.129'
  268. def query_replication(zone: str, qname: str, qtype: str, covers: str = None):
  269. if qtype == 'RRSIG':
  270. assert covers, 'If querying RRSIG, covers parameter must be set to a RR type, e.g. SOA.'
  271. else:
  272. assert not covers
  273. covers = dns.rdatatype.NONE
  274. zonefile = os.path.join('/zones', zone + '.zone')
  275. zone = dns.name.from_text(zone, origin=dns.name.root)
  276. qname = dns.name.from_text(qname, origin=zone)
  277. if not os.path.exists(zonefile):
  278. tsprint(f'RPL <<< Zone file for {zone} not found '
  279. f'(number of zones: {len(list(filter(lambda f: f.endswith(".zone"), os.listdir("/zones"))))})')
  280. return None
  281. try:
  282. tsprint(f'RPL >>> {qname}/{qtype} in {zone}')
  283. z = dns.zone.from_file(f=zonefile, origin=zone, relativize=False)
  284. v = {i.to_text() for i in z.find_rrset(qname, qtype, covers=covers).items}
  285. tsprint(f'RPL <<< {v}')
  286. return v
  287. except KeyError:
  288. tsprint(f'RPL <<< RR Set {qname}/{qtype} not found')
  289. return {}
  290. except dns.zone.NoSOA:
  291. tsprint(f'RPL <<< Zone {zone} not found')
  292. return None
  293. def return_eventually(expression: callable, min_pause: float = .1, max_pause: float = 2, timeout: float = 5,
  294. retry_on: Tuple[type] = (Exception,)):
  295. if not callable(expression):
  296. raise ValueError('Expression given not callable. Did you forget "lambda:"?')
  297. wait = min_pause
  298. started = datetime.now()
  299. while True:
  300. try:
  301. return expression()
  302. except retry_on as e:
  303. if (datetime.now() - started).total_seconds() > timeout:
  304. tsprint(f'{expression.__code__} failed with {e}, no more retries')
  305. raise e
  306. time.sleep(wait)
  307. wait = min(2 * wait, max_pause)
  308. def assert_eventually(assertion: callable, min_pause: float = .1, max_pause: float = 2, timeout: float = 5) -> None:
  309. def _assert():
  310. assert assertion()
  311. return_eventually(_assert, min_pause, max_pause, timeout, retry_on=(AssertionError,))
  312. def faketime(t: str):
  313. print('FAKETIME', t)
  314. with open('/etc/faketime/faketime.rc', 'w') as f:
  315. f.write(t + '\n')
  316. def faketime_get():
  317. try:
  318. with open('/etc/faketime/faketime.rc', 'r') as f:
  319. return f.readline().strip()
  320. except FileNotFoundError:
  321. return '+0d'
  322. def faketime_add(days: int):
  323. assert days >= 0
  324. current_faketime = faketime_get()
  325. assert current_faketime[0] == '+'
  326. assert current_faketime[-1] == 'd'
  327. current_days = int(current_faketime[1:-1])
  328. faketime(f'+{current_days + days:n}d')