dns.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import re
  2. import struct
  3. from ipaddress import IPv6Address
  4. import dns
  5. import dns.name
  6. import dns.rdtypes.txtbase, dns.rdtypes.svcbbase
  7. import dns.rdtypes.ANY.CERT, dns.rdtypes.ANY.MX, dns.rdtypes.ANY.NS
  8. import dns.rdtypes.IN.AAAA, dns.rdtypes.IN.SRV
  9. def _strip_quotes_decorator(func):
  10. return lambda *args, **kwargs: func(*args, **kwargs)[1:-1]
  11. # Ensure that dnspython agrees with pdns' expectations for SVCB / HTTPS parameters.
  12. # WARNING: This is a global side-effect. It can't be done by extending a class, because dnspython hardcodes the use of
  13. # their dns.rdtypes.svcbbase.*Param classes in the global dns.rdtypes.svcbbase._class_for_key dictionary. We either have
  14. # to globally mess with that dict and insert our custom class, or we just mess with their classes directly.
  15. dns.rdtypes.svcbbase.ALPNParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.ALPNParam.to_text)
  16. dns.rdtypes.svcbbase.IPv4HintParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.IPv4HintParam.to_text)
  17. dns.rdtypes.svcbbase.IPv6HintParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.IPv6HintParam.to_text)
  18. dns.rdtypes.svcbbase.MandatoryParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.MandatoryParam.to_text)
  19. dns.rdtypes.svcbbase.PortParam.to_text = _strip_quotes_decorator(dns.rdtypes.svcbbase.PortParam.to_text)
  20. @dns.immutable.immutable
  21. class CERT(dns.rdtypes.ANY.CERT.CERT):
  22. def to_text(self, origin=None, relativize=True, **kw):
  23. certificate_type = str(self.certificate_type) # upstream implementation calls _ctype_to_text
  24. return "%s %d %s %s" % (certificate_type, self.key_tag,
  25. dns.dnssec.algorithm_to_text(self.algorithm),
  26. dns.rdata._base64ify(self.certificate, **kw))
  27. @dns.immutable.immutable
  28. class AAAA(dns.rdtypes.IN.AAAA.AAAA):
  29. def to_text(self, origin=None, relativize=True, **kw):
  30. address = super().to_text(origin, relativize, **kw)
  31. return IPv6Address(address).compressed
  32. @dns.immutable.immutable
  33. class LongQuotedTXT(dns.rdtypes.txtbase.TXTBase):
  34. """
  35. A TXT record like RFC 1035, but
  36. - allows arbitrarily long tokens, and
  37. - all tokens must be quoted.
  38. """
  39. def __init__(self, rdclass, rdtype, strings):
  40. # Same as in parent class, but with max_length=None. Note that we are calling __init__ from the grandparent.
  41. super(dns.rdtypes.txtbase.TXTBase, self).__init__(rdclass, rdtype)
  42. self.strings = self._as_tuple(strings,
  43. lambda x: self._as_bytes(x, True, max_length=None))
  44. @classmethod
  45. def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True):
  46. strings = []
  47. for token in tok.get_remaining():
  48. token = token.unescape_to_bytes()
  49. # The 'if' below is always true in the current code, but we
  50. # are leaving this check in in case things change some day.
  51. if not token.is_quoted_string():
  52. raise dns.exception.SyntaxError("Content must be quoted.")
  53. strings.append(token.value)
  54. if len(strings) == 0:
  55. raise dns.exception.UnexpectedEnd
  56. return cls(rdclass, rdtype, strings)
  57. def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
  58. for long_s in self.strings:
  59. for s in [long_s[i:i+255] for i in range(0, max(len(long_s), 1), 255)]:
  60. l = len(s)
  61. assert l < 256
  62. file.write(struct.pack('!B', l))
  63. file.write(s)
  64. def _HostnameMixin(name_field, *, allow_root):
  65. # Taken from https://github.com/PowerDNS/pdns/blob/4646277d05f293777a3d2423a3b188ccdf42c6bc/pdns/dnsname.cc#L419
  66. hostname_re = re.compile(r'^(([A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?)\.)+$')
  67. class Mixin:
  68. def to_text(self, origin=None, relativize=True, **kw):
  69. name = getattr(self, name_field)
  70. if not (allow_root and name == dns.name.root) and hostname_re.match(str(name)) is None:
  71. raise ValueError(f'invalid {name_field}: {name}')
  72. return super().to_text(origin, relativize, **kw)
  73. return Mixin
  74. @dns.immutable.immutable
  75. class MX(_HostnameMixin('exchange', allow_root=True), dns.rdtypes.ANY.MX.MX):
  76. pass
  77. @dns.immutable.immutable
  78. class NS(_HostnameMixin('target', allow_root=False), dns.rdtypes.ANY.NS.NS):
  79. pass
  80. @dns.immutable.immutable
  81. class SRV(_HostnameMixin('target', allow_root=True), dns.rdtypes.IN.SRV.SRV):
  82. pass