Prechádzať zdrojové kódy

Support depth for local clones.

Jelmer Vernooij 2 rokov pred
rodič
commit
44631bcf86

+ 3 - 0
NEWS

@@ -1,5 +1,8 @@
 0.20.47	UNRELEASED
 
+ * Support ``depth`` for local clones,
+   (Jelmer Vernooij)
+
  * Fix Repo.reset_index.
    Previously, it instead took the union with the given tree.
    (Christian Sattler, #1072)

+ 12 - 8
dulwich/client.py

@@ -519,7 +519,7 @@ def _read_side_band64k_data(pkt_seq, channel_callbacks):
 
 
 def _handle_upload_pack_head(
-    proto, capabilities, graph_walker, wants, can_read, depth
+    proto, capabilities, graph_walker, wants, can_read, depth: Optional[int] = None
 ):
     """Handle the head of a 'git-upload-pack' request.
 
@@ -721,8 +721,10 @@ class GitClient(object):
         """
         raise NotImplementedError(self.send_pack)
 
-    def clone(self, path, target_path, mkdir: bool = True, bare=False, origin="origin",
-              checkout=None, branch=None, progress=None, depth=None):
+    def clone(self, path, target_path,
+              mkdir: bool = True, bare=False, origin="origin",
+              checkout=None, branch=None, progress=None,
+              depth: Optional[int] = None):
         """Clone a repository."""
         from .refs import _set_origin_head, _set_default_branch, _set_head
 
@@ -758,6 +760,7 @@ class GitClient(object):
 
             ref_message = b"clone: from " + encoded_path
             result = self.fetch(path, target, progress=progress, depth=depth)
+
             _import_remote_refs(
                 target.refs, origin, result.refs, message=ref_message)
 
@@ -857,7 +860,7 @@ class GitClient(object):
         graph_walker,
         pack_data,
         progress=None,
-        depth=None,
+        depth: Optional[int] = None,
     ):
         """Retrieve a pack from a git smart server.
 
@@ -1125,7 +1128,7 @@ class TraditionalGitClient(GitClient):
         graph_walker,
         pack_data,
         progress=None,
-        depth=None,
+        depth: Optional[int] = None,
     ):
         """Retrieve a pack from a git smart server.
 
@@ -1483,7 +1486,8 @@ class LocalGitClient(GitClient):
 
         return SendPackResult(new_refs, ref_status=ref_status)
 
-    def fetch(self, path, target, determine_wants=None, progress=None, depth=None):
+    def fetch(self, path, target, determine_wants=None, progress=None,
+              depth: Optional[int] = None):
         """Fetch into a target repository.
 
         Args:
@@ -1515,7 +1519,7 @@ class LocalGitClient(GitClient):
         graph_walker,
         pack_data,
         progress=None,
-        depth=None,
+        depth: Optional[int] = None,
     ):
         """Retrieve a pack from a git smart server.
 
@@ -2056,7 +2060,7 @@ class AbstractHttpGitClient(GitClient):
         graph_walker,
         pack_data,
         progress=None,
-        depth=None,
+        depth: Optional[int] = None,
     ):
         """Retrieve a pack from a git smart server.
 

+ 6 - 5
dulwich/object_store.py

@@ -1380,7 +1380,7 @@ class MissingObjectFinder(object):
     def add_todo(self, entries):
         self.objects_to_send.update([e for e in entries if not e[0] in self.sha_done])
 
-    def next(self):
+    def __next__(self):
         while True:
             if not self.objects_to_send:
                 return None
@@ -1407,7 +1407,7 @@ class MissingObjectFinder(object):
         self.progress(("counting objects: %d\r" % len(self.sha_done)).encode("ascii"))
         return (sha, name)
 
-    __next__ = next
+    next = __next__
 
 
 class ObjectStoreGraphWalker(object):
@@ -1418,7 +1418,7 @@ class ObjectStoreGraphWalker(object):
       get_parents: Function to retrieve parents in the local repo
     """
 
-    def __init__(self, local_heads, get_parents, shallow=None):
+    def __init__(self, local_heads, get_parents, shallow=None, update_shallow=None):
         """Create a new instance.
 
         Args:
@@ -1431,6 +1431,7 @@ class ObjectStoreGraphWalker(object):
         if shallow is None:
             shallow = set()
         self.shallow = shallow
+        self.update_shallow = update_shallow
 
     def ack(self, sha):
         """Ack that a revision and its ancestors are present in the source."""
@@ -1458,7 +1459,7 @@ class ObjectStoreGraphWalker(object):
 
             ancestors = new_ancestors
 
-    def next(self):
+    def __next__(self):
         """Iterate over ancestors of heads in the target."""
         if self.heads:
             ret = self.heads.pop()
@@ -1471,7 +1472,7 @@ class ObjectStoreGraphWalker(object):
             return ret
         return None
 
-    __next__ = next
+    next = __next__
 
 
 def commit_tree_changes(object_store, tree, changes):

+ 29 - 16
dulwich/repo.py

@@ -72,6 +72,7 @@ from dulwich.object_store import (
     MemoryObjectStore,
     BaseObjectStore,
     ObjectStoreGraphWalker,
+    find_shallow,
 )
 from dulwich.objects import (
     check_hexsha,
@@ -437,7 +438,8 @@ class BaseRepo(object):
         """
         raise NotImplementedError(self.open_index)
 
