Explorar o código

refactor(api): simplify record management

Peter Thomassen %!s(int64=7) %!d(string=hai) anos
pai
achega
fda7c07c60
Modificáronse 4 ficheiros con 59 adicións e 123 borrados
  1. 39 104
      api/desecapi/models.py
  2. 1 2
      api/desecapi/pdns.py
  3. 16 10
      api/desecapi/serializers.py
  4. 3 7
      api/desecapi/views.py

+ 39 - 104
api/desecapi/models.py

@@ -7,7 +7,6 @@ from desecapi import pdns, mixins
 import datetime, uuid
 import datetime, uuid
 from django.core.validators import MinValueValidator
 from django.core.validators import MinValueValidator
 from rest_framework.authtoken.models import Token
 from rest_framework.authtoken.models import Token
-from collections import Counter
 
 
 
 
 class MyUserManager(BaseUserManager):
 class MyUserManager(BaseUserManager):
@@ -103,8 +102,6 @@ class Domain(models.Model, mixins.SetterMixin):
     name = models.CharField(max_length=191, unique=True)
     name = models.CharField(max_length=191, unique=True)
     owner = models.ForeignKey(settings.AUTH_USER_MODEL, related_name='domains')
     owner = models.ForeignKey(settings.AUTH_USER_MODEL, related_name='domains')
     _dirtyName = False
     _dirtyName = False
-    _ns_records_data = [{'content': 'ns1.desec.io.'},
-                        {'content': 'ns2.desec.io.'}]
 
 
     def setter_name(self, val):
     def setter_name(self, val):
         if val != self.name:
         if val != self.name:
@@ -136,10 +133,6 @@ class Domain(models.Model, mixins.SetterMixin):
 
 
         return name
         return name
 
 
-    # When this is made a property, looping over Domain.rrsets breaks
-    def get_rrsets(self):
-        return RRset.objects.filter(domain=self)
-
     def _create_pdns_zone(self):
     def _create_pdns_zone(self):
         """
         """
         Create zone on pdns.  This will also import any RRsets that may have
         Create zone on pdns.  This will also import any RRsets that may have
@@ -150,7 +143,7 @@ class Domain(models.Model, mixins.SetterMixin):
         # Import RRsets that may have been created (e.g. during captcha lock).
         # Import RRsets that may have been created (e.g. during captcha lock).
         # Don't perform if we do not know of any RRsets (it would delete all
         # Don't perform if we do not know of any RRsets (it would delete all
         # existing records from pdns).
         # existing records from pdns).
-        rrsets = self.get_rrsets()
+        rrsets = self.rrset_set.all()
         if rrsets:
         if rrsets:
             pdns.set_rrsets(self, rrsets)
             pdns.set_rrsets(self, rrsets)
 
 
@@ -169,72 +162,35 @@ class Domain(models.Model, mixins.SetterMixin):
             if (e.status_code == 422 and e.detail.endswith(' already exists')):
             if (e.status_code == 422 and e.detail.endswith(' already exists')):
                 # Zone exists, purge it by deleting all RRsets and sync
                 # Zone exists, purge it by deleting all RRsets and sync
                 pdns.set_rrsets(self, [], notify=False)
                 pdns.set_rrsets(self, [], notify=False)
-                pdns.set_rrsets(self, self.get_rrsets())
+                pdns.set_rrsets(self, self.rrset_set.all())
             else:
             else:
                 raise e
                 raise e
 
 
     @transaction.atomic
     @transaction.atomic
     def sync_from_pdns(self):
     def sync_from_pdns(self):
