Просмотр исходного кода

fix(api): wrap Domain database writes in transaction

- Affects Domain.save(), Domain.delete()
- We start a transaction and issue SQL to write the Django Domain
  object to the database
  - If successful, we perform the necessary pdns operations
  - If successful, commit the transaction, if not, error out
- Test case included
Peter Thomassen 8 лет назад
Родитель
Сommit
de680aba73
2 измененных файлов с 51 добавлено и 16 удалено
  1. 36 16
      api/desecapi/models.py
  2. 15 0
      api/desecapi/tests/testdomains.py

+ 36 - 16
api/desecapi/models.py

@@ -1,5 +1,5 @@
 from django.conf import settings
-from django.db import models
+from django.db import models, transaction
 from django.contrib.auth.models import (
     BaseUserManager, AbstractBaseUser
 )
@@ -98,6 +98,26 @@ class Domain(models.Model):
     arecord = models.CharField(max_length=255, blank=True)
     aaaarecord = models.CharField(max_length=1024, blank=True)
     owner = models.ForeignKey(settings.AUTH_USER_MODEL, related_name='domains')
+    _dirtyRecords = False
+
+    def __setattr__(self, attrname, val):
+        setter_func = 'setter_' + attrname
+        if attrname in self.__dict__ and callable(getattr(self, setter_func, None)):
+            super(Domain, self).__setattr__(attrname, getattr(self, setter_func)(val))
+        else:
+            super(Domain, self).__setattr__(attrname, val)
+
+    def setter_arecord(self, val):
+        if val != self.arecord:
+            self._dirtyRecords = True
+
+        return val
+
+    def setter_aaaarecord(self, val):
+        if val != self.aaaarecord:
+            self._dirtyRecords = True
+
+        return val
 
     def pdns_resync(self):
         """
@@ -112,7 +132,7 @@ class Domain(models.Model):
         # update zone to latest information
         pdns.set_dyn_records(self.name, self.arecord, self.aaaarecord)
 
-    def pdns_sync(self):
+    def pdns_sync(self, new_domain):
         """
         Command pdns updates as indicated by the local changes.
         """
@@ -121,36 +141,36 @@ class Domain(models.Model):
             # suspend all updates
             return
 
-        new_domain = self.id is None
-        changes_required = False
-
-        # if this zone is new, create it
+        # if this zone is new, create it and set dirty flag if necessary
         if new_domain:
             pdns.create_zone(self.name)
-
-        # check if current A and AAAA record values require updating pdns
-        if new_domain:
-            changes_required = bool(self.arecord) or bool(self.aaaarecord)
-        else:
-            orig_domain = Domain.objects.get(id=self.id)
-            changes_required = self.arecord != orig_domain.arecord or self.aaaarecord != orig_domain.aaaarecord
+            self._dirtyRecords = bool(self.arecord) or bool(self.aaaarecord)
 
         # make changes if necessary
-        if changes_required:
+        if self._dirtyRecords:
             pdns.set_dyn_records(self.name, self.arecord, self.aaaarecord)
 
+        self._dirtyRecords = False
+
+    @transaction.atomic
     def delete(self, *args, **kwargs):
+        super(Domain, self).delete(*args, **kwargs)
+
         pdns.delete_zone(self.name)
         if self.name.endswith('.dedyn.io'):
             pdns.set_rrset('dedyn.io', self.name, 'DS', '')
             pdns.set_rrset('dedyn.io', self.name, 'NS', '')
-        super(Domain, self).delete(*args, **kwargs)
 
+    @transaction.atomic
     def save(self, *args, **kwargs):
+        # Record here if this is a new domain (self.pk is only None until we call super.save())
+        new_domain = self.pk is None
+
         self.updated = timezone.now()
-        self.pdns_sync()
         super(Domain, self).save(*args, **kwargs)
 
+        self.pdns_sync(new_domain)
+
     class Meta:
         ordering = ('created',)
 

+ 15 - 0
api/desecapi/tests/testdomains.py

@@ -194,6 +194,21 @@ class AuthenticatedDomainTests(APITestCase):
         self.assertTrue(("/%d" % self.ownedDomains[1].pk) in url)
         self.assertTrue("/" + self.ownedDomains[1].name in urlByName)
 
+    def testRollback(self):
+        name = utils.generateDomainname()
+
+        httpretty.enable()
+        httpretty.register_uri(httpretty.POST, settings.NSLORD_PDNS_API + '/zones', body="some error", status=500)
+
+        url = reverse('domain-list')
+        data = {'name': name}
+        try:
+            response = self.client.post(url, data)
+        except:
+            pass
+
+        self.assertFalse(Domain.objects.filter(name=name).exists())
+
 
 class AuthenticatedDynDomainTests(APITestCase):
     def setUp(self):