123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143 |
- import base64
- import binascii
- from functools import cached_property
- from rest_framework import generics
- from rest_framework.authentication import get_authorization_header
- from rest_framework.exceptions import NotFound, ValidationError
- from rest_framework.response import Response
- from desecapi import metrics
- from desecapi.authentication import BasicTokenAuthentication, TokenAuthentication, URLParamAuthentication
- from desecapi.exceptions import ConcurrencyException
- from desecapi.models import Domain
- from desecapi.pdns_change_tracker import PDNSChangeTracker
- from desecapi.permissions import TokenHasDomainDynDNSPermission
- from desecapi.renderers import PlainTextRenderer
- from desecapi.serializers import RRsetSerializer
- class DynDNS12UpdateView(generics.GenericAPIView):
- authentication_classes = (TokenAuthentication, BasicTokenAuthentication, URLParamAuthentication,)
- permission_classes = (TokenHasDomainDynDNSPermission,)
- renderer_classes = [PlainTextRenderer]
- serializer_class = RRsetSerializer
- throttle_scope = 'dyndns'
- @property
- def throttle_scope_bucket(self):
- return self.domain.name
- def _find_ip(self, params, version):
- if version == 4:
- look_for = '.'
- elif version == 6:
- look_for = ':'
- else:
- raise Exception
- # Check URL parameters
- for p in params:
- if p in self.request.query_params:
- if not len(self.request.query_params[p]):
- return None
- if look_for in self.request.query_params[p]:
- return self.request.query_params[p]
- # Check remote IP address
- client_ip = self.request.META.get('REMOTE_ADDR')
- if look_for in client_ip:
- return client_ip
- # give up
- return None
- @cached_property
- def qname(self):
- # hostname parameter
- try:
- if self.request.query_params['hostname'] != 'YES':
- return self.request.query_params['hostname'].lower()
- except KeyError:
- pass
- # host_id parameter
- try:
- return self.request.query_params['host_id'].lower()
- except KeyError:
- pass
- # http basic auth username
- try:
- domain_name = base64.b64decode(
- get_authorization_header(self.request).decode().split(' ')[1].encode()).decode().split(':')[0]
- if domain_name and '@' not in domain_name:
- return domain_name.lower()
- except (binascii.Error, IndexError, UnicodeDecodeError):
- pass
- # username parameter
- try:
- return self.request.query_params['username'].lower()
- except KeyError:
- pass
- # only domain associated with this user account
- try:
- return self.request.user.domains.get().name
- except Domain.MultipleObjectsReturned:
- raise ValidationError(detail={
- "detail": "Request does not properly specify domain for update.",
- "code": "domain-unspecified"
- })
- except Domain.DoesNotExist:
- metrics.get('desecapi_dynDNS12_domain_not_found').inc()
- raise NotFound('nohost')
- @cached_property
- def domain(self):
- try:
- return Domain.objects.filter_qname(self.qname, owner=self.request.user).order_by('-name_length')[0]
- except (IndexError, ValueError):
- raise NotFound('nohost')
- @property
- def subname(self):
- return self.qname.rpartition(f'.{self.domain.name}')[0]
- def get_serializer_context(self):
- return {**super().get_serializer_context(), 'domain': self.domain, 'minimum_ttl': 60}
- def get_queryset(self):
- return self.domain.rrset_set.filter(subname=self.subname, type__in=['A', 'AAAA'])
- def get(self, request, *args, **kwargs):
- instances = self.get_queryset().all()
- ipv4 = self._find_ip(['myip', 'myipv4', 'ip'], version=4)
- ipv6 = self._find_ip(['myipv6', 'ipv6', 'myip', 'ip'], version=6)
- data = [
- {'type': 'A', 'subname': self.subname, 'ttl': 60, 'records': [ipv4] if ipv4 else []},
- {'type': 'AAAA', 'subname': self.subname, 'ttl': 60, 'records': [ipv6] if ipv6 else []},
- ]
- serializer = self.get_serializer(instances, data=data, many=True, partial=True)
- try:
- serializer.is_valid(raise_exception=True)
- except ValidationError as e:
- if any(
- any(
- getattr(non_field_error, 'code', '') == 'unique'
- for non_field_error
- in err.get('non_field_errors', [])
- )
- for err in e.detail
- ):
- raise ConcurrencyException from e
- raise e
- with PDNSChangeTracker():
- serializer.save()
- return Response('good', content_type='text/plain')
|