Browse Source

refactor(api): streamline functionality and use of DomainSerializer

Peter Thomassen 4 years ago
parent
commit
e9c45062dc
2 changed files with 16 additions and 13 deletions
  1. 4 1
      api/desecapi/models.py
  2. 12 12
      api/desecapi/serializers.py

+ 4 - 1
api/desecapi/models.py

@@ -18,7 +18,7 @@ import psl_dns
 import rest_framework.authtoken.models
 import rest_framework.authtoken.models
 from django.conf import settings
 from django.conf import settings
 from django.contrib.auth.hashers import make_password
 from django.contrib.auth.hashers import make_password
-from django.contrib.auth.models import BaseUserManager, AbstractBaseUser
+from django.contrib.auth.models import AbstractBaseUser, AnonymousUser, BaseUserManager
 from django.contrib.postgres.constraints import ExclusionConstraint
 from django.contrib.postgres.constraints import ExclusionConstraint
 from django.contrib.postgres.fields import ArrayField, CIEmailField, RangeOperators
 from django.contrib.postgres.fields import ArrayField, CIEmailField, RangeOperators
 from django.core.exceptions import ValidationError
 from django.core.exceptions import ValidationError
@@ -225,6 +225,9 @@ class Domain(ExportModelOperationsMixin('Domain'), models.Model):
     _keys = None
     _keys = None
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
+        if isinstance(kwargs.get('owner'), AnonymousUser):
+            kwargs = {**kwargs, 'owner': None}  # make a copy and override
+        # Avoid super().__init__(owner=None, ...) to not mess up *values instantiation in django.db.models.Model.from_db
         super().__init__(*args, **kwargs)
         super().__init__(*args, **kwargs)
         if self.pk is None and kwargs.get('renewal_state') is None and self.is_locally_registrable:
         if self.pk is None and kwargs.get('renewal_state') is None and self.is_locally_registrable:
             self.renewal_state = Domain.RenewalState.FRESH
             self.renewal_state = Domain.RenewalState.FRESH

+ 12 - 12
api/desecapi/serializers.py

@@ -7,7 +7,6 @@ from base64 import urlsafe_b64decode, urlsafe_b64encode, b64encode
 import django.core.exceptions
 import django.core.exceptions
 from captcha.audio import AudioCaptcha
 from captcha.audio import AudioCaptcha
 from captcha.image import ImageCaptcha
 from captcha.image import ImageCaptcha
-from django.contrib.auth.models import AnonymousUser
 from django.contrib.auth.password_validation import validate_password
 from django.contrib.auth.password_validation import validate_password
 from django.core.validators import MinValueValidator
 from django.core.validators import MinValueValidator
 from django.db.models import Model, Q
 from django.db.models import Model, Q
@@ -558,6 +557,10 @@ class RRsetListSerializer(serializers.ListSerializer):
 
 
 
 
 class DomainSerializer(serializers.ModelSerializer):
 class DomainSerializer(serializers.ModelSerializer):
+    default_error_messages = {
+        **serializers.Serializer.default_error_messages,
+        'name_unavailable': 'This domain name conflicts with an existing zone, or is disallowed by policy.',
+    }
 
 
     class Meta:
     class Meta:
         model = models.Domain
         model = models.Domain
@@ -579,18 +582,10 @@ class DomainSerializer(serializers.ModelSerializer):
         return fields
         return fields
 
 
     def validate_name(self, value):
     def validate_name(self, value):
-        self.raise_if_domain_unavailable(value, self.context['request'].user)
+        if not models.Domain(name=value, owner=self.context['request'].user).is_registrable():
+            raise serializers.ValidationError(self.default_error_messages['name_unavailable'], code='name_unavailable')
         return value
         return value
 
 
-    @staticmethod
-    def raise_if_domain_unavailable(domain_name: str, user: models.User):
-        user = user if not isinstance(user, AnonymousUser) else None
-        if not models.Domain(name=domain_name, owner=user).is_registrable():
-            raise serializers.ValidationError(
-                'This domain name conflicts with an existing zone, or is disallowed by policy.',
-                code='name_unavailable'
-            )
-
     def create(self, validated_data):
     def create(self, validated_data):
         if 'minimum_ttl' not in validated_data and models.Domain(name=validated_data['name']).is_locally_registrable:
         if 'minimum_ttl' not in validated_data and models.Domain(name=validated_data['name']).is_locally_registrable:
             validated_data.update(minimum_ttl=60)
             validated_data.update(minimum_ttl=60)
@@ -645,7 +640,12 @@ class RegisterAccountSerializer(UserSerializer):
         extra_kwargs = UserSerializer.Meta.extra_kwargs
         extra_kwargs = UserSerializer.Meta.extra_kwargs
 
 
     def validate_domain(self, value):
     def validate_domain(self, value):
-        DomainSerializer.raise_if_domain_unavailable(value, self.context['request'].user)
+        serializer = DomainSerializer(data=dict(name=value), context=self.context)
+        try:
+            serializer.is_valid(raise_exception=True)
+        except serializers.ValidationError:
+            raise serializers.ValidationError(serializer.default_error_messages['name_unavailable'],
+                                              code='name_unavailable')
         return value
         return value
 
 
     def create(self, validated_data):
     def create(self, validated_data):