Browse Source

Add tests for reachability providers

Jelmer Vernooij 2 months ago
parent
commit
8327cd6af8
2 changed files with 263 additions and 2 deletions
  1. 4 2
      dulwich/bitmap.py
  2. 259 0
      tests/test_bitmap.py

+ 4 - 2
dulwich/bitmap.py

@@ -33,7 +33,7 @@ import struct
 from collections import deque
 from collections.abc import Callable, Iterable, Iterator
 from io import BytesIO
-from typing import IO, TYPE_CHECKING, Callable, Optional
+from typing import IO, TYPE_CHECKING, Optional
 
 from .file import GitFile
 from .objects import Blob, Commit, Tag, Tree
@@ -1143,7 +1143,9 @@ def generate_bitmap(
     return pack_bitmap
 
 
-def find_commit_bitmaps(commit_shas: set[bytes], packs: Iterable[Pack]) -> dict[bytes, tuple]:
+def find_commit_bitmaps(
+    commit_shas: set[bytes], packs: Iterable[Pack]
+) -> dict[bytes, tuple]:
     """Find which packs have bitmaps for the given commits.
 
     Args:

+ 259 - 0
tests/test_bitmap.py

@@ -39,6 +39,7 @@ from dulwich.bitmap import (
     read_bitmap_file,
     write_bitmap_file,
 )
+from dulwich.object_store import BitmapReachability, GraphTraversalReachability
 
 
 class EWAHCompressionTests(unittest.TestCase):
@@ -903,3 +904,261 @@ class BitmapConfigTests(unittest.TestCase):
         config = ConfigFile()
         config.set((b"pack",), b"useBitmapIndex", b"false")
         self.assertFalse(config.get_boolean((b"pack",), b"useBitmapIndex", True))
+
+
+class ReachabilityProviderTests(unittest.TestCase):
+    """Tests for ObjectReachabilityProvider implementations."""
+
+    def setUp(self):
+        """Set up test repository with commits."""
+        from dulwich.object_store import DiskObjectStore
+        from dulwich.objects import Blob, Commit, Tree
+
+        self.test_dir = tempfile.mkdtemp()
+        self.store = DiskObjectStore(self.test_dir)
+
+        # Create a simple commit history:
+        # commit1 -> commit2 -> commit3
+        #         \-> commit4
+
+        # Create blob and tree
+        self.blob1 = Blob.from_string(b"test content 1")
+        self.store.add_object(self.blob1)
+
+        self.blob2 = Blob.from_string(b"test content 2")
+        self.store.add_object(self.blob2)
+
+        self.tree1 = Tree()
+        self.tree1[b"file1.txt"] = (0o100644, self.blob1.id)
+        self.store.add_object(self.tree1)
+
+        self.tree2 = Tree()
+        self.tree2[b"file1.txt"] = (0o100644, self.blob1.id)
+        self.tree2[b"file2.txt"] = (0o100644, self.blob2.id)
+        self.store.add_object(self.tree2)
+
+        # Create commit1 (root)
+        self.commit1 = Commit()
+        self.commit1.tree = self.tree1.id
+        self.commit1.message = b"First commit"
+        self.commit1.author = self.commit1.committer = b"Test <test@example.com>"
+        self.commit1.author_time = self.commit1.commit_time = 1234567890
+        self.commit1.author_timezone = self.commit1.commit_timezone = 0
+        self.store.add_object(self.commit1)
+
+        # Create commit2 (child of commit1)
+        self.commit2 = Commit()
+        self.commit2.tree = self.tree1.id
+        self.commit2.parents = [self.commit1.id]
+        self.commit2.message = b"Second commit"
+        self.commit2.author = self.commit2.committer = b"Test <test@example.com>"
+        self.commit2.author_time = self.commit2.commit_time = 1234567891
+        self.commit2.author_timezone = self.commit2.commit_timezone = 0
+        self.store.add_object(self.commit2)
+
+        # Create commit3 (child of commit2)
+        self.commit3 = Commit()
+        self.commit3.tree = self.tree2.id
+        self.commit3.parents = [self.commit2.id]
+        self.commit3.message = b"Third commit"
+        self.commit3.author = self.commit3.committer = b"Test <test@example.com>"
+        self.commit3.author_time = self.commit3.commit_time = 1234567892
+        self.commit3.author_timezone = self.commit3.commit_timezone = 0
+        self.store.add_object(self.commit3)
+
+        # Create commit4 (child of commit1, creates a branch)
+        self.commit4 = Commit()
+        self.commit4.tree = self.tree2.id
+        self.commit4.parents = [self.commit1.id]
+        self.commit4.message = b"Fourth commit"
+        self.commit4.author = self.commit4.committer = b"Test <test@example.com>"
+        self.commit4.author_time = self.commit4.commit_time = 1234567893
+        self.commit4.author_timezone = self.commit4.commit_timezone = 0
+        self.store.add_object(self.commit4)
+
+    def tearDown(self):
+        """Clean up test directory."""
+        import shutil
+
+        shutil.rmtree(self.test_dir)
+
+    def test_graph_traversal_reachability_single_commit(self):
+        """Test GraphTraversalReachability with single commit."""
+        from dulwich.object_store import GraphTraversalReachability
+
+        provider = GraphTraversalReachability(self.store)
+
+        # Get reachable commits from commit1
+        reachable = provider.get_reachable_commits(
+            [self.commit1.id], exclude=None, shallow=None
+        )
+
+        # Should only include commit1
+        self.assertEqual({self.commit1.id}, reachable)
+
+    def test_graph_traversal_reachability_linear_history(self):
+        """Test GraphTraversalReachability with linear history."""
+        from dulwich.object_store import GraphTraversalReachability
+
+        provider = GraphTraversalReachability(self.store)
+
+        # Get reachable commits from commit3
+        reachable = provider.get_reachable_commits(
+            [self.commit3.id], exclude=None, shallow=None
+        )
+
+        # Should include commit3, commit2, and commit1
+        expected = {self.commit1.id, self.commit2.id, self.commit3.id}
+        self.assertEqual(expected, reachable)
+
+    def test_graph_traversal_reachability_with_exclusion(self):
+        """Test GraphTraversalReachability with exclusion."""
+        from dulwich.object_store import GraphTraversalReachability
+
+        provider = GraphTraversalReachability(self.store)
+
+        # Get commits reachable from commit3 but not from commit1
+        reachable = provider.get_reachable_commits(
+            [self.commit3.id], exclude=[self.commit1.id], shallow=None
+        )
+
+        # Should include commit3 and commit2, but not commit1
+        expected = {self.commit2.id, self.commit3.id}
+        self.assertEqual(expected, reachable)
+
+    def test_graph_traversal_reachability_branching(self):
+        """Test GraphTraversalReachability with branching history."""
+        from dulwich.object_store import GraphTraversalReachability
+
+        provider = GraphTraversalReachability(self.store)
+
+        # Get reachable commits from both commit3 and commit4
+        reachable = provider.get_reachable_commits(
+            [self.commit3.id, self.commit4.id], exclude=None, shallow=None
+        )
+
+        # Should include all commits
+        expected = {self.commit1.id, self.commit2.id, self.commit3.id, self.commit4.id}
+        self.assertEqual(expected, reachable)
+
+    def test_graph_traversal_reachable_objects(self):
+        """Test GraphTraversalReachability.get_reachable_objects()."""
+        from dulwich.object_store import GraphTraversalReachability
+
+        provider = GraphTraversalReachability(self.store)
+
+        # Get all objects reachable from commit3
+        reachable = provider.get_reachable_objects(
+            [self.commit3.id], exclude_commits=None
+        )
+
+        # Should include commit3, blob1, and blob2 (but not tree objects themselves)
+        self.assertIn(self.commit3.id, reachable)
+        self.assertIn(self.blob1.id, reachable)
+        self.assertIn(self.blob2.id, reachable)
+        # Verify at least 3 objects
+        self.assertGreaterEqual(len(reachable), 3)
+
+    def test_graph_traversal_reachable_objects_with_exclusion(self):
+        """Test GraphTraversalReachability.get_reachable_objects() with exclusion."""
+        from dulwich.object_store import GraphTraversalReachability
+
+        provider = GraphTraversalReachability(self.store)
+
+        # Get objects reachable from commit3 but not from commit2
+        reachable = provider.get_reachable_objects(
+            [self.commit3.id], exclude_commits=[self.commit2.id]
+        )
+
+        # commit2 uses tree1 (which has blob1), commit3 uses tree2 (which has blob1 + blob2)
+        # So should include commit3 and blob2 (new in commit3)
+        # blob1 should be excluded because it's in tree1 (reachable from commit2)
+        self.assertIn(self.commit3.id, reachable)
+        self.assertIn(self.blob2.id, reachable)
+
+    def test_get_reachability_provider_without_bitmaps(self):
+        """Test get_reachability_provider returns GraphTraversalReachability when no bitmaps."""
+        from dulwich.object_store import (
+            GraphTraversalReachability,
+            get_reachability_provider,
+        )
+
+        provider = get_reachability_provider(self.store)
+
+        # Should return GraphTraversalReachability when no bitmaps available
+        self.assertIsInstance(provider, GraphTraversalReachability)
+
+    def test_get_reachability_provider_prefer_bitmaps_false(self):
+        """Test get_reachability_provider with prefer_bitmaps=False."""
+        from dulwich.object_store import (
+            GraphTraversalReachability,
+            get_reachability_provider,
+        )
+
+        provider = get_reachability_provider(self.store, prefer_bitmaps=False)
+
+        # Should return GraphTraversalReachability when prefer_bitmaps=False
+        self.assertIsInstance(provider, GraphTraversalReachability)
+
+    def test_bitmap_reachability_fallback_without_bitmaps(self):
+        """Test BitmapReachability falls back to graph traversal without bitmaps."""
+        provider = BitmapReachability(self.store)
+
+        # Without bitmaps, should fall back to graph traversal
+        reachable = provider.get_reachable_commits(
+            [self.commit3.id], exclude=None, shallow=None
+        )
+
+        # Should still work via fallback
+        expected = {self.commit1.id, self.commit2.id, self.commit3.id}
+        self.assertEqual(expected, reachable)
+
+    def test_bitmap_reachability_fallback_with_shallow(self):
+        """Test BitmapReachability falls back for shallow clones."""
+        provider = BitmapReachability(self.store)
+
+        # With shallow boundary, should fall back to graph traversal
+        reachable = provider.get_reachable_commits(
+            [self.commit3.id], exclude=None, shallow={self.commit2.id}
+        )
+
+        # Should include commit3 and commit2 (shallow boundary includes boundary commit)
+        # but not commit1 (beyond shallow boundary)
+        self.assertEqual({self.commit2.id, self.commit3.id}, reachable)
+
+    def test_reachability_provider_protocol(self):
+        """Test that both providers implement the same interface."""
+        graph_provider = GraphTraversalReachability(self.store)
+        bitmap_provider = BitmapReachability(self.store)
+
+        # Both should have the same methods
+        for method in [
+            "get_reachable_commits",
+            "get_reachable_objects",
+            "get_tree_objects",
+        ]:
+            self.assertTrue(hasattr(graph_provider, method))
+            self.assertTrue(hasattr(bitmap_provider, method))
+
+    def test_graph_traversal_vs_bitmap_consistency(self):
+        """Test that GraphTraversalReachability and BitmapReachability produce same results."""
+        graph_provider = GraphTraversalReachability(self.store)
+        bitmap_provider = BitmapReachability(self.store)  # Will use fallback
+
+        # Test get_reachable_commits
+        graph_commits = graph_provider.get_reachable_commits(
+            [self.commit3.id], exclude=[self.commit1.id], shallow=None
+        )
+        bitmap_commits = bitmap_provider.get_reachable_commits(
+            [self.commit3.id], exclude=[self.commit1.id], shallow=None
+        )
+        self.assertEqual(graph_commits, bitmap_commits)
+
+        # Test get_reachable_objects
+        graph_objects = graph_provider.get_reachable_objects(
+            [self.commit3.id], exclude_commits=None
+        )
+        bitmap_objects = bitmap_provider.get_reachable_objects(
+            [self.commit3.id], exclude_commits=None
+        )
+        self.assertEqual(graph_objects, bitmap_objects)