-    def fetch(self, target, determine_wants=None, progress=None, depth=None):
+    def fetch(self, target, determine_wants=None, progress=None,
+              depth: Optional[int] = None):
         """Fetch objects into another repository.
 
         Args:
@@ -450,9 +452,10 @@ class BaseRepo(object):
         """
         if determine_wants is None:
             determine_wants = target.object_store.determine_wants_all
+        graph_walker = target.get_graph_walker()
         count, pack_data = self.fetch_pack_data(
             determine_wants,
-            target.get_graph_walker(),
+            graph_walker,
             progress=progress,
             depth=depth,
         )
@@ -464,8 +467,9 @@ class BaseRepo(object):
         determine_wants,
         graph_walker,
         progress,
+        *,
         get_tagged=None,
-        depth=None,
+        depth: Optional[int] = None,
     ):
         """Fetch the pack data required for a set of revisions.
 
@@ -484,7 +488,7 @@ class BaseRepo(object):
         """
         # TODO(jelmer): Fetch pack data directly, don't create objects first.
         objects = self.fetch_objects(
-            determine_wants, graph_walker, progress, get_tagged, depth=depth
+            determine_wants, graph_walker, progress, get_tagged=get_tagged, depth=depth
         )
         return pack_objects_to_data(objects)
 
@@ -493,8 +497,9 @@ class BaseRepo(object):
         determine_wants,
         graph_walker,
         progress,
+        *,
         get_tagged=None,
-        depth=None,
+        depth: Optional[int] = None,
     ):
         """Fetch the missing objects required for a set of revisions.
 
@@ -511,9 +516,6 @@ class BaseRepo(object):
           depth: Shallow fetch depth
         Returns: iterator over objects, with __len__ implemented
         """
-        if depth not in (None, 0):
-            raise NotImplementedError("depth not supported yet")
-
         refs = {}
         for ref, sha in self.get_refs().items():
             try:
@@ -534,14 +536,23 @@ class BaseRepo(object):
         if not isinstance(wants, list):
             raise TypeError("determine_wants() did not return a list")
 
-        shallows = getattr(graph_walker, "shallow", frozenset())
-        unshallows = getattr(graph_walker, "unshallow", frozenset())
+        current_shallow = set(graph_walker.shallow)
+
+        if depth not in (None, 0):
+            shallow, not_shallow = find_shallow(
+                self.object_store, wants, depth)
+            graph_walker.shallow.update(shallow - not_shallow)
+            new_shallow = graph_walker.shallow - current_shallow
+            unshallow = graph_walker.unshallow = not_shallow & current_shallow
+            graph_walker.update_shallow(new_shallow, unshallow)
+        else:
+            unshallow = getattr(graph_walker, "unshallow", frozenset())
 
         if wants == []:
             # TODO(dborowitz): find a way to short-circuit that doesn't change
             # this interface.
 
-            if shallows or unshallows:
+            if graph_walker.shallow or unshallow:
                 # Do not send a pack in shallow short-circuit path
                 return None
 
@@ -554,12 +565,13 @@ class BaseRepo(object):
 
         # Deal with shallow requests separately because the haves do
         # not reflect what objects are missing
-        if shallows or unshallows:
+        if graph_walker.shallow or unshallow:
             # TODO: filter the haves commits from iter_shas. the specific
             # commits aren't missing.
             haves = []
 
-        parents_provider = ParentsProvider(self.object_store, shallows=shallows)
+        parents_provider = ParentsProvider(
+            self.object_store, shallows=current_shallow)
 
         def get_parents(commit):
             return parents_provider.get_parents(commit.id, commit)
@@ -568,7 +580,7 @@ class BaseRepo(object):
             self.object_store.find_missing_objects(
                 haves,
                 wants,
-                self.get_shallow(),
+                graph_walker.shallow,
                 progress,
                 get_tagged,
                 get_parents=get_parents,
@@ -613,7 +625,8 @@ class BaseRepo(object):
             ]
         parents_provider = ParentsProvider(self.object_store)
         return ObjectStoreGraphWalker(
-            heads, parents_provider.get_parents, shallow=self.get_shallow()
+            heads, parents_provider.get_parents, shallow=self.get_shallow(),
+            update_shallow=self.update_shallow,
         )
 
     def get_refs(self) -> Dict[bytes, bytes]:
