dns.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144
  1. import re
  2. import struct
  3. from ipaddress import IPv6Address
  4. import dns
  5. import dns.dnssec
  6. import dns.name
  7. import dns.rdtypes.txtbase, dns.rdtypes.svcbbase
  8. import dns.rdtypes.ANY.CERT, dns.rdtypes.ANY.CNAME, dns.rdtypes.ANY.MX, dns.rdtypes.ANY.NS
  9. import dns.rdtypes.IN.AAAA, dns.rdtypes.IN.SRV
  10. def _strip_quotes_decorator(func):
  11. return lambda *args, **kwargs: func(*args, **kwargs)[1:-1]
  12. # Ensure that dnspython agrees with pdns' expectations for SVCB / HTTPS parameters.
  13. # WARNING: This is a global side-effect. It can't be done by extending a class, because dnspython hardcodes the use of
  14. # their dns.rdtypes.svcbbase.*Param classes in the global dns.rdtypes.svcbbase._class_for_key dictionary. We either have
  15. # to globally mess with that dict and insert our custom class, or we just mess with their classes directly.
  16. dns.rdtypes.svcbbase.ALPNParam.to_text = _strip_quotes_decorator(
  17. dns.rdtypes.svcbbase.ALPNParam.to_text
  18. )
  19. dns.rdtypes.svcbbase.IPv4HintParam.to_text = _strip_quotes_decorator(
  20. dns.rdtypes.svcbbase.IPv4HintParam.to_text
  21. )
  22. dns.rdtypes.svcbbase.IPv6HintParam.to_text = _strip_quotes_decorator(
  23. dns.rdtypes.svcbbase.IPv6HintParam.to_text
  24. )
  25. dns.rdtypes.svcbbase.MandatoryParam.to_text = _strip_quotes_decorator(
  26. dns.rdtypes.svcbbase.MandatoryParam.to_text
  27. )
  28. dns.rdtypes.svcbbase.PortParam.to_text = _strip_quotes_decorator(
  29. dns.rdtypes.svcbbase.PortParam.to_text
  30. )
  31. @dns.immutable.immutable
  32. class CERT(dns.rdtypes.ANY.CERT.CERT):
  33. def to_text(self, origin=None, relativize=True, **kw):
  34. certificate_type = str(
  35. self.certificate_type
  36. ) # upstream implementation calls _ctype_to_text
  37. algorithm = str(
  38. self.algorithm
  39. ) # upstream implementation calls dns.dnssec.algorithm_to_text
  40. return "%s %d %s %s" % (
  41. certificate_type,
  42. self.key_tag,
  43. algorithm,
  44. dns.rdata._base64ify(self.certificate, **kw),
  45. )
  46. @dns.immutable.immutable
  47. class AAAA(dns.rdtypes.IN.AAAA.AAAA):
  48. def to_text(self, origin=None, relativize=True, **kw):
  49. address = super().to_text(origin, relativize, **kw)
  50. return IPv6Address(address).compressed
  51. @dns.immutable.immutable
  52. class LongQuotedTXT(dns.rdtypes.txtbase.TXTBase):
  53. """
  54. A TXT record like RFC 1035, but
  55. - allows arbitrarily long tokens, and
  56. - all tokens must be quoted.
  57. """
  58. def __init__(self, rdclass, rdtype, strings):
  59. # Same as in parent class, but with max_length=None. Note that we are calling __init__ from the grandparent.
  60. super(dns.rdtypes.txtbase.TXTBase, self).__init__(rdclass, rdtype)
  61. self.strings = self._as_tuple(
  62. strings, lambda x: self._as_bytes(x, True, max_length=None)
  63. )
  64. @classmethod
  65. def from_text(cls, rdclass, rdtype, tok, origin=None, relativize=True):
  66. strings = []
  67. for token in tok.get_remaining():
  68. token = token.unescape_to_bytes()
  69. # The 'if' below is always true in the current code, but we
  70. # are leaving this check in in case things change some day.
  71. if not token.is_quoted_string():
  72. raise dns.exception.SyntaxError("Content must be quoted.")
  73. strings.append(token.value)
  74. if len(strings) == 0:
  75. raise dns.exception.UnexpectedEnd
  76. return cls(rdclass, rdtype, strings)
  77. def _to_wire(self, file, compress=None, origin=None, canonicalize=False):
  78. for long_s in self.strings:
  79. for s in [long_s[i : i + 255] for i in range(0, max(len(long_s), 1), 255)]:
  80. l = len(s)
  81. assert l < 256
  82. file.write(struct.pack("!B", l))
  83. file.write(s)
  84. def _NameMixin(name_field, *, allow_root, hostname=True):
  85. pattern = (
  86. # https://github.com/PowerDNS/pdns/blob/4646277d05f293777a3d2423a3b188ccdf42c6bc/pdns/dnsname.cc#L419
  87. re.compile(r"^(([A-Za-z0-9]([A-Za-z0-9-]*[A-Za-z0-9])?)\.)+$")
  88. if hostname
  89. else
  90. # https://github.com/PowerDNS/pdns/blob/4646277d05f293777a3d2423a3b188ccdf42c6bc/pdns/dnsname.cc#L473
  91. # with the exception of [/@ :\\]
  92. re.compile(r"^(([A-Za-z0-9_*-]+)\.)+$")
  93. )
  94. class Mixin:
  95. def to_text(self, origin=None, relativize=True, **kw):
  96. name = getattr(self, name_field)
  97. if (
  98. not (allow_root and name == dns.name.root)
  99. and pattern.match(str(name)) is None
  100. ):
  101. raise ValueError(f"invalid {name_field}: {name}")
  102. return super().to_text(origin, relativize, **kw)
  103. return Mixin
  104. @dns.immutable.immutable
  105. class CNAME(
  106. _NameMixin("target", allow_root=True, hostname=False), dns.rdtypes.ANY.CNAME.CNAME
  107. ):
  108. pass
  109. @dns.immutable.immutable
  110. class MX(_NameMixin("exchange", allow_root=True), dns.rdtypes.ANY.MX.MX):
  111. pass
  112. @dns.immutable.immutable
  113. class NS(_NameMixin("target", allow_root=False), dns.rdtypes.ANY.NS.NS):
  114. pass
  115. @dns.immutable.immutable
  116. class SRV(_NameMixin("target", allow_root=True), dns.rdtypes.IN.SRV.SRV):
  117. pass