Prechádzať zdrojové kódy

Merge branch 'fetch-deepen' of git://github.com/pmrowla/dulwich

Jelmer Vernooij 4 rokov pred
rodič
commit
a64f326a26

+ 4 - 1
docs/tutorial/remote.txt

@@ -41,10 +41,13 @@ The client object can then be used to retrieve a pack. The ``fetch_pack``
 method takes a ``determine_wants`` callback argument, which allows the
 method takes a ``determine_wants`` callback argument, which allows the
 client to determine which objects it wants to end up with::
 client to determine which objects it wants to end up with::
 
 
-   >>> def determine_wants(refs):
+   >>> def determine_wants(refs, depth=None):
    ...    # retrieve all objects
    ...    # retrieve all objects
    ...    return refs.values()
    ...    return refs.values()
 
 
+Note that the ``depth`` keyword argument will contain an optional requested
+shallow fetch depth.
+
 Another required object is a "graph walker", which is used to determine
 Another required object is a "graph walker", which is used to determine
 which objects that the client already has should not be sent again
 which objects that the client already has should not be sent again
 by the server. Here in the tutorial we'll just use a dummy graph walker
 by the server. Here in the tutorial we'll just use a dummy graph walker

+ 1 - 1
dulwich/cli.py

@@ -116,7 +116,7 @@ class cmd_fetch_pack(Command):
             determine_wants = r.object_store.determine_wants_all
             determine_wants = r.object_store.determine_wants_all
         else:
         else:
 
 
-            def determine_wants(x):
+            def determine_wants(x, **kwargs):
                 return [y for y in args if y not in r.object_store]
                 return [y for y in args if y not in r.object_store]
 
 
         client.fetch(path, r, determine_wants)
         client.fetch(path, r, determine_wants)

+ 2 - 2
dulwich/client.py

@@ -1023,7 +1023,7 @@ class TraditionalGitClient(GitClient):
                 return FetchPackResult(refs, symrefs, agent)
                 return FetchPackResult(refs, symrefs, agent)
 
 
             try:
             try:
-                wants = determine_wants(refs)
+                wants = determine_wants(refs, depth=depth)
             except BaseException:
             except BaseException:
                 proto.write_pkt_line(None)
                 proto.write_pkt_line(None)
                 raise
                 raise
@@ -2042,7 +2042,7 @@ class HttpGitClient(GitClient):
             symrefs,
             symrefs,
             agent,
             agent,
         ) = self._negotiate_upload_pack_capabilities(server_capabilities)
         ) = self._negotiate_upload_pack_capabilities(server_capabilities)
-        wants = determine_wants(refs)
+        wants = determine_wants(refs, depth=depth)
         if wants is not None:
         if wants is not None:
             wants = [cid for cid in wants if cid != ZERO_SHA]
             wants = [cid for cid in wants if cid != ZERO_SHA]
         if not wants:
         if not wants:

+ 7 - 7
dulwich/contrib/test_swift_smoke.py