-        RRset.objects.filter(domain=self).delete()
-        rrset_datas = [rrset_data for rrset_data in pdns.get_rrset_datas(self)
-                       if rrset_data['type'] not in RRset.RESTRICTED_TYPES]
-        # Can't do bulk create because we need records creation in RRset.save()
-        for rrset_data in rrset_datas:
-            RRset(**rrset_data).save(sync=False)
-
-    @transaction.atomic
-    def set_rrsets(self, rrsets):
-        """
-        Writes the provided RRsets to the database, overriding any existing
-        RRsets of the same subname and type.  If the user account is not locked
-        for captcha, also inform pdns about the new RRsets.
-        """
-        for rrset in rrsets:
-            if rrset.domain != self:
-                raise ValueError(
-                    'Cannot set RRset for domain %s on domain %s.' % (
-                    rrset.domain.name, self.name))
-            if rrset.type in RRset.RESTRICTED_TYPES:
-                raise ValueError(
-                    'You cannot tinker with the %s RRset.' % rrset.type)
-
-        pdns_rrsets = []
-        for rrset in rrsets:
-            # Look up old RRset to see if it needs updating.  If exists and
-            # outdated, delete it so that we can bulk-create it later.
-            try:
-                old_rrset = self.rrset_set.get(subname=rrset.subname,
-                                               type=rrset.type)
-                old_rrset.ttl = rrset.ttl
-                old_rrset.records_data = rrset.records_data
-                rrset = old_rrset
-            except RRset.DoesNotExist:
-                pass
-
-            # At this point, rrset is an RRset to be created or possibly to be
-            # updated.  RRset.save() will decide what to write to the database.
-            if rrset.pk is None or 'records' in rrset.get_dirties():
-                pdns_rrsets.append(rrset)
-
-            rrset.save(sync=False)
-
-        if not self.owner.captcha_required:
-            pdns.set_rrsets(self, pdns_rrsets)
+        self.rrset_set.all().delete()
+        for rrset_data in pdns.get_rrset_datas(self):
+            if rrset_data['type'] in RRset.RESTRICTED_TYPES:
+                continue
+            records = rrset_data.pop('records')
+            rrset = self.rrset_set.create(**rrset_data)
+            rrset.set_rrs(records, sync=False)
 
 
     @transaction.atomic
     @transaction.atomic
     def delete(self, *args, **kwargs):
     def delete(self, *args, **kwargs):
         # Delete delegation for dynDNS domains (direct child of dedyn.io)
         # Delete delegation for dynDNS domains (direct child of dedyn.io)
         subname, parent_pdns_id = self.pdns_id.split('.', 1)
         subname, parent_pdns_id = self.pdns_id.split('.', 1)
         if parent_pdns_id == 'dedyn.io.':
         if parent_pdns_id == 'dedyn.io.':
-            parent = Domain.objects.filter(name='dedyn.io').first()
-
-            if parent:
-                rrsets = RRset.objects.filter(domain=parent, subname=subname,
-                                              type__in=['NS', 'DS']).all()
+            try:
+                parent = Domain.objects.get(name='dedyn.io')
+            except Domain.DoesNotExist:
+                pass
+            else:
+                rrsets = parent.rrset_set.filter(subname=subname,
+                                                 type__in=['NS', 'DS']).all()
+                # Need to go RRset by RRset to trigger pdns sync
                 for rrset in rrsets:
                 for rrset in rrsets:
-                    rrset.records_data = []
-
-                parent.set_rrsets(rrsets)
+                    rrset.delete()
 
 
         # Delete domain
         # Delete domain
         super().delete(*args, **kwargs)
         super().delete(*args, **kwargs)
@@ -256,12 +212,15 @@ class Domain(models.Model, mixins.SetterMixin):
         # parent. Don't notify slaves (we first have to enable DNSSEC).
         # parent. Don't notify slaves (we first have to enable DNSSEC).
         subname, parent_pdns_id = self.pdns_id.split('.', 1)
         subname, parent_pdns_id = self.pdns_id.split('.', 1)
         if parent_pdns_id == 'dedyn.io.':
         if parent_pdns_id == 'dedyn.io.':
