Browse Source

feat(api): include subname/type in Token default policies

Peter Thomassen 1 year ago
parent
commit
6ff9220adf

+ 58 - 0
api/desecapi/migrations/0036_remove_tokendomainpolicy_default_policy_on_insert_and_more.py

@@ -0,0 +1,58 @@
+# Generated by Django 5.0rc1 on 2023-11-29 18:13
+
+import pgtrigger.compiler
+import pgtrigger.migrations
+from django.db import migrations, models
+
+
+class Migration(migrations.Migration):
+    dependencies = [
+        ("desecapi", "0035_rename_perm_rrsets_tokendomainpolicy_write"),
+    ]
+
+    operations = [
+        pgtrigger.migrations.RemoveTrigger(
+            model_name="tokendomainpolicy",
+            name="default_policy_on_insert",
+        ),
+        pgtrigger.migrations.RemoveTrigger(
+            model_name="tokendomainpolicy",
+            name="default_policy_on_update",
+        ),
+        pgtrigger.migrations.RemoveTrigger(
+            model_name="tokendomainpolicy",
+            name="default_policy_on_delete",
+        ),
+        migrations.RemoveConstraint(
+            model_name="tokendomainpolicy",
+            name="unique_entry",
+        ),
+        migrations.RemoveConstraint(
+            model_name="tokendomainpolicy",
+            name="unique_entry_null_domain",
+        ),
+        pgtrigger.migrations.AddTrigger(
+            model_name="tokendomainpolicy",
+            trigger=pgtrigger.compiler.Trigger(
+                name="default_policy_primacy",
+                sql=pgtrigger.compiler.UpsertTriggerSql(
+                    constraint="CONSTRAINT",
+                    func="\n                    IF\n                        EXISTS(SELECT * FROM desecapi_tokendomainpolicy WHERE token_id = COALESCE(NEW.token_id, OLD.token_id)) AND NOT EXISTS(\n                            SELECT * FROM desecapi_tokendomainpolicy WHERE token_id = COALESCE(NEW.token_id, OLD.token_id) AND domain_id IS NULL AND subname IS NULL AND type IS NULL\n                        )\n                    THEN\n                        RAISE EXCEPTION 'Token policies without a default policy are not allowed.';\n                    END IF;\n                    RETURN NULL;\n                ",
+                    hash="5bac9e99f4e2fe1ba1cb410fb5ffd74a2e023a46",
+                    operation="INSERT OR UPDATE OR DELETE",
+                    pgid="pgtrigger_default_policy_primacy_9f1b1",
+                    table="desecapi_tokendomainpolicy",
+                    timing="DEFERRABLE INITIALLY DEFERRED",
+                    when="AFTER",
+                ),
+            ),
+        ),
+        migrations.AddConstraint(
+            model_name="tokendomainpolicy",
+            constraint=models.UniqueConstraint(
+                fields=("token", "domain", "subname", "type"),
+                name="unique_policy",
+                nulls_distinct=False,
+            ),
+        ),
+    ]

+ 75 - 69
api/desecapi/models/tokens.py

@@ -11,7 +11,7 @@ from django.contrib.auth.hashers import make_password
 from django.contrib.postgres.fields import ArrayField
 from django.core import validators
 from django.core.exceptions import ValidationError
-from django.db import models, transaction
+from django.db import models
 from django.db.models import F, Q
 from django.utils import timezone
 from django_prometheus.models import ExportModelOperationsMixin
