dyndns.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184
  1. import base64
  2. import binascii
  3. from functools import cached_property
  4. from rest_framework import generics
  5. from rest_framework.authentication import get_authorization_header
  6. from rest_framework.exceptions import NotFound, ValidationError
  7. from rest_framework.response import Response
  8. from rest_framework.settings import api_settings
  9. from desecapi import metrics
  10. from desecapi.authentication import (
  11. BasicTokenAuthentication,
  12. TokenAuthentication,
  13. URLParamAuthentication,
  14. )
  15. from desecapi.exceptions import ConcurrencyException
  16. from desecapi.models import Domain
  17. from desecapi.pdns_change_tracker import PDNSChangeTracker
  18. from desecapi.permissions import TokenHasDomainDynDNSPermission
  19. from desecapi.renderers import PlainTextRenderer
  20. from desecapi.serializers import RRsetSerializer
  21. class DynDNS12UpdateView(generics.GenericAPIView):
  22. authentication_classes = (
  23. TokenAuthentication,
  24. BasicTokenAuthentication,
  25. URLParamAuthentication,
  26. )
  27. permission_classes = (TokenHasDomainDynDNSPermission,)
  28. renderer_classes = [PlainTextRenderer]
  29. serializer_class = RRsetSerializer
  30. throttle_scope = "dyndns"
  31. @property
  32. def throttle_scope_bucket(self):
  33. return self.domain.name
  34. def _find_ip(self, param_keys, separator):
  35. # Check URL parameters
  36. for param_key in param_keys:
  37. try:
  38. params = {
  39. param.strip()
  40. for param in self.request.query_params[param_key].split(",")
  41. if separator in param or param.strip() in ("", "preserve")
  42. }
  43. except KeyError:
  44. continue
  45. if len(params) > 1 and params & {"", "preserve"}:
  46. raise ValidationError(
  47. detail={
  48. "detail": f'IP parameter "{param_key}" cannot have addresses and "preserve" at the same time.',
  49. "code": "inconsistent-parameter",
  50. }
  51. )
  52. if params:
  53. return [] if "" in params else list(params)
  54. # Check remote IP address
  55. client_ip = self.request.META.get("REMOTE_ADDR")
  56. if separator in client_ip:
  57. return [client_ip]
  58. # give up
  59. return []
  60. @cached_property
  61. def qname(self):
  62. # hostname parameter
  63. try:
  64. if self.request.query_params["hostname"] != "YES":
  65. return self.request.query_params["hostname"].lower()
  66. except KeyError:
  67. pass
  68. # host_id parameter
  69. try:
  70. return self.request.query_params["host_id"].lower()
  71. except KeyError:
  72. pass
  73. # http basic auth username
  74. try:
  75. domain_name = (
  76. base64.b64decode(
  77. get_authorization_header(self.request)
  78. .decode()
  79. .split(" ")[1]
  80. .encode()
  81. )
  82. .decode()
  83. .split(":")[0]
  84. )
  85. if domain_name and "@" not in domain_name:
  86. return domain_name.lower()
  87. except (binascii.Error, IndexError, UnicodeDecodeError):
  88. pass
  89. # username parameter
  90. try:
  91. return self.request.query_params["username"].lower()
  92. except KeyError:
  93. pass
  94. # only domain associated with this user account
  95. try:
  96. return self.request.user.domains.get().name
  97. except Domain.MultipleObjectsReturned:
  98. raise ValidationError(
  99. detail={
  100. "detail": "Request does not properly specify domain for update.",
  101. "code": "domain-unspecified",
  102. }
  103. )
  104. except Domain.DoesNotExist:
  105. metrics.get("desecapi_dynDNS12_domain_not_found").inc()
  106. raise NotFound("nohost")
  107. @cached_property
  108. def domain(self):
  109. try:
  110. return Domain.objects.filter_qname(
  111. self.qname, owner=self.request.user
  112. ).order_by("-name_length")[0]
  113. except (IndexError, ValueError):
  114. raise NotFound("nohost")
  115. @property
  116. def subname(self):
  117. return self.qname.rpartition(f".{self.domain.name}")[0]
  118. def get_serializer_context(self):
  119. return {
  120. **super().get_serializer_context(),
  121. "domain": self.domain,
  122. "minimum_ttl": 60,
  123. }
  124. def get_queryset(self):
  125. return self.domain.rrset_set.filter(
  126. subname=self.subname, type__in=["A", "AAAA"]
  127. )
  128. def get(self, request, *args, **kwargs):
  129. instances = self.get_queryset().all()
  130. record_params = {
  131. "A": self._find_ip(["myip", "myipv4", "ip"], separator="."),
  132. "AAAA": self._find_ip(["myipv6", "ipv6", "myip", "ip"], separator=":"),
  133. }
  134. data = [
  135. {
  136. "type": type_,
  137. "subname": self.subname,
  138. "ttl": 60,
  139. "records": ip_params,
  140. }
  141. for type_, ip_params in record_params.items()
  142. if "preserve" not in ip_params
  143. ]
  144. serializer = self.get_serializer(instances, data=data, many=True, partial=True)
  145. try:
  146. serializer.is_valid(raise_exception=True)
  147. except ValidationError as e:
  148. if any(
  149. any(
  150. getattr(non_field_error, "code", "") == "unique"
  151. for non_field_error in err.get(
  152. api_settings.NON_FIELD_ERRORS_KEY, []
  153. )
  154. )
  155. for err in e.detail
  156. ):
  157. raise ConcurrencyException from e
  158. raise e
  159. with PDNSChangeTracker():
  160. serializer.save()
  161. return Response("good", content_type="text/plain")