test_replication.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136
  1. import json
  2. import os
  3. import random
  4. import string
  5. import time
  6. from datetime import datetime
  7. from tempfile import TemporaryDirectory
  8. from django.test import testcases
  9. from rest_framework import status
  10. from desecapi.replication import Repository
  11. from desecapi.tests.base import DesecTestCase
  12. class ReplicationTest(DesecTestCase):
  13. def test_serials(self):
  14. url = self.reverse('v1:serial')
  15. zones = [
  16. {'name': 'test.example.', 'edited_serial': 12345},
  17. {'name': 'example.org.', 'edited_serial': 54321},
  18. ]
  19. serials = {zone['name']: zone['edited_serial'] for zone in zones}
  20. pdns_requests = [{
  21. 'method': 'GET',
  22. 'uri': self.get_full_pdns_url(r'/zones', ns='MASTER'),
  23. 'status': 200,
  24. 'body': json.dumps(zones),
  25. }]
  26. # Run twice to make sure cache output varies on remote address
  27. for i in range(2):
  28. response = self.client.get(path=url, REMOTE_ADDR='123.8.0.2')
  29. self.assertStatus(response, status.HTTP_401_UNAUTHORIZED)
  30. with self.assertPdnsRequests(pdns_requests):
  31. response = self.client.get(path=url, REMOTE_ADDR='10.8.0.2')
  32. self.assertStatus(response, status.HTTP_200_OK)
  33. self.assertEqual(response.data, serials)
  34. # Do not expect pdns request in next iteration (result will be cached)
  35. pdns_requests = []
  36. class RepositoryTest(testcases.TestCase):
  37. def assertGit(self, path):
  38. self.assertTrue(
  39. os.path.exists(os.path.join(path, '.git')),
  40. f'Expected a git repository at {path} but did not find .git subdirectory.'
  41. )
  42. def assertHead(self, repo, message=None, sha=None):
  43. actual_sha, actual_message = repo.get_head()
  44. if actual_sha is None:
  45. self.fail(f'Expected HEAD to have commit message "{message}" and hash "{sha}", but repository has no '
  46. f'commits.')
  47. if sha:
  48. self.assertEqual(actual_sha, sha, f'Expected HEAD to have hash "{sha}" but had "{actual_sha}".')
  49. if message:
  50. self.assertIn(
  51. message, actual_message,
  52. f'Expected "{message}" to appear in the last commit message, but only found "{actual_message}".',
  53. )
  54. def assertHasCommit(self, repo: Repository, commit_id):
  55. self.assertIsNotNone(
  56. repo.get_commit(commit_id)[0], f'Expected repository to have commit {commit_id}, but it had not.'
  57. )
  58. def assertHasCommits(self, repo: Repository, commit_id_list):
  59. for commit in commit_id_list:
  60. self.assertHasCommit(repo, commit)
  61. def assertHasNotCommit(self, repo: Repository, commit_id):
  62. self.assertIsNone(
  63. repo.get_commit(commit_id)[0], f'Expected repository to not have commit {commit_id}, but it had.'
  64. )
  65. def assertHasNotCommits(self, repo: Repository, commit_id_list):
  66. for commit in commit_id_list:
  67. self.assertHasNotCommit(repo, commit)
  68. def assertNoCommits(self, repo: Repository):
  69. head = repo.get_head()
  70. self.assertEqual(head, (None, None), f'Expected that repository has no commits, but HEAD was {head}.')
  71. @staticmethod
  72. def _random_string(length):
  73. return ''.join(random.choices(string.ascii_lowercase, k=length))
  74. def _random_commit(self, repo: Repository, message=''):
  75. with open(os.path.join(repo.path, self._random_string(16)), 'w') as f:
  76. f.write(self._random_string(500))
  77. repo.commit_all(message)
  78. return repo.get_head()[0]
  79. def _random_commits(self, num, repo: Repository, message=''):
  80. return [self._random_commit(repo, message) for _ in range(num)]
  81. def test_init(self):
  82. with TemporaryDirectory() as path:
  83. repo = Repository(path)
  84. repo.init()
  85. self.assertGit(path)
  86. def test_commit(self):
  87. with TemporaryDirectory() as path:
  88. repo = Repository(path)
  89. repo.init()
  90. repo.commit_all('commit1')
  91. self.assertNoCommits(repo)
  92. with open(os.path.join(path, 'test_commit'), 'w') as f:
  93. f.write('foo')
  94. repo.commit_all('commit2')
  95. self.assertHead(repo, message='commit2')
  96. def test_remove_history(self):
  97. with TemporaryDirectory() as path:
  98. repo = Repository(path)
  99. repo.init()
  100. remove = self._random_commits(5, repo, 'to be removed') # we're going to remove these 'old' commits
  101. keep = self._random_commits(1, repo, 'anchor to be kept') # as sync anchor, the last 'old' commit is kept
  102. cutoff = datetime.now()
  103. time.sleep(1)
  104. keep += self._random_commits(5, repo, 'to be kept') # we're going to keep these 'new' commits
  105. self.assertHasCommits(repo, remove + keep)
  106. repo.remove_history(before=cutoff)
  107. self.assertHasCommits(repo, keep)
  108. self.assertHasNotCommits(repo, remove)