@@ -89,61 +89,24 @@ class Token(ExportModelOperationsMixin("Token"), rest_framework.authtoken.models
     def make_hash(plain):
         return make_password(plain, salt="static", hasher="pbkdf2_sha256_iter1")
 
-    def get_policy(self, *, domain=None):
-        order_by = F("domain").asc(
-            nulls_last=True
-        )  # default Postgres sorting, but: explicit is better than implicit
+    def get_policy(self, *, domain=None, subname=None, type=None):
+        order_by = [
+            F(field).asc(
+                nulls_last=True  # default Postgres sorting, but: explicit is better than implicit
+            )
+            for field in ["domain", "subname", "type"]
+        ]
         return (
-            self.tokendomainpolicy_set.filter(Q(domain=domain) | Q(domain__isnull=True))
-            .order_by(order_by)
+            self.tokendomainpolicy_set.filter(
+                Q(domain=domain) | Q(domain__isnull=True),
+                Q(subname=subname) | Q(subname__isnull=True),
+                Q(type=type) | Q(type__isnull=True),
+            )
+            .order_by(*order_by)
             .first()
         )
 
-    @transaction.atomic
-    def delete(self):
-        # This is needed because Model.delete() emulates cascade delete via django.db.models.deletion.Collector.delete()
-        # which deletes related objects in pk order.  However, the default policy has to be deleted last.
-        # Perhaps this will change with https://code.djangoproject.com/ticket/21961
-        self.tokendomainpolicy_set.filter(domain__isnull=False).delete()
-        self.tokendomainpolicy_set.filter(domain__isnull=True).delete()
-        return super().delete()
-
-
-@pgtrigger.register(
-    # Ensure that token_user is consistent with token
-    pgtrigger.Trigger(
-        name="token_user",
-        operation=pgtrigger.Update | pgtrigger.Insert,
-        when=pgtrigger.Before,
-        func="NEW.token_user_id = (SELECT user_id FROM desecapi_token WHERE id = NEW.token_id); RETURN NEW;",
-    ),
-    # Ensure that if there is *any* domain policy for a given token, there is always one with domain=None.
-    pgtrigger.Trigger(
-        name="default_policy_on_insert",
-        operation=pgtrigger.Insert,
-        when=pgtrigger.Before,
-        # Trigger `condition` arguments (corresponding to WHEN clause) don't support subqueries, so we use `func`
-        func="IF (NEW.domain_id IS NOT NULL and NOT EXISTS(SELECT * FROM desecapi_tokendomainpolicy WHERE domain_id IS NULL AND token_id = NEW.token_id)) THEN "
-        "  RAISE EXCEPTION 'Cannot insert non-default policy into % table when default policy is not present', TG_TABLE_NAME; "
-        "END IF; RETURN NEW;",
-    ),
-    pgtrigger.Protect(
-        name="default_policy_on_update",
-        operation=pgtrigger.Update,
-        when=pgtrigger.Before,
-        condition=pgtrigger.Q(old__domain__isnull=True, new__domain__isnull=False),
-    ),
-    # Ideally, a deferred trigger (https://github.com/Opus10/django-pgtrigger/issues/14). Available in 3.4.0.
-    pgtrigger.Trigger(
-        name="default_policy_on_delete",
-        operation=pgtrigger.Delete,
-        when=pgtrigger.Before,
-        # Trigger `condition` arguments (corresponding to WHEN clause) don't support subqueries, so we use `func`
-        func="IF (OLD.domain_id IS NULL and EXISTS(SELECT * FROM desecapi_tokendomainpolicy WHERE domain_id IS NOT NULL AND token_id = OLD.token_id)) THEN "
-        "  RAISE EXCEPTION 'Cannot delete default policy from % table when non-default policy is present', TG_TABLE_NAME; "
-        "END IF; RETURN OLD;",
-    ),
-)
+
 class TokenDomainPolicy(ExportModelOperationsMixin("TokenDomainPolicy"), models.Model):
     id = models.UUIDField(primary_key=True, default=uuid.uuid4, editable=False)
     token = models.ForeignKey(Token, on_delete=models.CASCADE)
@@ -166,41 +129,84 @@ class TokenDomainPolicy(ExportModelOperationsMixin("TokenDomainPolicy"), models.
 
     class Meta:
         constraints = [
-            models.UniqueConstraint(fields=["token", "domain"], name="unique_entry"),
             models.UniqueConstraint(
-                fields=["token"],
-                condition=Q(domain__isnull=True),
-                name="unique_entry_null_domain",
+                name="unique_policy",
+                fields=["token", "domain", "subname", "type"],
+                nulls_distinct=False,
+            ),
+        ]
+        triggers = [
+            # Ensure that token_user is consistent with token (to fulfill compound FK constraint, see migration)
+            pgtrigger.Trigger(
+                name="token_user",
+                operation=pgtrigger.Update | pgtrigger.Insert,
+                when=pgtrigger.Before,
+                func="NEW.token_user_id = (SELECT user_id FROM desecapi_token WHERE id = NEW.token_id); RETURN NEW;",
+            ),
+            # Ensure that if there is *any* domain policy for a given token, there is always one with domain=None.
+            pgtrigger.Trigger(
+                name="default_policy_primacy",
+                operation=pgtrigger.Insert | pgtrigger.Update | pgtrigger.Delete,
+                when=pgtrigger.After,
+                timing=pgtrigger.Deferred,
+                func=pgtrigger.Func(
+                    """
+                    IF
+                        EXISTS(SELECT * FROM {meta.db_table} WHERE token_id = COALESCE(NEW.token_id, OLD.token_id)) AND NOT EXISTS(
+                            SELECT * FROM {meta.db_table} WHERE token_id = COALESCE(NEW.token_id, OLD.token_id) AND domain_id IS NULL AND subname IS NULL AND type IS NULL
+                        )
+                    THEN
+                        RAISE EXCEPTION 'Token policies without a default policy are not allowed.';
+                    END IF;
+                    RETURN NULL;
+                """
+                ),
             ),
         ]
 
+    @property
+    def is_default_policy(self):
+        default_policy = self.token.get_policy()
+        return default_policy is not None and self.pk == default_policy.pk
+
+    @property
+    def represents_default_policy(self):
+        return self.domain is None and self.subname is None and self.type is None
+
     def clean(self):
-        default_policy = self.token.get_policy(domain=None)
-        if not self._state.adding:  # update
-            # Can't change policy's default status ("domain NULLness") to maintain policy precedence
-            if (self.domain is None) != (self.pk == default_policy.pk):
+        if self._state.adding:  # create
+            # Can't violate policy precedence (default policy has to be first)
+            default_policy = self.token.get_policy()
+            if (default_policy is None) and not self.represents_default_policy:
                 raise ValidationError(
                     {
-                        "domain": "Policy precedence: Cannot disable default policy when others exist."
+                        "non_field_errors": [
+                            "Policy precedence: The first policy must be the default policy."
+                        ]
                     }
                 )
-        else:  # create
-            # Can't violate policy precedence (default policy has to be first)
-            if (self.domain is not None) and (default_policy is None):
+        else:  # update
+            # Can't make non-default policy default and vice versa
+            if self.is_default_policy != self.represents_default_policy:
                 raise ValidationError(
                     {
-                        "domain": "Policy precedence: The first policy must be the default policy."
+                        "non_field_errors": [
+                            "When using policies, there must be exactly one default policy."
+                        ]
                     }
                 )
 
     def delete(self, *args, **kwargs):
         # Can't delete default policy when others exist
-        if (self.domain is None) and self.token.tokendomainpolicy_set.exclude(
-            domain__isnull=True
-        ).exists():
+        if (
+            self.is_default_policy
+            and self.token.tokendomainpolicy_set.exclude(pk=self.pk).exists()
+        ):
             raise ValidationError(
                 {
-                    "domain": "Policy precedence: Can't delete default policy when there exist others."
+                    "non_field_errors": [
+                        "Policy precedence: Can't delete default policy when there exist others."
+                    ]
                 }
             )
         return super().delete(*args, **kwargs)

+ 90 - 29
api/desecapi/tests/test_token_domain_policy.py

@@ -57,6 +57,77 @@ class TokenDomainPolicyTestCase(DomainOwnerTestCase):
         self.token_manage = self.create_token(self.owner, perm_manage_tokens=True)
         self.other_token = self.create_token(self.user)
 
+    def test_get_policy(self):
+        def get_policy(domain, subname, type):
+            return self.token.get_policy(domain=domain, subname=subname, type=type)
+
+        def assertPolicy(policy, domain, subname, type):
+            self.assertEqual(policy.domain, domain)
+            self.assertEqual(policy.subname, subname)
+            self.assertEqual(policy.type, type)
+
+        qs = self.token.tokendomainpolicy_set
+
+        # Default policy is fallback for everything
+        qs.create(domain=None, subname=None, type=None)
+        for kwargs in [
+            dict(subname=subname, type=type_)
+            for subname in (None, "www")
+            for type_ in (None, "A")
+        ]:
+            policy = get_policy(self.my_domain, **kwargs)
+            assertPolicy(policy, None, None, None)
+
+        # Type wins over default
+        qs.create(domain=None, subname=None, type="A")
+        policy = get_policy(self.my_domain, "www", "A")
+        assertPolicy(policy, None, None, "A")
+
+        # Subname wins over type
+        qs.create(domain=None, subname="www", type=None)
+        policy = get_policy(self.my_domain, "www", "A")
+        assertPolicy(policy, None, "www", None)
+
+        # Most specific wins
+        qs.create(domain=None, subname="www", type="A")
+        policy = get_policy(self.my_domain, "www", "A")
+        assertPolicy(policy, None, "www", "A")
+
+        # Domain wins over default and over subname and type
+        qs.create(domain=self.my_domain, subname=None, type=None)
+        policy = get_policy(self.my_domain, None, None)
+        assertPolicy(policy, self.my_domain, None, None)
+
+        # Subname wins over default or domain default
+        qs.create(domain=self.my_domain, subname="www", type=None)
+        for kwargs in [
+            dict(subname="www", type=None),
+            dict(subname="www", type="A"),
+        ]:
+            policy = get_policy(self.my_domain, **kwargs)
+            assertPolicy(policy, self.my_domain, "www", None)
+
+        # Type wins over default or domain default
+        qs.create(domain=self.my_domain, subname=None, type="A")
+        for kwargs in [
+            dict(subname=None, type="A"),
+            dict(subname="www2", type="A"),
+        ]:
+            policy = get_policy(self.my_domain, **kwargs)
+            assertPolicy(policy, self.my_domain, None, "A")
+
+        # Subname wins over type
+        policy = get_policy(self.my_domain, "www", "A")
+        assertPolicy(policy, self.my_domain, "www", None)
+
+        # Subname + type wins over less specific
+        qs.create(domain=self.my_domain, subname="www", type="A")
+        policy = get_policy(self.my_domain, "www", "A")
+        assertPolicy(policy, self.my_domain, "www", "A")
+
+        # Check that we did all combinations
+        self.assertEqual(qs.count(), 2**3)
+
     def test_policy_lifecycle_without_management_permission(self):
         # Prepare (with management token)
         data = {"domain": None, "subname": None, "type": None, "perm_write": True}
@@ -104,11 +175,9 @@ class TokenDomainPolicyTestCase(DomainOwnerTestCase):
 
             # Change
             data = dict(perm_dyndns=False, perm_write=True)
-            policy_id = models.TokenDomainPolicy.objects.get(
-                token=target, domain__isnull=True
-            )
+            policy = target.get_policy()
             response = self.client.patch_policy(
-                target, using=self.token, policy_id=policy_id, data=data
+                target, using=self.token, policy_id=policy.pk, data=data
             )
             self.assertStatus(response, status.HTTP_403_FORBIDDEN)
 
@@ -130,7 +199,8 @@ class TokenDomainPolicyTestCase(DomainOwnerTestCase):
         ## without required field
         response = self.client.create_policy(self.token, using=self.token_manage)
         self.assertStatus(response, status.HTTP_400_BAD_REQUEST)
-        self.assertEqual(response.data["domain"], ["This field is required."])
+        for field in ["domain", "subname", "type"]:
+            self.assertEqual(response.data[field], ["This field is required."])
 
         ## without a default policy
         data = {"domain": self.my_domains[0].name, "subname": None, "type": None}
@@ -140,7 +210,7 @@ class TokenDomainPolicyTestCase(DomainOwnerTestCase):
             )
         self.assertStatus(response, status.HTTP_400_BAD_REQUEST)
         self.assertEqual(
-            response.data["domain"],
+            response.data["non_field_errors"],
             ["Policy precedence: The first policy must be the default policy."],
         )
 
@@ -257,12 +327,8 @@ class TokenDomainPolicyTestCase(DomainOwnerTestCase):
             )
         self.assertStatus(response, status.HTTP_400_BAD_REQUEST)
         self.assertEqual(
-            response.data,
-            {
-                "domain": [
-                    "Policy precedence: Cannot disable default policy when others exist."
-                ]
-            },
+            response.data["non_field_errors"],
+            ["When using policies, there must be exactly one default policy."],
         )
 
         ## partially modify the default policy
