dyndns.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. import base64
  2. import binascii
  3. from functools import cached_property
  4. from rest_framework import generics
  5. from rest_framework.authentication import get_authorization_header
  6. from rest_framework.exceptions import NotFound, ValidationError
  7. from rest_framework.response import Response
  8. from desecapi import metrics
  9. from desecapi.authentication import (
  10. BasicTokenAuthentication,
  11. TokenAuthentication,
  12. URLParamAuthentication,
  13. )
  14. from desecapi.exceptions import ConcurrencyException
  15. from desecapi.models import Domain
  16. from desecapi.pdns_change_tracker import PDNSChangeTracker
  17. from desecapi.permissions import TokenHasDomainDynDNSPermission
  18. from desecapi.renderers import PlainTextRenderer
  19. from desecapi.serializers import RRsetSerializer
  20. class DynDNS12UpdateView(generics.GenericAPIView):
  21. authentication_classes = (
  22. TokenAuthentication,
  23. BasicTokenAuthentication,
  24. URLParamAuthentication,
  25. )
  26. permission_classes = (TokenHasDomainDynDNSPermission,)
  27. renderer_classes = [PlainTextRenderer]
  28. serializer_class = RRsetSerializer
  29. throttle_scope = "dyndns"
  30. @property
  31. def throttle_scope_bucket(self):
  32. return self.domain.name
  33. def _find_ip(self, params, version):
  34. if version == 4:
  35. look_for = "."
  36. elif version == 6:
  37. look_for = ":"
  38. else:
  39. raise Exception
  40. # Check URL parameters
  41. for p in params:
  42. if p in self.request.query_params:
  43. if not len(self.request.query_params[p]):
  44. return None
  45. if look_for in self.request.query_params[p]:
  46. return self.request.query_params[p]
  47. # Check remote IP address
  48. client_ip = self.request.META.get("REMOTE_ADDR")
  49. if look_for in client_ip:
  50. return client_ip
  51. # give up
  52. return None
  53. @cached_property
  54. def qname(self):
  55. # hostname parameter
  56. try:
  57. if self.request.query_params["hostname"] != "YES":
  58. return self.request.query_params["hostname"].lower()
  59. except KeyError:
  60. pass
  61. # host_id parameter
  62. try:
  63. return self.request.query_params["host_id"].lower()
  64. except KeyError:
  65. pass
  66. # http basic auth username
  67. try:
  68. domain_name = (
  69. base64.b64decode(
  70. get_authorization_header(self.request)
  71. .decode()
  72. .split(" ")[1]
  73. .encode()
  74. )
  75. .decode()
  76. .split(":")[0]
  77. )
  78. if domain_name and "@" not in domain_name:
  79. return domain_name.lower()
  80. except (binascii.Error, IndexError, UnicodeDecodeError):
  81. pass
  82. # username parameter
  83. try:
  84. return self.request.query_params["username"].lower()
  85. except KeyError:
  86. pass
  87. # only domain associated with this user account
  88. try:
  89. return self.request.user.domains.get().name
  90. except Domain.MultipleObjectsReturned:
  91. raise ValidationError(
  92. detail={
  93. "detail": "Request does not properly specify domain for update.",
  94. "code": "domain-unspecified",
  95. }
  96. )
  97. except Domain.DoesNotExist:
  98. metrics.get("desecapi_dynDNS12_domain_not_found").inc()
  99. raise NotFound("nohost")
  100. @cached_property
  101. def domain(self):
  102. try:
  103. return Domain.objects.filter_qname(
  104. self.qname, owner=self.request.user
  105. ).order_by("-name_length")[0]
  106. except (IndexError, ValueError):
  107. raise NotFound("nohost")
  108. @property
  109. def subname(self):
  110. return self.qname.rpartition(f".{self.domain.name}")[0]
  111. def get_serializer_context(self):
  112. return {
  113. **super().get_serializer_context(),
  114. "domain": self.domain,
  115. "minimum_ttl": 60,
  116. }
  117. def get_queryset(self):
  118. return self.domain.rrset_set.filter(
  119. subname=self.subname, type__in=["A", "AAAA"]
  120. )
  121. def get(self, request, *args, **kwargs):
  122. instances = self.get_queryset().all()
  123. ipv4 = self._find_ip(["myip", "myipv4", "ip"], version=4)
  124. ipv6 = self._find_ip(["myipv6", "ipv6", "myip", "ip"], version=6)
  125. data = [
  126. {
  127. "type": "A",
  128. "subname": self.subname,
  129. "ttl": 60,
  130. "records": [ipv4] if ipv4 else [],
  131. },
  132. {
  133. "type": "AAAA",
  134. "subname": self.subname,
  135. "ttl": 60,
  136. "records": [ipv6] if ipv6 else [],
  137. },
  138. ]
  139. serializer = self.get_serializer(instances, data=data, many=True, partial=True)
  140. try:
  141. serializer.is_valid(raise_exception=True)
  142. except ValidationError as e:
  143. if any(
  144. any(
  145. getattr(non_field_error, "code", "") == "unique"
  146. for non_field_error in err.get("non_field_errors", [])
  147. )
  148. for err in e.detail
  149. ):
  150. raise ConcurrencyException from e
  151. raise e
  152. with PDNSChangeTracker():
  153. serializer.save()
  154. return Response("good", content_type="text/plain")