2
0
Эх сурвалжийг харах

Merge branch 'fetch-deepen' of git://github.com/pmrowla/dulwich

Jelmer Vernooij 4 жил өмнө
parent
commit
a64f326a26

+ 4 - 1
docs/tutorial/remote.txt

@@ -41,10 +41,13 @@ The client object can then be used to retrieve a pack. The ``fetch_pack``
 method takes a ``determine_wants`` callback argument, which allows the
 client to determine which objects it wants to end up with::
 
-   >>> def determine_wants(refs):
+   >>> def determine_wants(refs, depth=None):
    ...    # retrieve all objects
    ...    return refs.values()
 
+Note that the ``depth`` keyword argument will contain an optional requested
+shallow fetch depth.
+
 Another required object is a "graph walker", which is used to determine
 which objects that the client already has should not be sent again
 by the server. Here in the tutorial we'll just use a dummy graph walker

+ 1 - 1
dulwich/cli.py

@@ -116,7 +116,7 @@ class cmd_fetch_pack(Command):
             determine_wants = r.object_store.determine_wants_all
         else:
 
-            def determine_wants(x):
+            def determine_wants(x, **kwargs):
                 return [y for y in args if y not in r.object_store]
 
         client.fetch(path, r, determine_wants)

+ 2 - 2
dulwich/client.py

@@ -1023,7 +1023,7 @@ class TraditionalGitClient(GitClient):
                 return FetchPackResult(refs, symrefs, agent)
 
             try:
-                wants = determine_wants(refs)
+                wants = determine_wants(refs, depth=depth)
             except BaseException:
                 proto.write_pkt_line(None)
                 raise
@@ -2042,7 +2042,7 @@ class HttpGitClient(GitClient):
             symrefs,
             agent,
         ) = self._negotiate_upload_pack_capabilities(server_capabilities)
-        wants = determine_wants(refs)
+        wants = determine_wants(refs, depth=depth)
         if wants is not None:
             wants = [cid for cid in wants if cid != ZERO_SHA]
         if not wants:

+ 7 - 7
dulwich/contrib/test_swift_smoke.py

