Переглянути джерело

Support "depth" argument in local clone (#1080)

* Factor out get_depth.

* Move find_shallow to object store.

* Support depth for local clones.
Jelmer Vernooij 1 місяць тому
батько
коміт
711e4000b1
8 змінених файлів з 232 додано та 100 видалено
  1. 3 3
      Cargo.lock
  2. 3 0
      NEWS
  3. 5 5
      dulwich/client.py
  4. 83 19
      dulwich/object_store.py
  5. 35 19
      dulwich/repo.py
  6. 5 43
      dulwich/server.py
  7. 89 1
      dulwich/tests/test_object_store.py
  8. 9 10
      tests/test_server.py

+ 3 - 3
Cargo.lock

@@ -10,7 +10,7 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26"
 
 [[package]]
 name = "diff-tree-py"
-version = "0.22.9"
+version = "0.23.0"
 dependencies = [
  "pyo3",
 ]
@@ -50,7 +50,7 @@ dependencies = [
 
 [[package]]
 name = "objects-py"
-version = "0.22.9"
+version = "0.23.0"
 dependencies = [
  "memchr",
  "pyo3",
@@ -64,7 +64,7 @@ checksum = "945462a4b81e43c4e3ba96bd7b49d834c6f61198356aa858733bc4acf3cbe62e"
 
 [[package]]
 name = "pack-py"
-version = "0.22.9"
+version = "0.23.0"
 dependencies = [
  "memchr",
  "pyo3",

+ 3 - 0
NEWS

@@ -1,5 +1,8 @@
 0.23.1	UNRELEASED
 
+ * Support ``depth`` for local clones.
+   (Jelmer Vernooij)
+
  * Add basic support for managing Notes. (Jelmer Vernooij)
 
  * Add basic ``cherry-pick`` subcommand.  (#1599, Jelmer Vernooij)

+ 5 - 5
dulwich/client.py

@@ -575,7 +575,7 @@ def _handle_upload_pack_head(
     graph_walker,
     wants,
     can_read,
-    depth,
+    depth: Optional[int],
     protocol_version,
 ):
     """Handle the head of a 'git-upload-pack' request.
@@ -831,7 +831,7 @@ class GitClient:
         checkout=None,
         branch=None,
         progress=None,
-        depth=None,
+        depth: Optional[int] = None,
         ref_prefix: Optional[list[Ref]] = None,
         filter_spec=None,
         protocol_version: Optional[int] = None,
@@ -1314,7 +1314,7 @@ class TraditionalGitClient(GitClient):
         graph_walker,
         pack_data,
         progress=None,
-        depth=None,
+        depth: Optional[int] = None,
         ref_prefix: Optional[list[Ref]] = None,
         filter_spec=None,
         protocol_version: Optional[int] = None,
@@ -1879,7 +1879,7 @@ class LocalGitClient(GitClient):
         graph_walker,
         pack_data,
         progress=None,
-        depth=None,
+        depth: Optional[int] = None,
         ref_prefix: Optional[list[Ref]] = None,
         filter_spec: Optional[bytes] = None,
         protocol_version: Optional[int] = None,
@@ -2650,7 +2650,7 @@ class AbstractHttpGitClient(GitClient):
         graph_walker,
         pack_data,
         progress=None,
-        depth=None,
+        depth: Optional[int] = None,
         ref_prefix: Optional[list[Ref]] = None,
         filter_spec=None,
         protocol_version: Optional[int] = None,

+ 83 - 19
dulwich/object_store.py

@@ -90,6 +90,83 @@ PACKDIR = "pack"
 PACK_MODE = 0o444 if sys.platform != "win32" else 0o644
 
 
+def find_shallow(store, heads, depth):
+    """Find shallow commits according to a given depth.
+
+    Args:
+      store: An ObjectStore for looking up objects.
+      heads: Iterable of head SHAs to start walking from.
+      depth: The depth of ancestors to include. A depth of one includes
+        only the heads themselves.
+    Returns: A tuple of (shallow, not_shallow), sets of SHAs that should be
+        considered shallow and unshallow according to the arguments. Note that
+        these sets may overlap if a commit is reachable along multiple paths.
+    """
+    parents = {}
+
+    def get_parents(sha):
+        result = parents.get(sha, None)
+        if not result:
+            result = store[sha].parents
+            parents[sha] = result
+        return result
+
+    todo = []  # stack of (sha, depth)
+    for head_sha in heads:
+        obj = store[head_sha]
+        # Peel tags if necessary
+        while isinstance(obj, Tag):
+            _, sha = obj.object
+            obj = store[sha]
+        if isinstance(obj, Commit):
+            todo.append((obj.id, 1))
+
+    not_shallow = set()
+    shallow = set()
+    while todo:
+        sha, cur_depth = todo.pop()
+        if cur_depth < depth:
+            not_shallow.add(sha)
+            new_depth = cur_depth + 1
+            todo.extend((p, new_depth) for p in get_parents(sha))
+        else:
+            shallow.add(sha)
+
+    return shallow, not_shallow
+
+
+def get_depth(
+    store,
+    head,
+    get_parents=lambda commit: commit.parents,
+    max_depth=None,
+):
+    """Return the current available depth for the given head.
+    For commits with multiple parents, the largest possible depth will be
+    returned.
+
+    Args:
+        head: commit to start from
+        get_parents: optional function for getting the parents of a commit
+        max_depth: maximum depth to search
+    """
+    if head not in store:
+        return 0
+    current_depth = 1
+    queue = [(head, current_depth)]
+    while queue and (max_depth is None or current_depth < max_depth):
+        e, depth = queue.pop(0)
+        current_depth = max(current_depth, depth)
+        cmt = store[e]
+        if isinstance(cmt, Tag):
+            _cls, sha = cmt.object
+            cmt = store[sha]
+        queue.extend(
+            (parent, depth + 1) for parent in get_parents(cmt) if parent in store
+        )
+    return current_depth
+
+
 class PackContainer(Protocol):
     def add_pack(self) -> tuple[BytesIO, Callable[[], None], Callable[[], None]]:
         """Add a new pack."""
@@ -334,21 +411,7 @@ class BaseObjectStore:
             get_parents: optional function for getting the parents of a commit
             max_depth: maximum depth to search
         """
-        if head not in self:
-            return 0
-        current_depth = 1
-        queue = [(head, current_depth)]
-        while queue and (max_depth is None or current_depth < max_depth):
-            e, depth = queue.pop(0)
-            current_depth = max(current_depth, depth)
-            cmt = self[e]
-            if isinstance(cmt, Tag):
-                _cls, sha = cmt.object
-                cmt = self[sha]
-            queue.extend(
-                (parent, depth + 1) for parent in get_parents(cmt) if parent in self
-            )
-        return current_depth
+        return get_depth(self, head, get_parents=get_parents, max_depth=max_depth)
 
     def close(self) -> None:
         """Close any files opened by this object store."""
@@ -401,7 +464,7 @@ class BaseObjectStore:
         return None
 
 
-class PackBasedObjectStore(BaseObjectStore):
+class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
     def __init__(self, pack_compression_level=-1, pack_index_version=None) -> None:
         self._pack_cache: dict[str, Pack] = {}
         self.pack_compression_level = pack_compression_level
@@ -663,9 +726,8 @@ class PackBasedObjectStore(BaseObjectStore):
 
     def iter_unpacked_subset(
         self,
-        shas,
-        *,
-        include_comp=False,
+        shas: set[bytes],
+        include_comp: bool = False,
         allow_missing: bool = False,
         convert_ofs_delta: bool = True,
     ) -> Iterator[UnpackedObject]:
@@ -1629,6 +1691,7 @@ class ObjectStoreGraphWalker:
         local_heads: Iterable[ObjectID],
         get_parents,
         shallow: Optional[set[ObjectID]] = None,
+        update_shallow=None,
     ) -> None:
         """Create a new instance.
 
@@ -1642,6 +1705,7 @@ class ObjectStoreGraphWalker:
         if shallow is None:
             shallow = set()
         self.shallow = shallow
+        self.update_shallow = update_shallow
 
     def nak(self) -> None:
         """Nothing in common was found."""

+ 35 - 19
dulwich/repo.py

@@ -79,6 +79,7 @@ from .object_store import (
     MissingObjectFinder,
     ObjectStoreGraphWalker,
     PackBasedObjectStore,
+    find_shallow,
     peel_sha,
 )
 from .objects import (
@@ -468,7 +469,9 @@ class BaseRepo:
         """
         raise NotImplementedError(self.open_index)
 
-    def fetch(self, target, determine_wants=None, progress=None, depth=None):
+    def fetch(
+        self, target, determine_wants=None, progress=None, depth: Optional[int] = None
+    ):
         """Fetch objects into another repository.
 
         Args:
@@ -495,8 +498,9 @@ class BaseRepo:
         determine_wants,
         graph_walker,
         progress,
+        *,
         get_tagged=None,
-        depth=None,
+        depth: Optional[int] = None,
     ):
         """Fetch the pack data required for a set of revisions.
 
@@ -514,8 +518,10 @@ class BaseRepo:
         Returns: count and iterator over pack data
         """
         missing_objects = self.find_missing_objects(
-            determine_wants, graph_walker, progress, get_tagged, depth=depth
+            determine_wants, graph_walker, progress, get_tagged=get_tagged, depth=depth
         )
+        if missing_objects is None:
+            return 0, iter([])
         remote_has = missing_objects.get_remote_has()
         object_ids = list(missing_objects)
         return len(object_ids), generate_unpacked_objects(
@@ -527,8 +533,9 @@ class BaseRepo:
         determine_wants,
         graph_walker,
         progress,
+        *,
         get_tagged=None,
-        depth=None,
+        depth: Optional[int] = None,
     ) -> Optional[MissingObjectFinder]:
         """Fetch the missing objects required for a set of revisions.
 
@@ -545,25 +552,31 @@ class BaseRepo:
           depth: Shallow fetch depth
         Returns: iterator over objects, with __len__ implemented
         """
-        if depth not in (None, 0):
-            raise NotImplementedError("depth not supported yet")
-
         refs = serialize_refs(self.object_store, self.get_refs())
 
         wants = determine_wants(refs)
         if not isinstance(wants, list):
             raise TypeError("determine_wants() did not return a list")
 
-        shallows: frozenset[ObjectID] = getattr(graph_walker, "shallow", frozenset())
-        unshallows: frozenset[ObjectID] = getattr(
-            graph_walker, "unshallow", frozenset()
-        )
+        current_shallow = set(getattr(graph_walker, "shallow", set()))
+
+        if depth not in (None, 0):
+            shallow, not_shallow = find_shallow(self.object_store, wants, depth)
+            # Only update if graph_walker has shallow attribute
+            if hasattr(graph_walker, "shallow"):
+                graph_walker.shallow.update(shallow - not_shallow)
+                new_shallow = graph_walker.shallow - current_shallow
+                unshallow = graph_walker.unshallow = not_shallow & current_shallow
+                if hasattr(graph_walker, "update_shallow"):
+                    graph_walker.update_shallow(new_shallow, unshallow)
+        else:
+            unshallow = getattr(graph_walker, "unshallow", frozenset())
 
         if wants == []:
             # TODO(dborowitz): find a way to short-circuit that doesn't change
             # this interface.
 
-            if shallows or unshallows:
+            if getattr(graph_walker, "shallow", set()) or unshallow:
                 # Do not send a pack in shallow short-circuit path
                 return None
 
@@ -586,12 +599,12 @@ class BaseRepo:
 
         # Deal with shallow requests separately because the haves do
         # not reflect what objects are missing
-        if shallows or unshallows:
+        if getattr(graph_walker, "shallow", set()) or unshallow:
             # TODO: filter the haves commits from iter_shas. the specific
             # commits aren't missing.
             haves = []
 
-        parents_provider = ParentsProvider(self.object_store, shallows=shallows)
+        parents_provider = ParentsProvider(self.object_store, shallows=current_shallow)
 
         def get_parents(commit):
             return parents_provider.get_parents(commit.id, commit)
@@ -600,7 +613,7 @@ class BaseRepo:
             self.object_store,
             haves=haves,
             wants=wants,
-            shallow=self.get_shallow(),
+            shallow=getattr(graph_walker, "shallow", set()),
             progress=progress,
             get_tagged=get_tagged,
             get_parents=get_parents,
@@ -649,7 +662,10 @@ class BaseRepo:
             ]
         parents_provider = ParentsProvider(self.object_store)
         return ObjectStoreGraphWalker(
-            heads, parents_provider.get_parents, shallow=self.get_shallow()
+            heads,
+            parents_provider.get_parents,
+            shallow=self.get_shallow(),
+            update_shallow=self.update_shallow,
         )
 
     def get_refs(self) -> dict[bytes, bytes]:
@@ -816,7 +832,7 @@ class BaseRepo:
 
         return Notes(self.object_store, self.refs)
 
-    def get_walker(self, include: Optional[list[bytes]] = None, *args, **kwargs):
+    def get_walker(self, include: Optional[list[bytes]] = None, **kwargs):
         """Obtain a walker for this repository.
 
         Args:
@@ -852,7 +868,7 @@ class BaseRepo:
 
         kwargs["get_parents"] = lambda commit: self.get_parents(commit.id, commit)
 
-        return Walker(self.object_store, include, *args, **kwargs)
+        return Walker(self.object_store, include, **kwargs)
 
     def __getitem__(self, name: Union[ObjectID, Ref]):
         """Retrieve a Git object by SHA1 or ref.
@@ -1576,7 +1592,7 @@ class Repo(BaseRepo):
         checkout=None,
         branch=None,
         progress=None,
-        depth=None,
+        depth: Optional[int] = None,
         symlinks=None,
     ) -> "Repo":
         """Clone this repository.

+ 5 - 43
dulwich/server.py

@@ -67,7 +67,7 @@ from .errors import (
     ObjectFormatException,
     UnexpectedCommandError,
 )
-from .object_store import peel_sha
+from .object_store import find_shallow
 from .objects import Commit, ObjectID, valid_hexsha
 from .pack import ObjectContainer, PackedObjectContainer, write_pack_from_container
 from .protocol import (
@@ -459,47 +459,6 @@ def _split_proto_line(line, allowed):
     raise GitProtocolError(f"Received invalid line from client: {line!r}")
 
 
-def _find_shallow(store: ObjectContainer, heads, depth):
-    """Find shallow commits according to a given depth.
-
-    Args:
-      store: An ObjectStore for looking up objects.
-      heads: Iterable of head SHAs to start walking from.
-      depth: The depth of ancestors to include. A depth of one includes
-        only the heads themselves.
-    Returns: A tuple of (shallow, not_shallow), sets of SHAs that should be
-        considered shallow and unshallow according to the arguments. Note that
-        these sets may overlap if a commit is reachable along multiple paths.
-    """
-    parents: dict[bytes, list[bytes]] = {}
-
-    def get_parents(sha):
-        result = parents.get(sha, None)
-        if not result:
-            result = store[sha].parents
-            parents[sha] = result
-        return result
-
-    todo = []  # stack of (sha, depth)
-    for head_sha in heads:
-        _unpeeled, peeled = peel_sha(store, head_sha)
-        if isinstance(peeled, Commit):
-            todo.append((peeled.id, 1))
-
-    not_shallow = set()
-    shallow = set()
-    while todo:
-        sha, cur_depth = todo.pop()
-        if cur_depth < depth:
-            not_shallow.add(sha)
-            new_depth = cur_depth + 1
-            todo.extend((p, new_depth) for p in get_parents(sha))
-        else:
-            shallow.add(sha)
-
-    return shallow, not_shallow
-
-
 def _want_satisfied(store: ObjectContainer, haves, want, earliest) -> bool:
     o = store[want]
     pending = collections.deque([o])
@@ -719,7 +678,7 @@ class _ProtocolGraphWalker:
             self.client_shallow.add(val)
         self.read_proto_line((None,))  # consume client's flush-pkt
 
-        shallow, not_shallow = _find_shallow(self.store, wants, depth)
+        shallow, not_shallow = find_shallow(self.store, wants, depth)
 
         # Update self.shallow instead of reassigning it since we passed a
         # reference to it before this method was called.
@@ -727,6 +686,9 @@ class _ProtocolGraphWalker:
         new_shallow = self.shallow - self.client_shallow
         unshallow = self.unshallow = not_shallow & self.client_shallow
 
+        self.update_shallow(new_shallow, unshallow)
+
+    def update_shallow(self, new_shallow, unshallow):
         for sha in sorted(new_shallow):
             self.proto.write_pkt_line(format_shallow_line(sha))
         for sha in sorted(unshallow):

+ 89 - 1
dulwich/tests/test_object_store.py

@@ -22,11 +22,14 @@
 """Tests for the object store interface."""
 
 from typing import TYPE_CHECKING, Any, Callable
+from unittest import TestCase
 from unittest.mock import patch
 
 from dulwich.index import commit_tree
 from dulwich.object_store import (
+    MemoryObjectStore,
     PackBasedObjectStore,
+    find_shallow,
     iter_tree_contents,
     peel_sha,
 )
@@ -37,7 +40,7 @@ from dulwich.objects import (
 )
 from dulwich.protocol import DEPTH_INFINITE
 
-from .utils import make_object, make_tag
+from .utils import make_commit, make_object, make_tag
 
 if TYPE_CHECKING:
     from dulwich.object_store import BaseObjectStore
@@ -395,3 +398,88 @@ class PackBasedObjectStoreTests(ObjectStoreTests):
         # Verify it's gone
         self.assertFalse(self.store.contains_loose(b1.id))
         self.assertNotIn(b1.id, self.store)
+
+
+class FindShallowTests(TestCase):
+    def setUp(self):
+        super().setUp()
+        self._store = MemoryObjectStore()
+
+    def make_commit(self, **attrs):
+        commit = make_commit(**attrs)
+        self._store.add_object(commit)
+        return commit
+
+    def make_linear_commits(self, n, message=b""):
+        commits = []
+        parents = []
+        for _ in range(n):
+            commits.append(self.make_commit(parents=parents, message=message))
+            parents = [commits[-1].id]
+        return commits
+
+    def assertSameElements(self, expected, actual):
+        self.assertEqual(set(expected), set(actual))
+
+    def test_linear(self):
+        c1, c2, c3 = self.make_linear_commits(3)
+
+        self.assertEqual((set([c3.id]), set([])), find_shallow(self._store, [c3.id], 1))
+        self.assertEqual(
+            (set([c2.id]), set([c3.id])),
+            find_shallow(self._store, [c3.id], 2),
+        )
+        self.assertEqual(
+            (set([c1.id]), set([c2.id, c3.id])),
+            find_shallow(self._store, [c3.id], 3),
+        )
+        self.assertEqual(
+            (set([]), set([c1.id, c2.id, c3.id])),
+            find_shallow(self._store, [c3.id], 4),
+        )
+
+    def test_multiple_independent(self):
+        a = self.make_linear_commits(2, message=b"a")
+        b = self.make_linear_commits(2, message=b"b")
+        c = self.make_linear_commits(2, message=b"c")
+        heads = [a[1].id, b[1].id, c[1].id]
+
+        self.assertEqual(
+            (set([a[0].id, b[0].id, c[0].id]), set(heads)),
+            find_shallow(self._store, heads, 2),
+        )
+
+    def test_multiple_overlapping(self):
+        # Create the following commit tree:
+        # 1--2
+        #  \
+        #   3--4
+        c1, c2 = self.make_linear_commits(2)
+        c3 = self.make_commit(parents=[c1.id])
+        c4 = self.make_commit(parents=[c3.id])
+
+        # 1 is shallow along the path from 4, but not along the path from 2.
+        self.assertEqual(
+            (set([c1.id]), set([c1.id, c2.id, c3.id, c4.id])),
+            find_shallow(self._store, [c2.id, c4.id], 3),
+        )
+
+    def test_merge(self):
+        c1 = self.make_commit()
+        c2 = self.make_commit()
+        c3 = self.make_commit(parents=[c1.id, c2.id])
+
+        self.assertEqual(
+            (set([c1.id, c2.id]), set([c3.id])),
+            find_shallow(self._store, [c3.id], 2),
+        )
+
+    def test_tag(self):
+        c1, c2 = self.make_linear_commits(2)
+        tag = make_tag(c2, name=b"tag")
+        self._store.add_object(tag)
+
+        self.assertEqual(
+            (set([c1.id]), set([c2.id])),
+            find_shallow(self._store, [tag.id], 2),
+        )

+ 9 - 10
tests/test_server.py

@@ -33,7 +33,7 @@ from dulwich.errors import (
     NotGitRepository,
     UnexpectedCommandError,
 )
-from dulwich.object_store import MemoryObjectStore
+from dulwich.object_store import MemoryObjectStore, find_shallow
 from dulwich.objects import Tree
 from dulwich.protocol import ZERO_SHA, format_capability_line
 from dulwich.repo import MemoryRepo, Repo
@@ -47,7 +47,6 @@ from dulwich.server import (
     ReceivePackHandler,
     SingleAckGraphWalkerImpl,
     UploadPackHandler,
-    _find_shallow,
     _ProtocolGraphWalker,
     _split_proto_line,
     serve_command,
@@ -271,18 +270,18 @@ class FindShallowTests(TestCase):
     def test_linear(self) -> None:
         c1, c2, c3 = self.make_linear_commits(3)
 
-        self.assertEqual(({c3.id}, set()), _find_shallow(self._store, [c3.id], 1))
+        self.assertEqual(({c3.id}, set()), find_shallow(self._store, [c3.id], 1))
         self.assertEqual(
             ({c2.id}, {c3.id}),
-            _find_shallow(self._store, [c3.id], 2),
+            find_shallow(self._store, [c3.id], 2),
         )
         self.assertEqual(
             ({c1.id}, {c2.id, c3.id}),
-            _find_shallow(self._store, [c3.id], 3),
+            find_shallow(self._store, [c3.id], 3),
         )
         self.assertEqual(
             (set(), {c1.id, c2.id, c3.id}),
-            _find_shallow(self._store, [c3.id], 4),
+            find_shallow(self._store, [c3.id], 4),
         )
 
     def test_multiple_independent(self) -> None:
@@ -293,7 +292,7 @@ class FindShallowTests(TestCase):
 
         self.assertEqual(
             ({a[0].id, b[0].id, c[0].id}, set(heads)),
-            _find_shallow(self._store, heads, 2),
+            find_shallow(self._store, heads, 2),
         )
 
     def test_multiple_overlapping(self) -> None:
@@ -308,7 +307,7 @@ class FindShallowTests(TestCase):
         # 1 is shallow along the path from 4, but not along the path from 2.
         self.assertEqual(
             ({c1.id}, {c1.id, c2.id, c3.id, c4.id}),
-            _find_shallow(self._store, [c2.id, c4.id], 3),
+            find_shallow(self._store, [c2.id, c4.id], 3),
         )
 
     def test_merge(self) -> None:
@@ -318,7 +317,7 @@ class FindShallowTests(TestCase):
 
         self.assertEqual(
             ({c1.id, c2.id}, {c3.id}),
-            _find_shallow(self._store, [c3.id], 2),
+            find_shallow(self._store, [c3.id], 2),
         )
 
     def test_tag(self) -> None:
@@ -328,7 +327,7 @@ class FindShallowTests(TestCase):
 
         self.assertEqual(
             ({c1.id}, {c2.id}),
-            _find_shallow(self._store, [tag.id], 2),
+            find_shallow(self._store, [tag.id], 2),
         )