Browse Source

refactor(api): streamline functionality and use of DomainSerializer

Peter Thomassen 4 years ago
parent
commit
e9c45062dc
2 changed files with 16 additions and 13 deletions
  1. 4 1
      api/desecapi/models.py
  2. 12 12
      api/desecapi/serializers.py

+ 4 - 1
api/desecapi/models.py

@@ -18,7 +18,7 @@ import psl_dns
 import rest_framework.authtoken.models
 from django.conf import settings
 from django.contrib.auth.hashers import make_password
-from django.contrib.auth.models import BaseUserManager, AbstractBaseUser
+from django.contrib.auth.models import AbstractBaseUser, AnonymousUser, BaseUserManager
 from django.contrib.postgres.constraints import ExclusionConstraint
 from django.contrib.postgres.fields import ArrayField, CIEmailField, RangeOperators
 from django.core.exceptions import ValidationError
@@ -225,6 +225,9 @@ class Domain(ExportModelOperationsMixin('Domain'), models.Model):
     _keys = None
 
     def __init__(self, *args, **kwargs):
+        if isinstance(kwargs.get('owner'), AnonymousUser):
+            kwargs = {**kwargs, 'owner': None}  # make a copy and override
+        # Avoid super().__init__(owner=None, ...) to not mess up *values instantiation in django.db.models.Model.from_db
         super().__init__(*args, **kwargs)
         if self.pk is None and kwargs.get('renewal_state') is None and self.is_locally_registrable:
             self.renewal_state = Domain.RenewalState.FRESH

+ 12 - 12
api/desecapi/serializers.py

@@ -7,7 +7,6 @@ from base64 import urlsafe_b64decode, urlsafe_b64encode, b64encode
 import django.core.exceptions
 from captcha.audio import AudioCaptcha
 from captcha.image import ImageCaptcha
-from django.contrib.auth.models import AnonymousUser
 from django.contrib.auth.password_validation import validate_password
 from django.core.validators import MinValueValidator
 from django.db.models import Model, Q
@@ -558,6 +557,10 @@ class RRsetListSerializer(serializers.ListSerializer):
 
 
 class DomainSerializer(serializers.ModelSerializer):
+    default_error_messages = {
+        **serializers.Serializer.default_error_messages,
+        'name_unavailable': 'This domain name conflicts with an existing zone, or is disallowed by policy.',
+    }
 
     class Meta:
         model = models.Domain
@@ -579,18 +582,10 @@ class DomainSerializer(serializers.ModelSerializer):
         return fields
 
     def validate_name(self, value):
-        self.raise_if_domain_unavailable(value, self.context['request'].user)
+        if not models.Domain(name=value, owner=self.context['request'].user).is_registrable():
+            raise serializers.ValidationError(self.default_error_messages['name_unavailable'], code='name_unavailable')
         return value
 
-    @staticmethod
-    def raise_if_domain_unavailable(domain_name: str, user: models.User):
-        user = user if not isinstance(user, AnonymousUser) else None
-        if not models.Domain(name=domain_name, owner=user).is_registrable():
-            raise serializers.ValidationError(
-                'This domain name conflicts with an existing zone, or is disallowed by policy.',
-                code='name_unavailable'
-            )
-
     def create(self, validated_data):
         if 'minimum_ttl' not in validated_data and models.Domain(name=validated_data['name']).is_locally_registrable:
             validated_data.update(minimum_ttl=60)
@@ -645,7 +640,12 @@ class RegisterAccountSerializer(UserSerializer):
         extra_kwargs = UserSerializer.Meta.extra_kwargs
 
     def validate_domain(self, value):
-        DomainSerializer.raise_if_domain_unavailable(value, self.context['request'].user)
+        serializer = DomainSerializer(data=dict(name=value), context=self.context)
+        try:
+            serializer.is_valid(raise_exception=True)
+        except serializers.ValidationError:
+            raise serializers.ValidationError(serializer.default_error_messages['name_unavailable'],
+                                              code='name_unavailable')
         return value
 
     def create(self, validated_data):