@@ -127,7 +127,7 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         self.assertEqual(remote_refs, None)
 
     def test_push_commit(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {"refs/heads/master": local_repo.refs["HEAD"]}
 
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
@@ -144,7 +144,7 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         self.assertEqual(sha, remote_sha)
 
     def test_push_branch(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {"refs/heads/mybranch": local_repo.refs["refs/heads/mybranch"]}
 
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
@@ -161,7 +161,7 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         self.assertEqual(sha, remote_sha)
 
     def test_push_multiple_branch(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {
                 "refs/heads/mybranch": local_repo.refs["refs/heads/mybranch"],
                 "refs/heads/master": local_repo.refs["refs/heads/master"],
@@ -191,7 +191,7 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         self.assertDictEqual(local_shas, remote_shas)
 
     def test_push_data_branch(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {"refs/heads/master": local_repo.refs["HEAD"]}
 
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
@@ -243,7 +243,7 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         for f in files:
             self.assertEqual(os.path.isfile(f), True)
 
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {"refs/heads/master": local_repo.refs["HEAD"]}
 
         os.mkdir(os.path.join(self.temp_d, "test"))
@@ -259,7 +259,7 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         )
 
     def test_push_remove_branch(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {
                 "refs/heads/pullr-108": objects.ZERO_SHA,
                 "refs/heads/master": local_repo.refs["refs/heads/master"],
@@ -276,7 +276,7 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         self.assertNotIn("refs/heads/pullr-108", swift_repo.refs.allkeys())
 
     def test_push_annotated_tag(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {
                 "refs/heads/master": local_repo.refs["HEAD"],
                 "refs/tags/v1.0": local_repo.refs["refs/tags/v1.0"],

+ 44 - 3
dulwich/object_store.py

@@ -64,6 +64,7 @@ from dulwich.pack import (
     PackIndexer,
     PackStreamCopier,
 )
+from dulwich.protocol import DEPTH_INFINITE
 from dulwich.refs import ANNOTATED_TAG_SUFFIX
 
 INFODIR = "info"
@@ -73,11 +74,18 @@ PACKDIR = "pack"
 class BaseObjectStore(object):
     """Object store interface."""
 
-    def determine_wants_all(self, refs):
+    def determine_wants_all(self, refs, depth=None):
+        def _want_deepen(sha):
+            if not depth:
+                return False
+            if depth == DEPTH_INFINITE:
+                return True
+            return depth > self._get_depth(sha)
+
         return [
             sha
             for (ref, sha) in refs.items()
-            if sha not in self
+            if (sha not in self or _want_deepen(sha))
             and not ref.endswith(ANNOTATED_TAG_SUFFIX)
             and not sha == ZERO_SHA
         ]
@@ -350,6 +358,36 @@ class BaseObjectStore(object):
                 queue.extend(get_parents(cmt))
         return (commits, bases)
 
+    def _get_depth(
+        self, head, get_parents=lambda commit: commit.parents, max_depth=None,
+    ):
+        """Return the current available depth for the given head.
+        For commits with multiple parents, the largest possible depth will be
+        returned.
+
+        Args:
+            head: commit to start from
+            get_parents: optional function for getting the parents of a commit
+            max_depth: maximum depth to search
+        """
+        if head not in self:
+            return 0
+        current_depth = 1
+        queue = [(head, current_depth)]
+        while queue and (max_depth is None or current_depth < max_depth):
+            e, depth = queue.pop(0)
+            current_depth = max(current_depth, depth)
+            cmt = self[e]
+            if isinstance(cmt, Tag):
+                _cls, sha = cmt.object
+                cmt = self[sha]
+            queue.extend(
+                (parent, depth + 1)
+                for parent in get_parents(cmt)
+                if parent in self
+            )
+        return current_depth
+
     def close(self):
         """Close any files opened by this object store."""
         # Default implementation is a NO-OP
@@ -1353,7 +1391,10 @@ class ObjectStoreGraphWalker(object):
         """Iterate over ancestors of heads in the target."""
         if self.heads:
             ret = self.heads.pop()
-            ps = self.get_parents(ret)
+            try:
+                ps = self.get_parents(ret)
+            except KeyError:
+                return None
             self.parents[ret] = ps
             self.heads.update([p for p in ps if p not in self.parents])
             return ret

+ 1 - 1
dulwich/porcelain.py

@@ -1154,7 +1154,7 @@ def pull(
             refspecs = [b"HEAD"]
         selected_refs = []
 
-        def determine_wants(remote_refs):
+        def determine_wants(remote_refs, **kwargs):
             selected_refs.extend(
                 parse_reftuples(remote_refs, r.refs, refspecs, force=force)
             )

+ 2 - 0
dulwich/protocol.py

@@ -107,6 +107,8 @@ KNOWN_RECEIVE_CAPABILITIES = set(
     ]
 )
 
+DEPTH_INFINITE = 0x7FFFFFFF
+
 
 def agent_string():
     return ("dulwich/%d.%d.%d" % dulwich.__version__).encode("ascii")

+ 7 - 2
dulwich/repo.py

@@ -495,7 +495,7 @@ class BaseRepo(object):
                     refs[ref + ANNOTATED_TAG_SUFFIX] = obj.object[1]
                 refs[ref] = sha
 
-        wants = determine_wants(refs)
+        wants = determine_wants(refs, depth=depth)
         if not isinstance(wants, list):
             raise TypeError("determine_wants() did not return a list")
 
@@ -698,7 +698,12 @@ class BaseRepo(object):
             shallow.update(new_shallow)
         if new_unshallow:
             shallow.difference_update(new_unshallow)
-        self._put_named_file("shallow", b"".join([sha + b"\n" for sha in shallow]))
+        if shallow:
+            self._put_named_file(
+                "shallow", b"".join([sha + b"\n" for sha in shallow])
+            )
+        else:
+            self._del_named_file("shallow")
 
     def get_peeled(self, ref):
         """Get the peeled value of a ref.

+ 3 - 3
dulwich/server.py

@@ -372,8 +372,8 @@ class UploadPackHandler(PackHandler):
         )
         wants = []
 
-        def wants_wrapper(refs):
-            wants.extend(graph_walker.determine_wants(refs))
+        def wants_wrapper(refs, **kwargs):
+            wants.extend(graph_walker.determine_wants(refs, **kwargs))
             return wants
 
         objects_iter = self.repo.fetch_objects(
@@ -573,7 +573,7 @@ class _ProtocolGraphWalker(object):
         self._cache_index = 0
         self._impl = None
 
-    def determine_wants(self, heads):
+    def determine_wants(self, heads, depth=None):
         """Determine the wants for a set of heads.
 
         The given heads are advertised to the client, who then specifies which

+ 2 - 2
dulwich/tests/compat/test_client.py

@@ -276,7 +276,7 @@ class DulwichClientTestBase(object):
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
 
-            def dw(refs):
+            def dw(refs, **kwargs):
                 return list(refs.values())
 
             result = c.fetch(
@@ -317,7 +317,7 @@ class DulwichClientTestBase(object):
             result = c.fetch(
                 self._build_path("/server_new.export"),
                 dest,
-                lambda refs: [protocol.ZERO_SHA],
+                lambda refs, **kwargs: [protocol.ZERO_SHA],
             )
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])

+ 5 - 5
dulwich/tests/test_client.py

@@ -160,7 +160,7 @@ class GitClientTests(TestCase):
         self.rin.write(b"0000")
         self.rin.seek(0)
 
-        def check_heads(heads):
+        def check_heads(heads, **kwargs):
             self.assertEqual(heads, {})
             return []
 
@@ -178,7 +178,7 @@ class GitClientTests(TestCase):
         )
         self.rin.seek(0)
 
-        def check_heads(heads):
+        def check_heads(heads, **kwargs):
             self.assertEqual({}, heads)
             return []
 
@@ -195,7 +195,7 @@ class GitClientTests(TestCase):
             b"0000"
         )
         self.rin.seek(0)
-        ret = self.client.fetch_pack(b"bla", lambda heads: [], None, None, None)
+        ret = self.client.fetch_pack(b"bla", lambda heads, **kwargs: [], None, None, None)
         self.assertEqual(
             {b"HEAD": b"55dcc6bf963f922e1ed5c4bbaaefcfacef57b1d7"}, ret.refs
         )
@@ -831,7 +831,7 @@ class LocalGitClientTests(TestCase):
         out = BytesIO()
         walker = {}
         ret = c.fetch_pack(
-            s.path, lambda heads: [], graph_walker=walker, pack_data=out.write
+            s.path, lambda heads, **kwargs: [], graph_walker=walker, pack_data=out.write
         )
         self.assertEqual(
             {
@@ -857,7 +857,7 @@ class LocalGitClientTests(TestCase):
         walker = MemoryRepo().get_graph_walker()
         ret = c.fetch_pack(
             s.path,
-            lambda heads: [b"a90fa2d900a17e99b433217e988c4eb4a2e9a097"],
+            lambda heads, **kwargs: [b"a90fa2d900a17e99b433217e988c4eb4a2e9a097"],
             graph_walker=walker,
             pack_data=out.write,
         )

+ 49 - 0
dulwich/tests/test_object_store.py

@@ -23,6 +23,7 @@
 
 from contextlib import closing
 from io import BytesIO
+from unittest import skipUnless
 import os
 import shutil
 import stat
@@ -54,6 +55,7 @@ from dulwich.pack import (
     REF_DELTA,
     write_pack_objects,
 )
+from dulwich.protocol import DEPTH_INFINITE
 from dulwich.tests import (
     TestCase,
 )
@@ -63,6 +65,11 @@ from dulwich.tests.utils import (
     build_pack,
 )
 
+try:
+    from unittest.mock import patch
+except ImportError:
+    patch = None  # type: ignore
+
 
 testobject = make_object(Blob, data=b"yummy data")
 
@@ -79,6 +86,48 @@ class ObjectStoreTests(object):
             [], self.store.determine_wants_all({b"refs/heads/foo": b"0" * 40})
         )
 
+    @skipUnless(patch, "Required mock.patch")
+    def test_determine_wants_all_depth(self):
+        self.store.add_object(testobject)
+        refs = {b"refs/heads/foo": testobject.id}
+        with patch.object(self.store, "_get_depth", return_value=1) as m:
+            self.assertEqual(
+                [], self.store.determine_wants_all(refs, depth=0)
+            )
+            self.assertEqual(
+                [testobject.id],
+                self.store.determine_wants_all(refs, depth=DEPTH_INFINITE),
+            )
+            m.assert_not_called()
+
+            self.assertEqual(
+                [], self.store.determine_wants_all(refs, depth=1)
+            )
+            m.assert_called_with(testobject.id)
+            self.assertEqual(
+                [testobject.id], self.store.determine_wants_all(refs, depth=2)
+            )
+
+    def test_get_depth(self):
+        self.assertEqual(
+            0, self.store._get_depth(testobject.id)
+        )
+
+        self.store.add_object(testobject)
+        self.assertEqual(
+            1, self.store._get_depth(testobject.id, get_parents=lambda x: [])
+        )
+
+        parent = make_object(Blob, data=b"parent data")
+        self.store.add_object(parent)
+        self.assertEqual(
+            2,
+            self.store._get_depth(
+                testobject.id,
+                get_parents=lambda x: [parent.id] if x == testobject else [],
+            ),
+        )
+
     def test_iter(self):
         self.assertEqual([], list(self.store))
 

+ 8 - 0
dulwich/tests/test_repository.py

@@ -827,6 +827,14 @@ class BuildRepoRootTests(TestCase):
             {b"a90fa2d900a17e99b433217e988c4eb4a2e9a097"},
             self._repo.get_shallow(),
         )
+        self._repo.update_shallow(
+            None, [b"a90fa2d900a17e99b433217e988c4eb4a2e9a097"]
+        )
+        self.assertEqual(set(), self._repo.get_shallow())
+        self.assertEqual(
+            False,
+            os.path.exists(os.path.join(self._repo.controldir(), "shallow")),
+        )
 
     def test_build_repo(self):
         r = self._repo