test_filter_branch.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # test_filter_branch.py -- Tests for filter_branch module
  2. # Copyright (C) 2024 Jelmer Vernooij <jelmer@jelmer.uk>
  3. #
  4. # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
  5. # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
  6. # General Public License as public by the Free Software Foundation; version 2.0
  7. # or (at your option) any later version. You can redistribute it and/or
  8. # modify it under the terms of either of these two licenses.
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # You should have received a copy of the licenses; if not, see
  17. # <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
  18. # and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
  19. # License, Version 2.0.
  20. #
  21. """Tests for dulwich.filter_branch."""
  22. import unittest
  23. from dulwich.filter_branch import CommitFilter, filter_refs
  24. from dulwich.object_store import MemoryObjectStore
  25. from dulwich.objects import Commit, Tree
  26. from dulwich.refs import DictRefsContainer
  27. class CommitFilterTests(unittest.TestCase):
  28. """Tests for CommitFilter class."""
  29. def setUp(self):
  30. self.store = MemoryObjectStore()
  31. self.refs = DictRefsContainer({})
  32. # Create test commits
  33. tree = Tree()
  34. self.store.add_object(tree)
  35. self.c1 = Commit()
  36. self.c1.tree = tree.id
  37. self.c1.author = self.c1.committer = b"Test User <test@example.com>"
  38. self.c1.author_time = self.c1.commit_time = 1000
  39. self.c1.author_timezone = self.c1.commit_timezone = 0
  40. self.c1.message = b"First commit"
  41. self.store.add_object(self.c1)
  42. self.c2 = Commit()
  43. self.c2.tree = tree.id
  44. self.c2.parents = [self.c1.id]
  45. self.c2.author = self.c2.committer = b"Test User <test@example.com>"
  46. self.c2.author_time = self.c2.commit_time = 2000
  47. self.c2.author_timezone = self.c2.commit_timezone = 0
  48. self.c2.message = b"Second commit"
  49. self.store.add_object(self.c2)
  50. def test_filter_author(self):
  51. """Test filtering author."""
  52. def new_author(old):
  53. return b"New Author <new@example.com>"
  54. filter = CommitFilter(self.store, filter_author=new_author)
  55. new_sha = filter.process_commit(self.c2.id)
  56. self.assertNotEqual(new_sha, self.c2.id)
  57. new_commit = self.store[new_sha]
  58. self.assertEqual(new_commit.author, b"New Author <new@example.com>")
  59. self.assertEqual(new_commit.committer, self.c2.committer)
  60. def test_filter_message(self):
  61. """Test filtering message."""
  62. def prefix_message(msg):
  63. return b"[PREFIX] " + msg
  64. filter = CommitFilter(self.store, filter_message=prefix_message)
  65. new_sha = filter.process_commit(self.c2.id)
  66. self.assertNotEqual(new_sha, self.c2.id)
  67. new_commit = self.store[new_sha]
  68. self.assertEqual(new_commit.message, b"[PREFIX] Second commit")
  69. def test_filter_fn(self):
  70. """Test custom filter function."""
  71. def custom_filter(commit):
  72. return {
  73. "author": b"Custom <custom@example.com>",
  74. "message": b"Custom: " + commit.message,
  75. }
  76. filter = CommitFilter(self.store, filter_fn=custom_filter)
  77. new_sha = filter.process_commit(self.c2.id)
  78. self.assertNotEqual(new_sha, self.c2.id)
  79. new_commit = self.store[new_sha]
  80. self.assertEqual(new_commit.author, b"Custom <custom@example.com>")
  81. self.assertEqual(new_commit.message, b"Custom: Second commit")
  82. def test_no_changes(self):
  83. """Test commit with no changes."""
  84. filter = CommitFilter(self.store)
  85. new_sha = filter.process_commit(self.c2.id)
  86. self.assertEqual(new_sha, self.c2.id)
  87. def test_parent_rewriting(self):
  88. """Test that parent commits are rewritten."""
  89. def new_author(old):
  90. return b"New Author <new@example.com>"
  91. filter = CommitFilter(self.store, filter_author=new_author)
  92. new_sha = filter.process_commit(self.c2.id)
  93. # Check that parent was also rewritten
  94. new_commit = self.store[new_sha]
  95. self.assertEqual(len(new_commit.parents), 1)
  96. new_parent_sha = new_commit.parents[0]
  97. self.assertNotEqual(new_parent_sha, self.c1.id)
  98. new_parent = self.store[new_parent_sha]
  99. self.assertEqual(new_parent.author, b"New Author <new@example.com>")
  100. class FilterRefsTests(unittest.TestCase):
  101. """Tests for filter_refs function."""
  102. def setUp(self):
  103. self.store = MemoryObjectStore()
  104. self.refs = DictRefsContainer({})
  105. # Create test commits
  106. tree = Tree()
  107. self.store.add_object(tree)
  108. c1 = Commit()
  109. c1.tree = tree.id
  110. c1.author = c1.committer = b"Test User <test@example.com>"
  111. c1.author_time = c1.commit_time = 1000
  112. c1.author_timezone = c1.commit_timezone = 0
  113. c1.message = b"First commit"
  114. self.store.add_object(c1)
  115. self.refs[b"refs/heads/master"] = c1.id
  116. self.c1_id = c1.id
  117. def test_filter_refs_basic(self):
  118. """Test basic ref filtering."""
  119. def new_author(old):
  120. return b"New Author <new@example.com>"
  121. filter = CommitFilter(self.store, filter_author=new_author)
  122. result = filter_refs(
  123. self.refs,
  124. self.store,
  125. [b"refs/heads/master"],
  126. filter,
  127. )
  128. # Check mapping
  129. self.assertEqual(len(result), 1)
  130. self.assertIn(self.c1_id, result)
  131. self.assertNotEqual(result[self.c1_id], self.c1_id)
  132. # Check ref was updated
  133. new_sha = self.refs[b"refs/heads/master"]
  134. self.assertEqual(new_sha, result[self.c1_id])
  135. # Check original was saved
  136. original_sha = self.refs[b"refs/original/refs/heads/master"]
  137. self.assertEqual(original_sha, self.c1_id)
  138. def test_filter_refs_already_filtered(self):
  139. """Test error when refs already filtered."""
  140. # Set up an "already filtered" state
  141. self.refs[b"refs/original/refs/heads/master"] = b"0" * 40
  142. filter = CommitFilter(self.store)
  143. with self.assertRaises(ValueError) as cm:
  144. filter_refs(
  145. self.refs,
  146. self.store,
  147. [b"refs/heads/master"],
  148. filter,
  149. )
  150. self.assertIn("filtered already", str(cm.exception))
  151. def test_filter_refs_force(self):
  152. """Test force filtering."""
  153. # Set up an "already filtered" state
  154. self.refs[b"refs/original/refs/heads/master"] = b"0" * 40
  155. filter = CommitFilter(self.store)
  156. # Should not raise with force=True
  157. result = filter_refs(
  158. self.refs,
  159. self.store,
  160. [b"refs/heads/master"],
  161. filter,
  162. force=True,
  163. )
  164. self.assertEqual(len(result), 1)