dns.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. import re
  2. import struct
  3. from ipaddress import IPv6Address
  4. import dns
  5. import dns.name
  6. import dns.rdtypes.txtbase, dns.rdtypes.svcbbase
  7. import dns.rdtypes.ANY.CERT, dns.rdtypes.ANY.MX, dns.rdtypes.ANY.NS
  8. import dns.rdtypes.IN.AAAA, dns.rdtypes.IN.SRV
  9. def _strip_quotes_decorator(func):
  10. return lambda *args, **kwargs: func(*args, **kwargs)[1:-1]
  11. # Ensure that dnspython agrees with pdns' expectations for SVCB / HTTPS parameters.
  12. # WARNING: This is a global side-effect. It can't be done by extending a class, because dnspython hardcodes the use of
  13. # their dns.rdtypes.svcbbase.*Param classes in the global dns.rdtypes.svcbbase._class_for_key dictionary. We either have
  14. # to globally mess with that dict and insert our custom class, or we just mess with their classes directly.
  15. dns.rdtypes.svcbbase.ALPNParam.to_text = _strip_quotes_decorator(
  16. dns.rdtypes.svcbbase.ALPNParam.to_text
  17. )
  18. dns.rdtypes.svcbbase.IPv4HintParam.to_text = _strip_quotes_decorator(
  19. dns.rdtypes.svcbbase.IPv4HintParam.to_text
  20. )
  21. dns.rdtypes.svcbbase.IPv6HintParam.to_text = _strip_quotes_decorator(
  22. dns.rdtypes.svcbbase.IPv6HintParam.to_text
  23. )
  24. dns.rdtypes.svcbbase.MandatoryParam.to_text = _strip_quotes_decorator(
  25. dns.rdtypes.svcbbase.MandatoryParam.to_text
  26. )
  27. dns.rdtypes.svcbbase.PortParam.to_text = _strip_quotes_decorator(
  28. dns.rdtypes.svcbbase.PortParam.to_text
  29. )
  30. @dns.immutable.immutable
  31. class CERT(dns.rdtypes.ANY.CERT.CERT):
  32. def to_text(self, origin=None, relativize=True, **kw):
  33. certificate_type = str(
  34. self.certificate_type
  35. ) # upstream implementation calls _ctype_to_text
  36. return "%s %d %s %s" % (
  37. certificate_type,
  38. self.key_tag,
  39. dns.dnssec.algorithm_to_text(self.algorithm),
  40. dns.rdata._base64ify(self.certificate, **kw),
  41. )
  42. @dns.immutable.immutable
  43. class AAAA(dns.rdtypes.IN.AAAA.AAAA):
  44. def to_text(self, origin=None, relativize=True, **kw):
  45. address = super().to_text(origin, relativize, **kw)
  46. return IPv6Address(address).compressed
  47. @dns.immutable.immutable
  48. class LongQuotedTXT(dns.rdtypes.txtbase.TXTBase):
  49. """
  50. A TXT record like RFC 1035, but
  51. - allows arbitrarily long tokens, and
  52. - all tokens must be quoted.
  53. """
  54. def __init__(self, rdclass, rdtype, strings):
  55. # Same as in parent class, but with max_length=None. Note that we are calling __init__ from the grandparent.
  56. super(dns.rdtypes.txtbase.TXTBase, self).__init__(rdclass, rdtype)
  57. self.strings = self._as_tuple(
  58. strings, lambda x: self._as_bytes(x, True, max_length=None)
  59. )
  60. @classmethod
  61. def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True):
  62. strings = []
  63. for token in tok.get_remaining():
  64. token = token.unescape_to_bytes()
  65. # The 'if' below is always true in the current code, but we
  66. # are leaving this check in in case things change some day.
  67. if not token.is_quoted_string():
  68. raise dns.exception.SyntaxError("Content must be quoted.")
  69. strings.append(token.value)
  70. if len(strings) == 0:
  71. raise dns.exception.UnexpectedEnd
  72. return cls(rdclass, rdtype, strings)
  73. def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
  74. for long_s in self.strings:
  75. for s in [long_s[i : i + 255] for i in range(0, max(len(long_s), 1), 255)]:
  76. l = len(s)
  77. assert l < 256
  78. file.write(struct.pack("!B", l))
  79. file.write(s)
  80. def _HostnameMixin(name_field, *, allow_root):
  81. # Taken from https://github.com/PowerDNS/pdns/blob/4646277d05f293777a3d2423a3b188ccdf42c6bc/pdns/dnsname.cc#L419
  82. hostname_re = re.compile(r"^(([A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?)\.)+$")
  83. class Mixin:
  84. def to_text(self, origin=None, relativize=True, **kw):
  85. name = getattr(self, name_field)
  86. if (
  87. not (allow_root and name == dns.name.root)
  88. and hostname_re.match(str(name)) is None
  89. ):
  90. raise ValueError(f"invalid {name_field}: {name}")
  91. return super().to_text(origin, relativize, **kw)
  92. return Mixin
  93. @dns.immutable.immutable
  94. class MX(_HostnameMixin("exchange", allow_root=True), dns.rdtypes.ANY.MX.MX):
  95. pass
  96. @dns.immutable.immutable
  97. class NS(_HostnameMixin("target", allow_root=False), dns.rdtypes.ANY.NS.NS):
  98. pass
  99. @dns.immutable.immutable
  100. class SRV(_HostnameMixin("target", allow_root=True), dns.rdtypes.IN.SRV.SRV):
  101. pass