@@ -291,12 +357,8 @@ class TokenDomainPolicyTestCase(DomainOwnerTestCase):
             )
         self.assertStatus(response, status.HTTP_400_BAD_REQUEST)
         self.assertEqual(
-            response.data,
-            {
-                "domain": [
-                    "Policy precedence: Can't delete default policy when there exist others."
-                ]
-            },
+            response.data["non_field_errors"],
+            ["Policy precedence: Can't delete default policy when there exist others."],
         )
 
         ## delete other policy
@@ -384,6 +446,7 @@ class TokenDomainPolicyTestCase(DomainOwnerTestCase):
             self.token, using=self.token_manage, data=data
         )
         self.assertStatus(response, status.HTTP_201_CREATED)
+        default_policy_id = response.data["id"]
 
         ## another policy
         data = {"domain": self.my_domains[0].name, "subname": None, "type": None}
@@ -400,13 +463,9 @@ class TokenDomainPolicyTestCase(DomainOwnerTestCase):
         self.assertStatus(response, status.HTTP_200_OK)
         self.assertEqual(response.data, self.default_data | data | {"id": policy_id})
 
-        policies = {
-            self.my_domains[0]: self.token.tokendomainpolicy_set.get(
-                domain__isnull=False
-            ),
-            self.my_domains[1]: self.token.tokendomainpolicy_set.get(
-                domain__isnull=True
-            ),
+        policy_id_by_domain = {
+            self.my_domains[0]: policy_id,
+            self.my_domains[1]: default_policy_id,
         }
 
         kwargs = dict(HTTP_AUTHORIZATION=f"Token {self.token.plain}")
@@ -414,12 +473,14 @@ class TokenDomainPolicyTestCase(DomainOwnerTestCase):
         # For each permission type
         for perm in self.default_data.keys():
             # For the domain with specific policy and for the domain covered by the default policy
-            for domain in policies.keys():
+            for domain in policy_id_by_domain.keys():
                 # For both possible values of the permission
                 for value in [True, False]:
                     # Set only that permission for that domain (on its effective policy)
                     _reset_policies(self.token)
-                    policy = policies[domain]
+                    policy = self.token.tokendomainpolicy_set.get(
+                        pk=policy_id_by_domain[domain]
+                    )
                     setattr(policy, perm, value)
                     policy.save()
 
@@ -505,8 +566,8 @@ class TokenDomainPolicyTestCase(DomainOwnerTestCase):
         domain = domains.pop()
         domain.delete()
         self.assertEqual(
-            list(map(lambda x: x.domain, self.token.tokendomainpolicy_set.all())),
-            domains,
+            set(policy.domain for policy in self.token.tokendomainpolicy_set.all()),
+            set(domains),
         )
 
     def test_token_deletion(self):