-            parent = Domain.objects.filter(name='dedyn.io').first()
-            if parent:
-                records_data = [{'content': x} for x in settings.DEFAULT_NS]
-                rrset = RRset(domain=parent, subname=subname, type='NS',
-                              ttl=60, records_data=records_data)
-                rrset.save(notify=False)
+            try:
+                parent = Domain.objects.get(name='dedyn.io')
+            except Domain.DoesNotExist:
+                return
+
+            with transaction.atomic():
+                rrset = parent.rrset_set.create(subname=subname, type='NS',
+                                                ttl=60)
+                rrset.set_rrs(settings.DEFAULT_NS, notify=False)
 
 
     def __str__(self):
     def __str__(self):
         """
         """
@@ -329,8 +288,7 @@ class RRset(models.Model, mixins.SetterMixin):
     class Meta:
     class Meta:
         unique_together = (("domain","subname","type"),)
         unique_together = (("domain","subname","type"),)
 
 
-    def __init__(self, *args, records_data=None, **kwargs):
-        self.records_data = records_data
+    def __init__(self, *args, **kwargs):
         self._dirties = set()
         self._dirties = set()
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
 
 
@@ -372,55 +330,32 @@ class RRset(models.Model, mixins.SetterMixin):
             raise ValidationError(errors)
             raise ValidationError(errors)
 
 
     def get_dirties(self):
     def get_dirties(self):
-        if self.records_data is not None and 'records' not in self._dirties \
-            and (self.pk is None
-                or Counter([x['content'] for x in self.records_data])
-                    != Counter(self.records.values_list('content', flat=True))
-                ):
-            self._dirties.add('records')
-
         return self._dirties
         return self._dirties
 
 
     @property
     @property
     def name(self):
     def name(self):
         return '.'.join(filter(None, [self.subname, self.domain.name])) + '.'
         return '.'.join(filter(None, [self.subname, self.domain.name])) + '.'
 
 
+    @transaction.atomic
+    def set_rrs(self, contents, sync=True, notify=True):
+        self.records.all().delete()
+        self.records.set([RR(content=x) for x in contents], bulk=False)
+        if sync and not self.domain.owner.captcha_required:
+            pdns.set_rrset(self, notify=notify)
+
     @transaction.atomic
     @transaction.atomic
     def delete(self, *args, **kwargs):
     def delete(self, *args, **kwargs):
         super().delete(*args, **kwargs)
         super().delete(*args, **kwargs)
         pdns.set_rrset(self)
         pdns.set_rrset(self)
-        self.records_data = None
         self._dirties = {}
         self._dirties = {}
 
 
-    @transaction.atomic
-    def save(self, sync=True, notify=True, *args, **kwargs):
-        new = self.pk is None
-
-        # Empty records data means deletion
-        if self.records_data == []:
-            if not new:
-                self.delete()
-            return
-
-        # The only thing that can change is the TTL
-        if new or 'ttl' in self.get_dirties():
+    def save(self, *args, **kwargs):
+        # If not new, the only thing that can change is the TTL
+        if self.created is None or 'ttl' in self.get_dirties():
             self.updated = timezone.now()
             self.updated = timezone.now()
             self.full_clean()
             self.full_clean()
             super().save(*args, **kwargs)
             super().save(*args, **kwargs)
-
-        # Create RRset contents
-        if 'records' in self.get_dirties():
-            self.records.all().delete()
-            records = [RR(rrset=self, **data) for data in self.records_data]
-            self.records.bulk_create(records)
-            self.records_data = None
-
-        # Sync to pdns if new or anything is dirty
-        if sync and not self.domain.owner.captcha_required \
-                and (new or self.get_dirties()):
-            pdns.set_rrset(self, notify=notify)
-
-        self._dirties = {}
+            self._dirties = {}
 
 
 
 
 class RR(models.Model):
 class RR(models.Model):

+ 1 - 2
api/desecapi/pdns.py

