Pārlūkot izejas kodu

Add reachability abstraction (#1994)

Jelmer Vernooij 2 mēneši atpakaļ
vecāks
revīzija
9fefb91eb2
1 mainītis faili ar 177 papildinājumiem un 7 dzēšanām
  1. 177 7
      dulwich/object_store.py

+ 177 - 7
dulwich/object_store.py

@@ -104,6 +104,62 @@ class GraphWalker(Protocol):
         ...
 
 
+class ObjectReachabilityProvider(Protocol):
+    """Protocol for computing object reachability queries.
+
+    This abstraction allows reachability computations to be backed by either
+    naive graph traversal or optimized bitmap indexes, with a consistent interface.
+    """
+
+    def get_reachable_commits(
+        self,
+        heads: Iterable[bytes],
+        exclude: Iterable[bytes] | None = None,
+        shallow: Set[bytes] | None = None,
+    ) -> set[bytes]:
+        """Get all commits reachable from heads, excluding those in exclude.
+
+        Args:
+          heads: Starting commit SHAs
+          exclude: Commit SHAs to exclude (and their ancestors)
+          shallow: Set of shallow commit boundaries (traversal stops here)
+
+        Returns:
+          Set of commit SHAs reachable from heads but not from exclude
+        """
+        ...
+
+    def get_reachable_objects(
+        self,
+        commits: Iterable[bytes],
+        exclude_commits: Iterable[bytes] | None = None,
+    ) -> set[bytes]:
+        """Get all objects (commits + trees + blobs) reachable from commits.
+
+        Args:
+          commits: Starting commit SHAs
+          exclude_commits: Commits whose objects should be excluded
+
+        Returns:
+          Set of all object SHAs (commits, trees, blobs, tags)
+        """
+        ...
+
+    def get_tree_objects(
+        self,
+        tree_shas: Iterable[bytes],
+    ) -> set[bytes]:
+        """Get all trees and blobs reachable from the given trees.
+
+        Args:
+          tree_shas: Starting tree SHAs
+
+        Returns:
+          Set of tree and blob SHAs
+        """
+        ...
+
+
 INFODIR = "info"
 PACKDIR = "pack"
 
@@ -304,6 +360,18 @@ class BaseObjectStore:
         """
         raise NotImplementedError(self.add_objects)
 
+    def get_reachability_provider(self) -> ObjectReachabilityProvider:
+        """Get a reachability provider for this object store.
+
+        Returns an ObjectReachabilityProvider that can efficiently compute
+        object reachability queries. Subclasses can override this to provide
+        optimized implementations (e.g., using bitmap indexes).
+
+        Returns:
+          ObjectReachabilityProvider instance
+        """
+        return GraphTraversalReachability(self)
+
     def tree_changes(
         self,
         source: bytes | None,
@@ -2220,6 +2288,7 @@ class MissingObjectFinder:
         if shallow is None:
             shallow = set()
         self._get_parents = get_parents
+        reachability = object_store.get_reachability_provider()
         # process Commits and Tags differently
         # Note, while haves may list commits/tags not available locally,
         # and such SHAs would get filtered out by _split_commits_and_tags,
@@ -2233,12 +2302,9 @@ class MissingObjectFinder:
         )
         # all_ancestors is a set of commits that shall not be sent
         # (complete repository up to 'haves')
-        all_ancestors = _collect_ancestors(
-            object_store,
-            have_commits,
-            shallow=frozenset(shallow),
-            get_parents=self._get_parents,
-        )[0]
+        all_ancestors = reachability.get_reachable_commits(
+            have_commits, exclude=None, shallow=shallow
+        )
         # all_missing - complete set of commits between haves and wants
         # common - commits from all_ancestors we hit into while
         # traversing parent hierarchy of wants
@@ -2258,7 +2324,8 @@ class MissingObjectFinder:
             self.remote_has.add(h)
             cmt = object_store[h]
             assert isinstance(cmt, Commit)
-            _collect_filetree_revs(object_store, cmt.tree, self.remote_has)
+            tree_objects = reachability.get_tree_objects([cmt.tree])
+            self.remote_has.update(tree_objects)
         # record tags we have as visited, too
         for t in have_tags:
             self.remote_has.add(t)
@@ -2965,3 +3032,106 @@ def peel_sha(store: ObjectContainer, sha: bytes) -> tuple[ShaFile, ShaFile]:
         obj_class, sha = obj.object
         obj = store[sha]
     return unpeeled, obj
+
+
+# ObjectReachabilityProvider implementation
+
+
+class GraphTraversalReachability:
+    """Naive graph traversal implementation of ObjectReachabilityProvider.
+
+    This implementation wraps existing graph traversal functions
+    (_collect_ancestors, _collect_filetree_revs) to provide the standard
+    reachability interface without any performance optimizations.
+    """
+
+    def __init__(self, object_store: BaseObjectStore) -> None:
+        """Initialize the graph traversal provider.
+
+        Args:
+          object_store: Object store to query
+        """
+        self.store = object_store
+
+    def get_reachable_commits(
+        self,
+        heads: Iterable[bytes],
+        exclude: Iterable[bytes] | None = None,
+        shallow: Set[bytes] | None = None,
+    ) -> set[bytes]:
+        """Get all commits reachable from heads, excluding those in exclude.
+
+        Uses _collect_ancestors for commit traversal.
+
+        Args:
+          heads: Starting commit SHAs
+          exclude: Commit SHAs to exclude (and their ancestors)
+          shallow: Set of shallow commit boundaries
+
+        Returns:
+          Set of commit SHAs reachable from heads but not from exclude
+        """
+        exclude_set = frozenset(exclude) if exclude else frozenset()
+        shallow_set = frozenset(shallow) if shallow else frozenset()
+
+        commits, _bases = _collect_ancestors(
+            self.store, heads, exclude_set, shallow_set
+        )
+        return commits
+
+    def get_tree_objects(
+        self,
+        tree_shas: Iterable[bytes],
+    ) -> set[bytes]:
+        """Get all trees and blobs reachable from the given trees.
+
+        Uses _collect_filetree_revs for tree traversal.
+
+        Args:
+          tree_shas: Starting tree SHAs
+
+        Returns:
+          Set of tree and blob SHAs
+        """
+        result: set[bytes] = set()
+        for tree_sha in tree_shas:
+            _collect_filetree_revs(self.store, tree_sha, result)
+        return result
+
+    def get_reachable_objects(
+        self,
+        commits: Iterable[bytes],
+        exclude_commits: Iterable[bytes] | None = None,
+    ) -> set[bytes]:
+        """Get all objects (commits + trees + blobs) reachable from commits.
+
+        Args:
+          commits: Starting commit SHAs
+          exclude_commits: Commits whose objects should be excluded
+
+        Returns:
+          Set of all object SHAs (commits, trees, blobs)
+        """
+        commits_set = set(commits)
+        result = set(commits_set)
+
+        # Get trees for all commits
+        tree_shas = []
+        for commit_sha in commits_set:
+            try:
+                commit = self.store[commit_sha]
+                if isinstance(commit, Commit):
+                    tree_shas.append(commit.tree)
+            except KeyError:
+                # Commit not in store, skip
+                continue
+
+        # Collect all tree/blob objects
+        result.update(self.get_tree_objects(tree_shas))
+
+        # Exclude objects from exclude_commits if needed
+        if exclude_commits:
+            exclude_objects = self.get_reachable_objects(exclude_commits, None)
+            result -= exclude_objects
+
+        return result