conftest.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507
  1. import glob
  2. import json
  3. import os
  4. import random
  5. import re
  6. import string
  7. import time
  8. import warnings
  9. from datetime import datetime
  10. from json import JSONDecodeError
  11. from typing import Optional, Tuple, Iterable
  12. import dns
  13. import dns.name
  14. import dns.query
  15. import dns.rdtypes.svcbbase
  16. import dns.zone
  17. import pytest
  18. import requests
  19. from requests.exceptions import SSLError
  20. from urllib3.exceptions import InsecureRequestWarning
  21. def pytest_addoption(parser):
  22. parser.addoption(
  23. "--skip-performance-tests", action="store_true", default=False, help="skip expensive performance tests"
  24. )
  25. def pytest_configure(config):
  26. config.addinivalue_line("markers", "performance: mark test as expensive performance test")
  27. def pytest_collection_modifyitems(config, items):
  28. if config.getoption("--skip-performance-tests"):
  29. skip_mark = pytest.mark.skip(reason="need --runslow option to run")
  30. for item in items:
  31. if "performance" in item.keywords:
  32. item.add_marker(skip_mark)
  33. def tsprint(s, *args, **kwargs):
  34. print(f"{datetime.now().strftime('%d-%b (%H:%M:%S)')} {s}", *args, **kwargs)
  35. def _strip_quotes_decorator(func):
  36. return lambda *args, **kwargs: func(*args, **kwargs)[1:-1]
  37. # Ensure that dnspython agrees with pdns' expectations for SVCB / HTTPS parameters.
  38. # WARNING: This is a global side-effect. It can't be done by extending a class, because dnspython hardcodes the use of
  39. # their dns.rdtypes.svcbbase.*Param classes in the global dns.rdtypes.svcbbase._class_for_key dictionary. We either have
  40. # to globally mess with that dict and insert our custom class, or we just mess with their classes directly.
  41. dns.rdtypes.svcbbase.ALPNParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.ALPNParam.to_text)
  42. dns.rdtypes.svcbbase.IPv4HintParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.IPv4HintParam.to_text)
  43. dns.rdtypes.svcbbase.IPv6HintParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.IPv6HintParam.to_text)
  44. dns.rdtypes.svcbbase.MandatoryParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.MandatoryParam.to_text)
  45. dns.rdtypes.svcbbase.PortParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.PortParam.to_text)
  46. def random_mixed_case_string(n):
  47. k = random.randint(1, n-1)
  48. s = random.choices(string.ascii_lowercase, k=k) + random.choices(string.ascii_uppercase, k=n-k)
  49. random.shuffle(s)
  50. return ''.join(s)
  51. def random_email() -> str:
  52. return f'{random_mixed_case_string(10)}@{random_mixed_case_string(10)}.desec.test'
  53. def random_password() -> str:
  54. return "".join(random.choice(string.ascii_letters) for _ in range(16))
  55. def random_domainname(suffix='test') -> str:
  56. return "".join(random.choice(string.ascii_lowercase) for _ in range(16)) + f'.{suffix}'
  57. class DeSECAPIV1Client:
  58. base_url = "https://desec." + os.environ["DESECSTACK_DOMAIN"] + "/api/v1"
  59. def __init__(self) -> None:
  60. super().__init__()
  61. self.headers = { # instance-local
  62. "Accept": "application/json",
  63. "Content-Type": "application/json",
  64. "User-Agent": "e2e2",
  65. }
  66. self.email = None
  67. self.password = None
  68. self.domains = {}
  69. # We support two certificate verification methods
  70. # (1) against self-signed certificates, if /autocert path is present
  71. # (this is usually the case when run inside a docker container)
  72. # (2) against the default certificate store, if /autocert is not available
  73. # (this is usually the case when run outside a docker container)
  74. self.verify = True
  75. self.verify_alt = glob.glob('/autocert/*.cer')
  76. @staticmethod
  77. def _filter_response_output(output: dict) -> dict:
  78. try:
  79. output['challenge'] = output['challenge'][:10] + '...'
  80. except (KeyError, TypeError):
  81. pass
  82. return output
  83. @property
  84. def domain(self):
  85. try:
  86. return next(iter(self.domains))
  87. except StopIteration:
  88. return None
  89. def _do_request(self, *args, **kwargs):
  90. if 'verify' in kwargs:
  91. verify_list = [kwargs.pop('verify')]
  92. elif faketime_get() != '+0d':
  93. # do not verify SSL if we're in faketime (cert will be expired!?)
  94. verify_list = [False]
  95. else:
  96. verify_list = [self.verify] + self.verify_alt
  97. exc = None
  98. for verify in verify_list:
  99. try:
  100. with warnings.catch_warnings():
  101. if not verify:
  102. # Suppress insecurity warning if we do not want to verify
  103. warnings.filterwarnings('ignore', category=InsecureRequestWarning)
  104. reply = requests.request(*args, **kwargs, verify=verify)
  105. except SSLError as e:
  106. tsprint(f'API <<< SSL could not verify against "{verify}"')
  107. exc = e
  108. else:
  109. # note verification preference for next time
  110. self.verify = verify
  111. self.verify_alt = verify_list
  112. self.verify_alt.remove(self.verify)
  113. return reply
  114. tsprint(f'API <<< SSL could not be verified against any verification method')
  115. raise exc
  116. def _request(self, method: str, *, path: str, data: Optional[dict] = None, **kwargs) -> requests.Response:
  117. if data is not None:
  118. data = json.dumps(data)
  119. url = self.base_url + path if re.match(r'^https?://', path) is None else path
  120. tsprint(f"API >>> {method} {url}")
  121. if data:
  122. tsprint(f"API >>> {type(data)}: {self._shorten(data)}")
  123. response = self._do_request(
  124. method,
  125. url,
  126. data=data,
  127. headers=self.headers,
  128. **kwargs,
  129. )
  130. tsprint(f"API <<< {response.status_code}")
  131. if response.text:
  132. try:
  133. tsprint(f"API <<< {self._shorten(str(self._filter_response_output(response.json())))}")
  134. except JSONDecodeError:
  135. tsprint(f"API <<< {response.text}")
  136. return response
  137. @staticmethod
  138. def _shorten(s: str):
  139. if len(s) < 200:
  140. return s
  141. else:
  142. return s[:50] + '...' + s[-50:]
  143. def get(self, path: str, **kwargs) -> requests.Response:
  144. return self._request("GET", path=path, **kwargs)
  145. def options(self, path: str, **kwargs) -> requests.Response:
  146. return self._request("OPTIONS", path=path, **kwargs)
  147. def post(self, path: str, data: Optional[dict] = None, **kwargs) -> requests.Response:
  148. return self._request("POST", path=path, data=data, **kwargs)
  149. def patch(self, path: str, data: Optional[dict] = None, **kwargs) -> requests.Response:
  150. return self._request("PATCH", path=path, data=data, **kwargs)
  151. def delete(self, path: str, **kwargs) -> requests.Response:
  152. return self._request("DELETE", path=path, **kwargs)
  153. def register(self, email: str, password: str) -> Tuple[requests.Response, requests.Response]:
  154. self.email = email
  155. self.password = password
  156. captcha = self.post("/captcha/")
  157. return captcha, self.post(
  158. "/auth/",
  159. data={
  160. "email": email,
  161. "password": password,
  162. "captcha": {
  163. "id": captcha.json()["id"],
  164. "solution": captcha.json()[
  165. "content"
  166. ], # available via e2e configuration magic
  167. },
  168. },
  169. )
  170. def login(self, email: str, password: str) -> requests.Response:
  171. response = self.post(
  172. "/auth/login/", data={"email": email, "password": password}
  173. )
  174. self.token = response.json().get('token')
  175. if self.token is not None:
  176. self.headers["Authorization"] = f'Token {self.token}'
  177. self.patch( # make token last forever
  178. f"/auth/tokens/{response.json().get('id')}/",
  179. data={'max_unused_period': None, 'max_age': None}
  180. )
  181. return response
  182. def domain_list(self) -> requests.Response:
  183. return self.get("/domains/").json()
  184. def domain_create(self, name) -> requests.Response:
  185. if name in self.domains:
  186. raise ValueError
  187. response = self.post("/domains/", data={"name": name})
  188. self.domains[name] = response.json()
  189. return response
  190. def domain_destroy(self, name) -> requests.Response:
  191. if name not in self.domains:
  192. raise ValueError
  193. response = self.delete(f"/domains/{name}/")
  194. self.domains.pop(name)
  195. return response
  196. def rr_set_create(self, domain_name: str, rr_type: str, records: Iterable[str], subname: str = '',
  197. ttl: int = 3600) -> requests.Response:
  198. return self.post(
  199. f"/domains/{domain_name}/rrsets/",
  200. data={
  201. "subname": subname,
  202. "type": rr_type,
  203. "ttl": ttl,
  204. "records": records,
  205. }
  206. )
  207. def rr_set_create_bulk(self, domain_name: str, data: list) -> requests.Response:
  208. return self.patch(f"/domains/{domain_name}/rrsets/", data=data)
  209. def rr_set_delete(self, domain_name: str, rr_type: str, subname: str = '') -> requests.Response:
  210. return self.delete(f"/domains/{domain_name}/rrsets/{subname}.../{rr_type}/")
  211. def get_key_params(self, domain_name: str, rr_type: str) -> list:
  212. keys = self.domains[domain_name]['keys']
  213. if rr_type in ('CDNSKEY', 'DNSKEY'):
  214. params = {key['dnskey'] for key in keys}
  215. elif rr_type == 'CDS':
  216. params = {ds for key in keys for ds in key['ds']}
  217. else:
  218. raise ValueError
  219. # Split into four fields and remove additional spaces
  220. params = [map(lambda x: x.replace(' ', ''), param.split(' ', 3)) for param in params]
  221. # For (C)DNSKEY, add spaces every 32 characters
  222. if rr_type in ('CDNSKEY', 'DNSKEY'):
  223. params = [[a, b, c, ' '.join(d[i:i + 32] for i in range(0, len(d), 32))] for a, b, c, d in params]
  224. # Join again
  225. return {' '.join(param) for param in params}
  226. class DeSECAPIV2Client(DeSECAPIV1Client):
  227. base_url = "https://desec." + os.environ["DESECSTACK_DOMAIN"] + "/api/v2"
  228. @pytest.fixture
  229. def api_anon() -> DeSECAPIV1Client:
  230. """
  231. Anonymous access to the API.
  232. """
  233. return DeSECAPIV1Client()
  234. @pytest.fixture
  235. def api_anon_v2() -> DeSECAPIV2Client:
  236. """
  237. Anonymous access to the API.
  238. """
  239. return DeSECAPIV2Client()
  240. @pytest.fixture()
  241. def api_user() -> DeSECAPIV1Client:
  242. """
  243. Access to the API with a fresh user account (zero domains, one token). Authorization header
  244. is preconfigured, email address and password are randomly chosen.
  245. """
  246. api = DeSECAPIV1Client()
  247. email = random_email()
  248. password = random_password()
  249. api.register(email, password)
  250. api.login(email, password)
  251. return api
  252. @pytest.fixture()
  253. def api_user_domain(api_user) -> DeSECAPIV1Client:
  254. """
  255. Access to the API with a fresh user account that owns a domain with random name. The domain has
  256. no records other than the default ones.
  257. """
  258. api_user.domain_create(random_domainname())
  259. return api_user
  260. @pytest.fixture()
  261. def api_user_domain_rrsets(api_user_domain, init_rrsets: dict) -> DeSECAPIV1Client:
  262. """
  263. Access to the API with a fresh user account that owns a domain with random name. The domain is
  264. equipped with RRsets from init_rrsets.
  265. """
  266. def _normalize_rrset(rrset, qtype):
  267. if qtype not in ('CDS', 'CDNSKEY', 'DNSKEY'):
  268. return rrset
  269. ttl, records = rrset
  270. return ttl, {' '.join(map(lambda x: x.replace(' ', ''), record.split(' ', 3))) for record in records}
  271. def _assert_rrsets(self, rrsets):
  272. rrsets_api = {
  273. (rrset['subname'], rrset['type']): (rrset['ttl'], set(rrset['records']))
  274. for rrset in self.get(f'/domains/{self.domain}/rrsets/').json()
  275. }
  276. rrsets_dns = {
  277. (subname, qtype): _normalize_rrset(NSLordClient.query(f'{subname}.{self.domain}'.lstrip('.'), qtype), qtype)
  278. for subname, qtype in rrsets.keys()
  279. }
  280. for k, v in rrsets.items():
  281. v = v or init_rrsets[k] # if None, check against init_rrsets
  282. if not v[1]:
  283. assert k not in rrsets_api
  284. assert not rrsets_dns[k][1]
  285. else:
  286. assert rrsets_api[k] == v
  287. assert rrsets_dns[k] == v
  288. api_user_domain.assert_rrsets = _assert_rrsets.__get__(api_user_domain) # very hacky way of adding a method
  289. api_user_domain.post(f"/domains/{api_user_domain.domain}/rrsets/", data=[
  290. {"subname": k[0], "type": k[1], "ttl": v[0], "records": list(v[1])}
  291. for k, v in init_rrsets.items()
  292. ])
  293. api_user_domain.assert_rrsets(init_rrsets)
  294. return api_user_domain
  295. api_user_lps = api_user
  296. @pytest.fixture()
  297. def lps(api_user_lps) -> DeSECAPIV1Client:
  298. """
  299. Access to the API with a fresh user account that owns a local public suffix.
  300. """
  301. lps = "dedyn." + os.environ['DESECSTACK_DOMAIN']
  302. api_user_lps.domain_create(lps) # may return 400 if exists, but that's ok
  303. return lps
  304. @pytest.fixture()
  305. def api_user_lps_domain(api_user, lps) -> DeSECAPIV1Client:
  306. """
  307. Access to the API with a fresh user account that owns a domain with random name under a local public suffix.
  308. The domain has no records other than the default ones.
  309. """
  310. api_user.domain_create(random_domainname(suffix=lps))
  311. return api_user
  312. class NSClient:
  313. where = None
  314. @classmethod
  315. def query(cls, qname: str, qtype: str):
  316. tsprint(f'DNS >>> {qname}/{qtype} @{cls.where}')
  317. qname = dns.name.from_text(qname)
  318. qtype = dns.rdatatype.from_text(qtype)
  319. answer = dns.query.tcp(
  320. q=dns.message.make_query(qname, qtype),
  321. where=cls.where,
  322. timeout=2
  323. )
  324. try:
  325. section = dns.message.AUTHORITY if qtype == dns.rdatatype.from_text('NS') else dns.message.ANSWER
  326. response = answer.find_rrset(section, qname, dns.rdataclass.IN, qtype)
  327. tsprint(f'DNS <<< {response}')
  328. return response.ttl, {i.to_text() for i in response.items}
  329. except KeyError:
  330. tsprint('DNS <<< !!! not found !!! Complete Answer below:\n' + answer.to_text())
  331. return None, set()
  332. class NSLordClient(NSClient):
  333. where = os.environ["DESECSTACK_IPV4_REAR_PREFIX16"] + '.0.129'
  334. def query_replication(zone: str, qname: str, qtype: str, covers: str = None):
  335. if qtype == 'RRSIG':
  336. assert covers, 'If querying RRSIG, covers parameter must be set to a RR type, e.g. SOA.'
  337. else:
  338. assert not covers
  339. covers = dns.rdatatype.NONE
  340. zonefile = os.path.join('/zones', zone + '.zone')
  341. zone = dns.name.from_text(zone, origin=dns.name.root)
  342. qname = dns.name.from_text(qname, origin=zone)
  343. if not os.path.exists(zonefile):
  344. tsprint(f'RPL <<< Zone file for {zone} not found '
  345. f'(number of zones: {len(list(filter(lambda f: f.endswith(".zone"), os.listdir("/zones"))))})')
  346. return None
  347. try:
  348. tsprint(f'RPL >>> {qname}/{qtype} in {zone}')
  349. z = dns.zone.from_file(f=zonefile, origin=zone, relativize=False)
  350. v = {i.to_text() for i in z.find_rrset(qname, qtype, covers=covers).items}
  351. tsprint(f'RPL <<< {v}')
  352. return v
  353. except KeyError:
  354. tsprint(f'RPL <<< RR Set {qname}/{qtype} not found')
  355. return {}
  356. except dns.zone.NoSOA:
  357. tsprint(f'RPL <<< Zone {zone} not found')
  358. return None
  359. def return_eventually(expression: callable, min_pause: float = .1, max_pause: float = 2, timeout: float = 5,
  360. retry_on: Tuple[type] = (Exception,)):
  361. if not callable(expression):
  362. raise ValueError('Expression given not callable. Did you forget "lambda:"?')
  363. wait = min_pause
  364. started = datetime.now()
  365. while True:
  366. try:
  367. return expression()
  368. except retry_on as e:
  369. if (datetime.now() - started).total_seconds() > timeout:
  370. tsprint(f'{expression.__code__} failed with {e}, no more retries')
  371. raise e
  372. time.sleep(wait)
  373. wait = min(2 * wait, max_pause)
  374. except Exception as e:
  375. tsprint(f'{expression.__code__} raised unexpected exception {e}')
  376. raise
  377. def assert_eventually(assertion: callable, min_pause: float = .1, max_pause: float = 2, timeout: float = 5,
  378. retry_on: Tuple[type] = (AssertionError,)) -> None:
  379. def _assert():
  380. assert assertion()
  381. return_eventually(_assert, min_pause, max_pause, timeout, retry_on=retry_on)
  382. def faketime(t: str):
  383. print('FAKETIME', t)
  384. with open('/etc/faketime/faketime.rc', 'w') as f:
  385. f.write(t + '\n')
  386. def faketime_get():
  387. try:
  388. with open('/etc/faketime/faketime.rc', 'r') as f:
  389. return f.readline().strip()
  390. except FileNotFoundError:
  391. return '+0d'
  392. class FaketimeShift:
  393. # Note: Time is rolled back when exiting the context. Conceivably, this may cause weird side
  394. # effects. So: Watch out for side effects. If any are found, let's rethink this.
  395. def __init__(self, days: int):
  396. assert days >= 0
  397. self.days = days
  398. def __enter__(self):
  399. self._faketime = faketime_get()
  400. assert self._faketime[0] == '+'
  401. assert self._faketime[-1] == 'd'
  402. current_days = int(self._faketime[1:-1])
  403. faketime(f'+{current_days + self.days:n}d')
  404. def __exit__(self, type, value, traceback):
  405. faketime(self._faketime)