dyndns.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174
  1. import base64
  2. import binascii
  3. from functools import cached_property
  4. from rest_framework import generics
  5. from rest_framework.authentication import get_authorization_header
  6. from rest_framework.exceptions import NotFound, ValidationError
  7. from rest_framework.response import Response
  8. from desecapi import metrics
  9. from desecapi.authentication import (
  10. BasicTokenAuthentication,
  11. TokenAuthentication,
  12. URLParamAuthentication,
  13. )
  14. from desecapi.exceptions import ConcurrencyException
  15. from desecapi.models import Domain
  16. from desecapi.pdns_change_tracker import PDNSChangeTracker
  17. from desecapi.permissions import TokenHasDomainDynDNSPermission
  18. from desecapi.renderers import PlainTextRenderer
  19. from desecapi.serializers import RRsetSerializer
  20. class DynDNS12UpdateView(generics.GenericAPIView):
  21. authentication_classes = (
  22. TokenAuthentication,
  23. BasicTokenAuthentication,
  24. URLParamAuthentication,
  25. )
  26. permission_classes = (TokenHasDomainDynDNSPermission,)
  27. renderer_classes = [PlainTextRenderer]
  28. serializer_class = RRsetSerializer
  29. throttle_scope = "dyndns"
  30. @property
  31. def throttle_scope_bucket(self):
  32. return self.domain.name
  33. def _find_ip(self, params, separator):
  34. # Check URL parameters
  35. for p in params:
  36. try:
  37. param = self.request.query_params[p]
  38. except KeyError:
  39. continue
  40. if not len(param):
  41. return None
  42. if separator in param:
  43. return param
  44. # Check remote IP address
  45. client_ip = self.request.META.get("REMOTE_ADDR")
  46. if separator in client_ip:
  47. return client_ip
  48. # give up
  49. return None
  50. @cached_property
  51. def qname(self):
  52. # hostname parameter
  53. try:
  54. if self.request.query_params["hostname"] != "YES":
  55. return self.request.query_params["hostname"].lower()
  56. except KeyError:
  57. pass
  58. # host_id parameter
  59. try:
  60. return self.request.query_params["host_id"].lower()
  61. except KeyError:
  62. pass
  63. # http basic auth username
  64. try:
  65. domain_name = (
  66. base64.b64decode(
  67. get_authorization_header(self.request)
  68. .decode()
  69. .split(" ")[1]
  70. .encode()
  71. )
  72. .decode()
  73. .split(":")[0]
  74. )
  75. if domain_name and "@" not in domain_name:
  76. return domain_name.lower()
  77. except (binascii.Error, IndexError, UnicodeDecodeError):
  78. pass
  79. # username parameter
  80. try:
  81. return self.request.query_params["username"].lower()
  82. except KeyError:
  83. pass
  84. # only domain associated with this user account
  85. try:
  86. return self.request.user.domains.get().name
  87. except Domain.MultipleObjectsReturned:
  88. raise ValidationError(
  89. detail={
  90. "detail": "Request does not properly specify domain for update.",
  91. "code": "domain-unspecified",
  92. }
  93. )
  94. except Domain.DoesNotExist:
  95. metrics.get("desecapi_dynDNS12_domain_not_found").inc()
  96. raise NotFound("nohost")
  97. @cached_property
  98. def domain(self):
  99. try:
  100. return Domain.objects.filter_qname(
  101. self.qname, owner=self.request.user
  102. ).order_by("-name_length")[0]
  103. except (IndexError, ValueError):
  104. raise NotFound("nohost")
  105. @property
  106. def subname(self):
  107. return self.qname.rpartition(f".{self.domain.name}")[0]
  108. def get_serializer_context(self):
  109. return {
  110. **super().get_serializer_context(),
  111. "domain": self.domain,
  112. "minimum_ttl": 60,
  113. }
  114. def get_queryset(self):
  115. return self.domain.rrset_set.filter(
  116. subname=self.subname, type__in=["A", "AAAA"]
  117. )
  118. def get(self, request, *args, **kwargs):
  119. instances = self.get_queryset().all()
  120. ipv4 = self._find_ip(["myip", "myipv4", "ip"], separator=".")
  121. ipv6 = self._find_ip(["myipv6", "ipv6", "myip", "ip"], separator=":")
  122. data = [
  123. {
  124. "type": "A",
  125. "subname": self.subname,
  126. "ttl": 60,
  127. "records": [ipv4] if ipv4 else [],
  128. },
  129. {
  130. "type": "AAAA",
  131. "subname": self.subname,
  132. "ttl": 60,
  133. "records": [ipv6] if ipv6 else [],
  134. },
  135. ]
  136. serializer = self.get_serializer(instances, data=data, many=True, partial=True)
  137. try:
  138. serializer.is_valid(raise_exception=True)
  139. except ValidationError as e:
  140. if any(
  141. any(
  142. getattr(non_field_error, "code", "") == "unique"
  143. for non_field_error in err.get("non_field_errors", [])
  144. )
  145. for err in e.detail
  146. ):
  147. raise ConcurrencyException from e
  148. raise e
  149. with PDNSChangeTracker():
  150. serializer.save()
  151. return Response("good", content_type="text/plain")