@@ -127,7 +127,7 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         self.assertEqual(remote_refs, None)
         self.assertEqual(remote_refs, None)
 
 
     def test_push_commit(self):
     def test_push_commit(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {"refs/heads/master": local_repo.refs["HEAD"]}
             return {"refs/heads/master": local_repo.refs["HEAD"]}
 
 
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
@@ -144,7 +144,7 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         self.assertEqual(sha, remote_sha)
         self.assertEqual(sha, remote_sha)
 
 
     def test_push_branch(self):
     def test_push_branch(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {"refs/heads/mybranch": local_repo.refs["refs/heads/mybranch"]}
             return {"refs/heads/mybranch": local_repo.refs["refs/heads/mybranch"]}
 
 
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
@@ -161,7 +161,7 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         self.assertEqual(sha, remote_sha)
         self.assertEqual(sha, remote_sha)
 
 
     def test_push_multiple_branch(self):
     def test_push_multiple_branch(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {
             return {
                 "refs/heads/mybranch": local_repo.refs["refs/heads/mybranch"],
                 "refs/heads/mybranch": local_repo.refs["refs/heads/mybranch"],
                 "refs/heads/master": local_repo.refs["refs/heads/master"],
                 "refs/heads/master": local_repo.refs["refs/heads/master"],
@@ -191,7 +191,7 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         self.assertDictEqual(local_shas, remote_shas)
         self.assertDictEqual(local_shas, remote_shas)
 
 
     def test_push_data_branch(self):
     def test_push_data_branch(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {"refs/heads/master": local_repo.refs["HEAD"]}
             return {"refs/heads/master": local_repo.refs["HEAD"]}
 
 
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
         local_repo = repo.Repo.init(self.temp_d, mkdir=True)
@@ -243,7 +243,7 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         for f in files:
         for f in files:
             self.assertEqual(os.path.isfile(f), True)
             self.assertEqual(os.path.isfile(f), True)
 
 
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {"refs/heads/master": local_repo.refs["HEAD"]}
             return {"refs/heads/master": local_repo.refs["HEAD"]}
 
 
         os.mkdir(os.path.join(self.temp_d, "test"))
         os.mkdir(os.path.join(self.temp_d, "test"))
@@ -259,7 +259,7 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         )
         )
 
 
     def test_push_remove_branch(self):
     def test_push_remove_branch(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {
             return {
                 "refs/heads/pullr-108": objects.ZERO_SHA,
                 "refs/heads/pullr-108": objects.ZERO_SHA,
                 "refs/heads/master": local_repo.refs["refs/heads/master"],
                 "refs/heads/master": local_repo.refs["refs/heads/master"],
@@ -276,7 +276,7 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         self.assertNotIn("refs/heads/pullr-108", swift_repo.refs.allkeys())
         self.assertNotIn("refs/heads/pullr-108", swift_repo.refs.allkeys())
 
 
     def test_push_annotated_tag(self):
     def test_push_annotated_tag(self):
-        def determine_wants(*args):
+        def determine_wants(*args, **kwargs):
             return {
             return {
                 "refs/heads/master": local_repo.refs["HEAD"],
                 "refs/heads/master": local_repo.refs["HEAD"],
                 "refs/tags/v1.0": local_repo.refs["refs/tags/v1.0"],
                 "refs/tags/v1.0": local_repo.refs["refs/tags/v1.0"],

+ 44 - 3
dulwich/object_store.py

@@ -64,6 +64,7 @@ from dulwich.pack import (
     PackIndexer,
     PackIndexer,
     PackStreamCopier,
     PackStreamCopier,
 )
 )
+from dulwich.protocol import DEPTH_INFINITE
 from dulwich.refs import ANNOTATED_TAG_SUFFIX
 from dulwich.refs import ANNOTATED_TAG_SUFFIX
 
 
 INFODIR = "info"
 INFODIR = "info"
@@ -73,11 +74,18 @@ PACKDIR = "pack"
 class BaseObjectStore(object):
 class BaseObjectStore(object):
     """Object store interface."""
     """Object store interface."""
 
 
-    def determine_wants_all(self, refs):
+    def determine_wants_all(self, refs, depth=None):
+        def _want_deepen(sha):
+            if not depth:
+                return False
+            if depth == DEPTH_INFINITE:
+                return True
+            return depth > self._get_depth(sha)
+
         return [
         return [
             sha
             sha
             for (ref, sha) in refs.items()
             for (ref, sha) in refs.items()
-            if sha not in self
+            if (sha not in self or _want_deepen(sha))
             and not ref.endswith(ANNOTATED_TAG_SUFFIX)
             and not ref.endswith(ANNOTATED_TAG_SUFFIX)
             and not sha == ZERO_SHA
             and not sha == ZERO_SHA
         ]
         ]
@@ -350,6 +358,36 @@ class BaseObjectStore(object):
                 queue.extend(get_parents(cmt))
                 queue.extend(get_parents(cmt))
         return (commits, bases)
         return (commits, bases)
 
 
+    def _get_depth(
+        self, head, get_parents=lambda commit: commit.parents, max_depth=None,
+    ):
+        """Return the current available depth for the given head.
+        For commits with multiple parents, the largest possible depth will be
+        returned.
+
+        Args:
+            head: commit to start from
+            get_parents: optional function for getting the parents of a commit
+            max_depth: maximum depth to search
+        """
+        if head not in self:
+            return 0
+        current_depth = 1
+        queue = [(head, current_depth)]
+        while queue and (max_depth is None or current_depth < max_depth):
+            e, depth = queue.pop(0)
+            current_depth = max(current_depth, depth)
+            cmt = self[e]
+            if isinstance(cmt, Tag):
+                _cls, sha = cmt.object
+                cmt = self[sha]
+            queue.extend(
+                (parent, depth + 1)
+                for parent in get_parents(cmt)
+                if parent in self
+            )
+        return current_depth
+
     def close(self):
     def close(self):
         """Close any files opened by this object store."""
         """Close any files opened by this object store."""
         # Default implementation is a NO-OP
         # Default implementation is a NO-OP
@@ -1353,7 +1391,10 @@ class ObjectStoreGraphWalker(object):
         """Iterate over ancestors of heads in the target."""
         """Iterate over ancestors of heads in the target."""
         if self.heads:
         if self.heads:
             ret = self.heads.pop()
             ret = self.heads.pop()
-            ps = self.get_parents(ret)
+            try:
+                ps = self.get_parents(ret)
+            except KeyError:
+                return None
             self.parents[ret] = ps
             self.parents[ret] = ps
             self.heads.update([p for p in ps if p not in self.parents])
             self.heads.update([p for p in ps if p not in self.parents])
             return ret
             return ret

+ 1 - 1
dulwich/porcelain.py

@@ -1154,7 +1154,7 @@ def pull(
             refspecs = [b"HEAD"]
             refspecs = [b"HEAD"]
         selected_refs = []
         selected_refs = []
 
 
-        def determine_wants(remote_refs):
+        def determine_wants(remote_refs, **kwargs):
             selected_refs.extend(
             selected_refs.extend(
                 parse_reftuples(remote_refs, r.refs, refspecs, force=force)
                 parse_reftuples(remote_refs, r.refs, refspecs, force=force)
             )
             )

+ 2 - 0
dulwich/protocol.py

@@ -107,6 +107,8 @@ KNOWN_RECEIVE_CAPABILITIES = set(
     ]
     ]
 )
 )
 
 
+DEPTH_INFINITE = 0x7FFFFFFF
+
 
 
 def agent_string():
 def agent_string():
     return ("dulwich/%d.%d.%d" % dulwich.__version__).encode("ascii")
     return ("dulwich/%d.%d.%d" % dulwich.__version__).encode("ascii")

+ 7 - 2
dulwich/repo.py

@@ -495,7 +495,7 @@ class BaseRepo(object):
                     refs[ref + ANNOTATED_TAG_SUFFIX] = obj.object[1]
                     refs[ref + ANNOTATED_TAG_SUFFIX] = obj.object[1]
                 refs[ref] = sha
                 refs[ref] = sha
 
 
-        wants = determine_wants(refs)
+        wants = determine_wants(refs, depth=depth)
         if not isinstance(wants, list):
         if not isinstance(wants, list):
             raise TypeError("determine_wants() did not return a list")
             raise TypeError("determine_wants() did not return a list")
 
 
@@ -698,7 +698,12 @@ class BaseRepo(object):
             shallow.update(new_shallow)
             shallow.update(new_shallow)
         if new_unshallow:
         if new_unshallow:
             shallow.difference_update(new_unshallow)
             shallow.difference_update(new_unshallow)
-        self._put_named_file("shallow", b"".join([sha + b"\n" for sha in shallow]))
+        if shallow:
+            self._put_named_file(
+                "shallow", b"".join([sha + b"\n" for sha in shallow])
+            )
+        else:
+            self._del_named_file("shallow")
 
 
     def get_peeled(self, ref):
     def get_peeled(self, ref):
         """Get the peeled value of a ref.
         """Get the peeled value of a ref.

+ 3 - 3
dulwich/server.py

@@ -372,8 +372,8 @@ class UploadPackHandler(PackHandler):
         )
         )
         wants = []
         wants = []
 
 
-        def wants_wrapper(refs):
-            wants.extend(graph_walker.determine_wants(refs))
+        def wants_wrapper(refs, **kwargs):
+            wants.extend(graph_walker.determine_wants(refs, **kwargs))
             return wants
             return wants
 
 
         objects_iter = self.repo.fetch_objects(
         objects_iter = self.repo.fetch_objects(
@@ -573,7 +573,7 @@ class _ProtocolGraphWalker(object):
         self._cache_index = 0
         self._cache_index = 0
         self._impl = None
         self._impl = None
 
 
-    def determine_wants(self, heads):
+    def determine_wants(self, heads, depth=None):
         """Determine the wants for a set of heads.
         """Determine the wants for a set of heads.
 
 
         The given heads are advertised to the client, who then specifies which
         The given heads are advertised to the client, who then specifies which

+ 2 - 2
dulwich/tests/compat/test_client.py

@@ -276,7 +276,7 @@ class DulwichClientTestBase(object):
                 dest.refs.set_if_equals(r[0], None, r[1])
                 dest.refs.set_if_equals(r[0], None, r[1])
             self.assertDestEqualsSrc()
             self.assertDestEqualsSrc()
 
 
-            def dw(refs):
+            def dw(refs, **kwargs):
                 return list(refs.values())
                 return list(refs.values())
 
 
             result = c.fetch(
             result = c.fetch(
@@ -317,7 +317,7 @@ class DulwichClientTestBase(object):
             result = c.fetch(
             result = c.fetch(
                 self._build_path("/server_new.export"),
                 self._build_path("/server_new.export"),
                 dest,
                 dest,
-                lambda refs: [protocol.ZERO_SHA],
+                lambda refs, **kwargs: [protocol.ZERO_SHA],
             )
             )
             for r in result.refs.items():
             for r in result.refs.items():
                 dest.refs.set_if_equals(r[0], None, r[1])
                 dest.refs.set_if_equals(r[0], None, r[1])

+ 5 - 5
dulwich/tests/test_client.py

@@ -160,7 +160,7 @@ class GitClientTests(TestCase):
         self.rin.write(b"0000")
         self.rin.write(b"0000")
         self.rin.seek(0)
         self.rin.seek(0)
 
 
-        def check_heads(heads):
+        def check_heads(heads, **kwargs):
             self.assertEqual(heads, {})
             self.assertEqual(heads, {})
             return []
             return []
 
 
@@ -178,7 +178,7 @@ class GitClientTests(TestCase):
         )
         )
         self.rin.seek(0)
         self.rin.seek(0)
 
 
-        def check_heads(heads):
+        def check_heads(heads, **kwargs):
             self.assertEqual({}, heads)
             self.assertEqual({}, heads)
             return []
             return []
 
 
@@ -195,7 +195,7 @@ class GitClientTests(TestCase):
             b"0000"
             b"0000"
         )
         )
         self.rin.seek(0)
         self.rin.seek(0)
-        ret = self.client.fetch_pack(b"bla", lambda heads: [], None, None, None)
+        ret = self.client.fetch_pack(b"bla", lambda heads, **kwargs: [], None, None, None)
         self.assertEqual(
         self.assertEqual(
             {b"HEAD": b"55dcc6bf963f922e1ed5c4bbaaefcfacef57b1d7"}, ret.refs
             {b"HEAD": b"55dcc6bf963f922e1ed5c4bbaaefcfacef57b1d7"}, ret.refs
         )
         )
@@ -831,7 +831,7 @@ class LocalGitClientTests(TestCase):
         out = BytesIO()
         out = BytesIO()
         walker = {}
         walker = {}
         ret = c.fetch_pack(
         ret = c.fetch_pack(
-            s.path, lambda heads: [], graph_walker=walker, pack_data=out.write
+            s.path, lambda heads, **kwargs: [], graph_walker=walker, pack_data=out.write
         )
         )
         self.assertEqual(
         self.assertEqual(
             {
             {
@@ -857,7 +857,7 @@ class LocalGitClientTests(TestCase):
         walker = MemoryRepo().get_graph_walker()
         walker = MemoryRepo().get_graph_walker()
         ret = c.fetch_pack(
         ret = c.fetch_pack(
             s.path,
             s.path,
-            lambda heads: [b"a90fa2d900a17e99b433217e988c4eb4a2e9a097"],
+            lambda heads, **kwargs: [b"a90fa2d900a17e99b433217e988c4eb4a2e9a097"],
             graph_walker=walker,
             graph_walker=walker,
             pack_data=out.write,
             pack_data=out.write,
         )
         )

+ 49 - 0
dulwich/tests/test_object_store.py

@@ -23,6 +23,7 @@
 
 
 from contextlib import closing
 from contextlib import closing
 from io import BytesIO
 from io import BytesIO
+from unittest import skipUnless
 import os
 import os
 import shutil
 import shutil
 import stat
 import stat
@@ -54,6 +55,7 @@ from dulwich.pack import (
     REF_DELTA,
     REF_DELTA,
     write_pack_objects,
     write_pack_objects,
 )
 )
+from dulwich.protocol import DEPTH_INFINITE
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
 )
 )
@@ -63,6 +65,11 @@ from dulwich.tests.utils import (
     build_pack,
     build_pack,
 )
 )
 
 
+try:
+    from unittest.mock import patch
+except ImportError:
+    patch = None  # type: ignore
+
 
 
 testobject = make_object(Blob, data=b"yummy data")
 testobject = make_object(Blob, data=b"yummy data")
 
 
@@ -79,6 +86,48 @@ class ObjectStoreTests(object):
             [], self.store.determine_wants_all({b"refs/heads/foo": b"0" * 40})
             [], self.store.determine_wants_all({b"refs/heads/foo": b"0" * 40})
         )
         )
 
 
+    @skipUnless(patch, "Required mock.patch")
+    def test_determine_wants_all_depth(self):
+        self.store.add_object(testobject)
+        refs = {b"refs/heads/foo": testobject.id}
+        with patch.object(self.store, "_get_depth", return_value=1) as m:
+            self.assertEqual(
+                [], self.store.determine_wants_all(refs, depth=0)
+            )
+            self.assertEqual(
+                [testobject.id],
+                self.store.determine_wants_all(refs, depth=DEPTH_INFINITE),
+            )
+            m.assert_not_called()
+
+            self.assertEqual(
+                [], self.store.determine_wants_all(refs, depth=1)
+            )
+            m.assert_called_with(testobject.id)
+            self.assertEqual(
+                [testobject.id], self.store.determine_wants_all(refs, depth=2)
+            )
+
+    def test_get_depth(self):
+        self.assertEqual(
+            0, self.store._get_depth(testobject.id)
+        )
+
+        self.store.add_object(testobject)
+        self.assertEqual(
+            1, self.store._get_depth(testobject.id, get_parents=lambda x: [])
+        )
+
+        parent = make_object(Blob, data=b"parent data")
+        self.store.add_object(parent)
+        self.assertEqual(
+            2,
+            self.store._get_depth(
+                testobject.id,
+                get_parents=lambda x: [parent.id] if x == testobject else [],
+            ),
+        )
+
     def test_iter(self):
     def test_iter(self):
         self.assertEqual([], list(self.store))
         self.assertEqual([], list(self.store))
 
 

+ 8 - 0
dulwich/tests/test_repository.py

@@ -827,6 +827,14 @@ class BuildRepoRootTests(TestCase):
             {b"a90fa2d900a17e99b433217e988c4eb4a2e9a097"},
             {b"a90fa2d900a17e99b433217e988c4eb4a2e9a097"},
             self._repo.get_shallow(),
             self._repo.get_shallow(),
         )
         )
+        self._repo.update_shallow(
+            None, [b"a90fa2d900a17e99b433217e988c4eb4a2e9a097"]
+        )
+        self.assertEqual(set(), self._repo.get_shallow())
+        self.assertEqual(
+            False,
+            os.path.exists(os.path.join(self._repo.controldir(), "shallow")),
+        )
 
 
     def test_build_repo(self):
     def test_build_repo(self):
         r = self._repo
         r = self._repo