@@ -119,8 +119,7 @@ def get_rrset_datas(domain):
     return [{'domain': domain,
     return [{'domain': domain,
              'subname': rrset['name'][:-(len(domain.name) + 2)],
              'subname': rrset['name'][:-(len(domain.name) + 2)],
              'type': rrset['type'],
              'type': rrset['type'],
-             'records_data': [{'content': record['content']}
-                              for record in rrset['records']],
+             'records': [record['content'] for record in rrset['records']],
              'ttl': rrset['ttl']}
              'ttl': rrset['ttl']}
             for rrset in get_zone(domain)['rrsets']]
             for rrset in get_zone(domain)['rrsets']]
 
 

+ 16 - 10
api/desecapi/serializers.py

@@ -1,6 +1,7 @@
 from rest_framework import serializers
 from rest_framework import serializers
 from desecapi.models import Domain, Donation, User, RR, RRset
 from desecapi.models import Domain, Donation, User, RR, RRset
 from djoser import serializers as djoserSerializers
 from djoser import serializers as djoserSerializers
+from django.db import transaction
 
 
 
 
 class RRSerializer(serializers.ModelSerializer):
 class RRSerializer(serializers.ModelSerializer):
@@ -20,25 +21,30 @@ class RRsetSerializer(serializers.ModelSerializer):
         model = RRset
         model = RRset
         fields = ('domain', 'subname', 'name', 'records', 'ttl', 'type',)
         fields = ('domain', 'subname', 'name', 'records', 'ttl', 'type',)
 
 
-    def _inject_records_data(self, validated_data):
+    def _set_records(self, instance):
         records_data = [{'content': x}
         records_data = [{'content': x}
                         for x in self.context['request'].data['records']]
                         for x in self.context['request'].data['records']]
-        rrs = RRSerializer(data=records_data, many=True, allow_empty=False)
-        if not rrs.is_valid():
-            errors = rrs.errors
+        rr_serializer = RRSerializer(data=records_data, many=True,
+                                     allow_empty=False)
+        if not rr_serializer.is_valid():
+            errors = rr_serializer.errors
             if 'non_field_errors' in errors:
             if 'non_field_errors' in errors:
                 errors['records'] = errors.pop('non_field_errors')
                 errors['records'] = errors.pop('non_field_errors')
             raise serializers.ValidationError(errors)
             raise serializers.ValidationError(errors)
+        instance.set_rrs([x['content'] for x in rr_serializer.validated_data])
 
 
-        return {'records_data': rrs.validated_data, **validated_data}
-
+    @transaction.atomic
     def create(self, validated_data):
     def create(self, validated_data):
-        validated_data = self._inject_records_data(validated_data)
-        return super().create(validated_data)
+        instance = super().create(validated_data)
+        self._set_records(instance)
+        return instance
 
 
+    @transaction.atomic
     def update(self, instance, validated_data):
     def update(self, instance, validated_data):
-        validated_data = self._inject_records_data(validated_data)
-        return super().update(instance, validated_data)
+        instance = super().update(instance, validated_data)
+        instance.records.all().delete()
+        self._set_records(instance)
+        return instance
 
 
     def get_records(self, obj):
     def get_records(self, obj):
         return list(obj.records.values_list('content', flat=True))
         return list(obj.records.values_list('content', flat=True))

+ 3 - 7
api/desecapi/views.py

@@ -366,16 +366,12 @@ class DynDNS12Update(APIView):
         if domain is None:
         if domain is None:
             raise Http404
             raise Http404
 
 
-        rrsets = []
         datas = {'A': self.findIPv4(request), 'AAAA': self.findIPv6(request)}
         datas = {'A': self.findIPv4(request), 'AAAA': self.findIPv6(request)}
 
 
         for type_, ip in datas.items():
         for type_, ip in datas.items():
-            records_data = [{'content': ip}] if ip is not None else []
-            rrset = RRset(domain=domain, subname='', ttl=60, type=type_,
-                          records_data=records_data)
-            rrsets.append(rrset)
-
-        domain.set_rrsets(rrsets)
+            rrset, _ = domain.rrset_set.update_or_create(subname='', type=type_,
+                                                         defaults={'ttl': 60})
+            rrset.set_rrs([ip] if ip is not None else [])
 
 
         return Response('good')
         return Response('good')