@@ -1472,7 +1485,7 @@ class Repo(BaseRepo):
         checkout=None,
         branch=None,
         progress=None,
-        depth=None,
+        depth: Optional[int] = None,
     ):
         """Clone this repository.
 

+ 3 - 0
dulwich/server.py

@@ -667,6 +667,9 @@ class _ProtocolGraphWalker(object):
         new_shallow = self.shallow - self.client_shallow
         unshallow = self.unshallow = not_shallow & self.client_shallow
 
+        self.update_shallow(new_shallow, unshallow)
+
+    def update_shallow(self, new_shallow: List[bytes], unshallow: List[bytes]):
         for sha in sorted(new_shallow):
             self.proto.write_pkt_line(format_shallow_line(sha))
         for sha in sorted(unshallow):

+ 3 - 3
dulwich/tests/test_client.py

@@ -184,7 +184,7 @@ class GitClientTests(TestCase):
             self.assertEqual({}, heads)
             return []
 
-        ret = self.client.fetch_pack(b"bla", check_heads, None, None, None)
+        ret = self.client.fetch_pack(b"bla", check_heads, None, None)
         self.assertEqual({}, ret.refs)
         self.assertEqual({}, ret.symrefs)
         self.assertEqual(self.rout.getvalue(), b"0000")
@@ -197,7 +197,7 @@ class GitClientTests(TestCase):
             b"0000"
         )
         self.rin.seek(0)
-        ret = self.client.fetch_pack(b"bla", lambda heads, **kwargs: [], None, None, None)
+        ret = self.client.fetch_pack(b"bla", lambda heads, **kwargs: [], None, None)
         self.assertEqual(
             {b"HEAD": b"55dcc6bf963f922e1ed5c4bbaaefcfacef57b1d7"}, ret.refs
         )
@@ -891,7 +891,7 @@ class LocalGitClientTests(TestCase):
         s = open_repo("a.git")
         self.addCleanup(tear_down_repo, s)
         out = BytesIO()
-        walker = {}
+        walker = MemoryRepo().get_graph_walker()
         ret = c.fetch_pack(
             s.path, lambda heads, **kwargs: [], graph_walker=walker, pack_data=out.write
         )

+ 37 - 0
dulwich/tests/test_porcelain.py

@@ -872,6 +872,43 @@ class CloneTests(PorcelainTestCase):
             target_repo.refs.get_symrefs(),
         )
 
+    def test_local_depth(self):
+        f1_1 = make_object(Blob, data=b"f1")
+        commit_spec = [[1], [2, 1], [3, 1, 2]]
+        trees = {
+            1: [(b"f1", f1_1), (b"f2", f1_1)],
+            2: [(b"f1", f1_1), (b"f2", f1_1)],
+            3: [(b"f1", f1_1), (b"f2", f1_1)],
+        }
+
+        c1, c2, c3 = build_commit_graph(self.repo.object_store, commit_spec, trees)
+        self.repo.refs[b"refs/heads/master"] = c3.id
+        self.repo.refs[b"refs/tags/foo"] = c3.id
+        target_path = tempfile.mkdtemp()
+        errstream = BytesIO()
+        self.addCleanup(shutil.rmtree, target_path)
+        r = porcelain.clone(
+            self.repo.path, target_path, checkout=False, errstream=errstream,
+            depth=1
+        )
+        self.addCleanup(r.close)
+        self.assertEqual(r.path, target_path)
+        target_repo = Repo(target_path)
+        self.assertEqual([c3.id], [w.commit.id for w in target_repo.get_walker()])
+        self.assertEqual(0, len(target_repo.open_index()))
+        self.assertEqual(c3.id, target_repo.refs[b"refs/tags/foo"])
+        self.assertNotIn(b"f1", os.listdir(target_path))
+        self.assertNotIn(b"f2", os.listdir(target_path))
+        c = r.get_config()
+        encoded_path = self.repo.path
+        if not isinstance(encoded_path, bytes):
+            encoded_path = encoded_path.encode("utf-8")
+        self.assertEqual(encoded_path, c.get((b"remote", b"origin"), b"url"))
+        self.assertEqual(
+            b"+refs/heads/*:refs/remotes/origin/*",
+            c.get((b"remote", b"origin"), b"fetch"),
+        )
+
 
 class InitTests(TestCase):
     def test_non_bare(self):