Ver código fonte

Fix pushing from a shallow clone

send_pack now also takes the set of shallow commits to determine the
missing objects.
Brecht Machiels 5 anos atrás
pai
commit
6fc954bb53

+ 12 - 8
dulwich/client.py

@@ -361,7 +361,7 @@ class GitClient(object):
         """
         raise NotImplementedError(cls.from_parsedurl)
 
-    def send_pack(self, path, update_refs, generate_pack_data,
+    def send_pack(self, path, update_refs, shallow, generate_pack_data,
                   progress=None):
         """Upload a pack to a remote repository.
 
@@ -370,6 +370,7 @@ class GitClient(object):
           update_refs: Function to determine changes to remote refs. Receive
             dict with existing remote refs, returns dict with
             changed refs (name -> sha, where sha=ZERO_SHA for deletions)
+          shallow: Set of shallow commits generate_pack_data should skip
           generate_pack_data: Function that can return a tuple
             with number of objects and list of pack data to include
           progress: Optional progress function
@@ -772,7 +773,7 @@ class TraditionalGitClient(GitClient):
         """
         raise NotImplementedError()
 
-    def send_pack(self, path, update_refs, generate_pack_data,
+    def send_pack(self, path, update_refs, shallow, generate_pack_data,
                   progress=None):
         """Upload a pack to a remote repository.
 
@@ -781,6 +782,7 @@ class TraditionalGitClient(GitClient):
           update_refs: Function to determine changes to remote refs.
         Receive dict with existing remote refs, returns dict with
         changed refs (name -> sha, where sha=ZERO_SHA for deletions)
+          shallow: Set of shallow commits generate_pack_data should skip
           generate_pack_data: Function that can return a tuple with
         number of objects and pack data to upload.
           progress: Optional callback called with progress updates
@@ -842,7 +844,7 @@ class TraditionalGitClient(GitClient):
                     set(new_refs.items()).issubset(set(old_refs.items()))):
                 return new_refs
             pack_data_count, pack_data = generate_pack_data(
-                have, want,
+                have, want, shallow,
                 ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities))
 
             dowrite = bool(pack_data_count)
@@ -1116,7 +1118,7 @@ class LocalGitClient(GitClient):
             path = path.decode(sys.getfilesystemencoding())
         return closing(Repo(path))
 
-    def send_pack(self, path, update_refs, generate_pack_data,
+    def send_pack(self, path, update_refs, shallow, generate_pack_data,
                   progress=None):
         """Upload a pack to a remote repository.
 
@@ -1125,6 +1127,7 @@ class LocalGitClient(GitClient):
           update_refs: Function to determine changes to remote refs.
         Receive dict with existing remote refs, returns dict with
         changed refs (name -> sha, where sha=ZERO_SHA for deletions)
+          shallow: Set of shallow commits generate_pack_data should skip
           generate_pack_data: Function that can return a tuple
         with number of items and pack data to upload.
           progress: Optional progress function
@@ -1160,7 +1163,7 @@ class LocalGitClient(GitClient):
                 return new_refs
 
             target.object_store.add_pack_data(
-                *generate_pack_data(have, want, ofs_delta=True))
+                *generate_pack_data(have, want, shallow, ofs_delta=True))
 
             for refname, new_sha1 in new_refs.items():
                 old_sha1 = old_refs.get(refname, ZERO_SHA)
@@ -1670,15 +1673,16 @@ class HttpGitClient(GitClient):
                                    % resp.content_type)
         return resp, read
 
-    def send_pack(self, path, update_refs, generate_pack_data,
+    def send_pack(self, path, update_refs, shallow, generate_pack_data,
                   progress=None):
         """Upload a pack to a remote repository.
 
         Args:
           path: Repository path (as bytestring)
           update_refs: Function to determine changes to remote refs.
-        Receive dict with existing remote refs, returns dict with
+        Receives dict with existing remote refs, returns dict with
         changed refs (name -> sha, where sha=ZERO_SHA for deletions)
+          shallow: Set of shallow commits generate_pack_data should skip
           generate_pack_data: Function that can return a tuple
         with number of elements and pack data to upload.
           progress: Optional progress function
@@ -1716,7 +1720,7 @@ class HttpGitClient(GitClient):
         if not want and set(new_refs.items()).issubset(set(old_refs.items())):
             return new_refs
         pack_data_count, pack_data = generate_pack_data(
-                have, want,
+                have, want, shallow,
                 ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities))
         if pack_data_count:
             write_pack_data(req_proto.write_file(), pack_data_count, pack_data)

+ 14 - 14
dulwich/contrib/test_swift_smoke.py

@@ -140,8 +140,8 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         swift.SwiftRepo.init_bare(self.scon, self.conf)
         tcp_client = client.TCPGitClient(self.server_address,
                                          port=self.port)
-        tcp_client.send_pack(self.fakerepo,
-                             determine_wants,
+        tcp_client.send_pack(self.fakerepo, determine_wants,
+                             local_repo.get_shallow(),
                              local_repo.object_store.generate_pack_data)
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
         remote_sha = swift_repo.refs.read_loose_ref('refs/heads/master')
@@ -160,8 +160,8 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         swift.SwiftRepo.init_bare(self.scon, self.conf)
         tcp_client = client.TCPGitClient(self.server_address,
                                          port=self.port)
-        tcp_client.send_pack("/fakerepo",
-                             determine_wants,
+        tcp_client.send_pack("/fakerepo", determine_wants,
+                             local_repo.get_shallow(),
                              local_repo.object_store.generate_pack_data)
         swift_repo = swift.SwiftRepo(self.fakerepo, self.conf)
         remote_sha = swift_repo.refs.read_loose_ref('refs/heads/mybranch')
@@ -187,8 +187,8 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         swift.SwiftRepo.init_bare(self.scon, self.conf)
         tcp_client = client.TCPGitClient(self.server_address,
                                          port=self.port)
-        tcp_client.send_pack(self.fakerepo,
-                             determine_wants,
+        tcp_client.send_pack(self.fakerepo, determine_wants,
+                             local_repo.get_shallow(),
                              local_repo.object_store.generate_pack_data)
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
         for branch in ('master', 'mybranch', 'pullr-108'):
@@ -212,8 +212,8 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         swift.SwiftRepo.init_bare(self.scon, self.conf)
         tcp_client = client.TCPGitClient(self.server_address,
                                          port=self.port)
-        tcp_client.send_pack(self.fakerepo,
-                             determine_wants,
+        tcp_client.send_pack(self.fakerepo, determine_wants,
+                             local_repo.get_shallow(),
                              local_repo.object_store.generate_pack_data)
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
         commit_sha = swift_repo.refs.read_loose_ref('refs/heads/master')
@@ -259,8 +259,8 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         local_repo.stage(files)
         local_repo.do_commit('Test commit', 'fbo@localhost',
                              ref='refs/heads/master')
-        tcp_client.send_pack("/fakerepo",
-                             determine_wants,
+        tcp_client.send_pack("/fakerepo", determine_wants,
+                             local_repo.get_shallow(),
                              local_repo.object_store.generate_pack_data)
 
     def test_push_remove_branch(self):
@@ -275,8 +275,8 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         local_repo = repo.Repo(self.temp_d)
         tcp_client = client.TCPGitClient(self.server_address,
                                          port=self.port)
-        tcp_client.send_pack(self.fakerepo,
-                             determine_wants,
+        tcp_client.send_pack(self.fakerepo, determine_wants,
+                             local_repo.get_shallow(),
                              local_repo.object_store.generate_pack_data)
         swift_repo = swift.SwiftRepo("fakerepo", self.conf)
         self.assertNotIn('refs/heads/pullr-108', swift_repo.refs.allkeys())
@@ -302,8 +302,8 @@ class SwiftRepoSmokeTest(unittest.TestCase):
         swift.SwiftRepo.init_bare(self.scon, self.conf)
         tcp_client = client.TCPGitClient(self.server_address,
                                          port=self.port)
-        tcp_client.send_pack(self.fakerepo,
-                             determine_wants,
+        tcp_client.send_pack(self.fakerepo, determine_wants,
+                             local_repo.get_shallow(),
                              local_repo.object_store.generate_pack_data)
         swift_repo = swift.SwiftRepo(self.fakerepo, self.conf)
         tag_sha = swift_repo.refs.read_loose_ref('refs/tags/v1.0')

+ 19 - 11
dulwich/object_store.py

@@ -201,7 +201,7 @@ class BaseObjectStore(object):
                  not stat.S_ISDIR(entry.mode)) or include_trees):
                 yield entry
 
-    def find_missing_objects(self, haves, wants, progress=None,
+    def find_missing_objects(self, haves, wants, shallow, progress=None,
                              get_tagged=None,
                              get_parents=lambda commit: commit.parents,
                              depth=None):
@@ -210,6 +210,7 @@ class BaseObjectStore(object):
         Args:
           haves: Iterable over SHAs already in common.
           wants: Iterable over SHAs of objects to fetch.
+          shallow: Set of shallow commit SHA1s to skip
           progress: Simple progress function that will be called with
             updated progress strings.
           get_tagged: Function that returns a dict of pointed-to sha ->
@@ -218,8 +219,8 @@ class BaseObjectStore(object):
             commit.
         Returns: Iterator over (sha, path) pairs.
         """
-        finder = MissingObjectFinder(self, haves, wants, progress, get_tagged,
-                                     get_parents=get_parents)
+        finder = MissingObjectFinder(self, haves, wants, shallow, progress,
+                                     get_tagged, get_parents=get_parents)
         return iter(finder.next, None)
 
     def find_common_revisions(self, graphwalker):
@@ -238,28 +239,32 @@ class BaseObjectStore(object):
             sha = next(graphwalker)
         return haves
 
-    def generate_pack_contents(self, have, want, progress=None):
+    def generate_pack_contents(self, have, want, shallow, progress=None):
         """Iterate over the contents of a pack file.
 
         Args:
           have: List of SHA1s of objects that should not be sent
           want: List of SHA1s of objects that should be sent
+          shallow: Set of shallow commit SHA1s to skip
           progress: Optional progress reporting method
         """
-        return self.iter_shas(self.find_missing_objects(have, want, progress))
+        missing = self.find_missing_objects(have, want, shallow, progress)
+        return self.iter_shas(missing)
 
-    def generate_pack_data(self, have, want, progress=None, ofs_delta=True):
+    def generate_pack_data(self, have, want, shallow, progress=None,
+                           ofs_delta=True):
         """Generate pack data objects for a set of wants/haves.
 
         Args:
           have: List of SHA1s of objects that should not be sent
           want: List of SHA1s of objects that should be sent
+          shallow: Set of shallow commit SHA1s to skip
           ofs_delta: Whether OFS deltas can be included
           progress: Optional progress reporting method
         """
         # TODO(jelmer): More efficient implementation
         return pack_objects_to_data(
-            self.generate_pack_contents(have, want, progress))
+            self.generate_pack_contents(have, want, shallow, progress))
 
     def peel_sha(self, sha):
         """Peel all tags from a SHA.
@@ -277,7 +282,7 @@ class BaseObjectStore(object):
             obj = self[sha]
         return obj
 
-    def _collect_ancestors(self, heads, common=set(),
+    def _collect_ancestors(self, heads, common=set(), shallow=set(),
                            get_parents=lambda commit: commit.parents):
         """Collect all ancestors of heads up to (excluding) those in common.
 
@@ -301,6 +306,8 @@ class BaseObjectStore(object):
                 bases.add(e)
             elif e not in commits:
                 commits.add(e)
+                if e in shallow:
+                    continue
                 cmt = self[e]
                 queue.extend(get_parents(cmt))
         return (commits, bases)
@@ -1162,7 +1169,7 @@ class MissingObjectFinder(object):
       tagged: dict of pointed-to sha -> tag sha for including tags
     """
 
-    def __init__(self, object_store, haves, wants, progress=None,
+    def __init__(self, object_store, haves, wants, shallow, progress=None,
                  get_tagged=None, get_parents=lambda commit: commit.parents):
         self.object_store = object_store
         self._get_parents = get_parents
@@ -1178,12 +1185,13 @@ class MissingObjectFinder(object):
         # all_ancestors is a set of commits that shall not be sent
         # (complete repository up to 'haves')
         all_ancestors = object_store._collect_ancestors(
-            have_commits, get_parents=self._get_parents)[0]
+            have_commits, shallow=shallow, get_parents=self._get_parents)[0]
         # all_missing - complete set of commits between haves and wants
         # common - commits from all_ancestors we hit into while
         # traversing parent hierarchy of wants
         missing_commits, common_commits = object_store._collect_ancestors(
-            want_commits, all_ancestors, get_parents=self._get_parents)
+            want_commits, all_ancestors, shallow=shallow,
+            get_parents=self._get_parents)
         self.sha_done = set()
         # Now, fill sha_done with commits and revisions of
         # files and directories known to be both locally

+ 1 - 1
dulwich/porcelain.py

@@ -911,7 +911,7 @@ def push(repo, remote_location, refspecs,
         remote_location_bytes = client.get_url(path).encode(err_encoding)
         try:
             client.send_pack(
-                path, update_refs,
+                path, update_refs, r.get_shallow(),
                 generate_pack_data=r.object_store.generate_pack_data,
                 progress=errstream.write)
             errstream.write(

+ 2 - 2
dulwich/repo.py

@@ -468,8 +468,8 @@ class BaseRepo(object):
 
         return self.object_store.iter_shas(
           self.object_store.find_missing_objects(
-              haves, wants, progress,
-              get_tagged,
+              haves, wants, self.get_shallow(),
+              progress, get_tagged,
               get_parents=get_parents))
 
     def get_graph_walker(self, heads=None):

+ 9 - 8
dulwich/tests/compat/test_client.py

@@ -25,6 +25,7 @@ from io import BytesIO
 import os
 import select
 import signal
+import stat
 import subprocess
 import sys
 import tarfile
@@ -106,7 +107,7 @@ class DulwichClientTestBase(object):
             sendrefs = dict(src.get_refs())
             del sendrefs[b'HEAD']
             c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
-                        src.object_store.generate_pack_data)
+                        src.get_shallow(), src.object_store.generate_pack_data)
 
     def test_send_pack(self):
         self._do_send_pack()
@@ -152,7 +153,7 @@ class DulwichClientTestBase(object):
                     tree=tree_id)
             sendrefs = dict(local.get_refs())
             del sendrefs[b'HEAD']
-            c.send_pack(remote_path, lambda _: sendrefs,
+            c.send_pack(remote_path, lambda _: sendrefs, local.get_shallow(),
                         local.object_store.generate_pack_data)
         with repo.Repo(server_new_path) as remote:
             self.assertEqual(remote.head(), commit_id)
@@ -165,7 +166,7 @@ class DulwichClientTestBase(object):
             sendrefs = dict(src.get_refs())
             del sendrefs[b'HEAD']
             c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
-                        src.object_store.generate_pack_data)
+                        src.get_shallow(), src.object_store.generate_pack_data)
             self.assertDestEqualsSrc()
 
     def make_dummy_commit(self, dest):
@@ -202,8 +203,8 @@ class DulwichClientTestBase(object):
             sendrefs, gen_pack = self.compute_send(src)
             c = self._client()
             try:
-                c.send_pack(self._build_path('/dest'),
-                            lambda _: sendrefs, gen_pack)
+                c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
+                            src.get_shallow(), gen_pack)
             except errors.UpdateRefsError as e:
                 self.assertEqual('refs/heads/master failed to update',
                                  e.args[0])
@@ -222,7 +223,7 @@ class DulwichClientTestBase(object):
             c = self._client()
             try:
                 c.send_pack(self._build_path('/dest'), lambda _: sendrefs,
-                            gen_pack)
+                            src.get_shallow(), gen_pack)
             except errors.UpdateRefsError as e:
                 self.assertIn(
                         str(e),
@@ -315,12 +316,12 @@ class DulwichClientTestBase(object):
             sendrefs[b'refs/heads/abranch'] = b"00" * 20
             del sendrefs[b'HEAD']
 
-            def gen_pack(have, want, ofs_delta=False):
+            def gen_pack(have, want, shallow, ofs_delta=False):
                 return 0, []
             c = self._client()
             self.assertEqual(dest.refs[b"refs/heads/abranch"], dummy_commit)
             c.send_pack(
-                self._build_path('/dest'), lambda _: sendrefs, gen_pack)
+                self._build_path('/dest'), lambda _: sendrefs, set(), gen_pack)
             self.assertFalse(b"refs/heads/abranch" in dest.refs)
 
     def test_get_refs(self):

+ 16 - 15
dulwich/tests/test_client.py

@@ -223,12 +223,12 @@ class GitClientTests(TestCase):
         def update_refs(refs):
             return {b'refs/foo/bar': commit.id, }
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, shallow, ofs_delta=False):
             return pack_objects_to_data([(commit, None), (tree, ''), ])
 
         self.assertRaises(UpdateRefsError,
                           self.client.send_pack, "blah",
-                          update_refs, generate_pack_data)
+                          update_refs, set(), generate_pack_data)
 
     def test_send_pack_none(self):
         self.rin.write(
@@ -244,10 +244,10 @@ class GitClientTests(TestCase):
                     b'310ca9477129b8586fa2afc779c1f57cf64bba6c'
             }
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, shallow, ofs_delta=False):
             return 0, []
 
-        self.client.send_pack(b'/', update_refs, generate_pack_data)
+        self.client.send_pack(b'/', update_refs, set(), generate_pack_data)
         self.assertEqual(self.rout.getvalue(), b'0000')
 
     def test_send_pack_keep_and_delete(self):
@@ -263,10 +263,10 @@ class GitClientTests(TestCase):
         def update_refs(refs):
             return {b'refs/heads/master': b'0' * 40}
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, shallow, ofs_delta=False):
             return 0, []
 
-        self.client.send_pack(b'/', update_refs, generate_pack_data)
+        self.client.send_pack(b'/', update_refs, set(), generate_pack_data)
         self.assertEqual(
             self.rout.getvalue(),
             b'008b310ca9477129b8586fa2afc779c1f57cf64bba6c '
@@ -285,10 +285,10 @@ class GitClientTests(TestCase):
         def update_refs(refs):
             return {b'refs/heads/master': b'0' * 40}
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, shallow, ofs_delta=False):
             return 0, []
 
-        self.client.send_pack(b'/', update_refs, generate_pack_data)
+        self.client.send_pack(b'/', update_refs, set(), generate_pack_data)
         self.assertEqual(
             self.rout.getvalue(),
             b'008b310ca9477129b8586fa2afc779c1f57cf64bba6c '
@@ -312,12 +312,12 @@ class GitClientTests(TestCase):
                     b'310ca9477129b8586fa2afc779c1f57cf64bba6c'
             }
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, shallow, ofs_delta=False):
             return 0, []
 
         f = BytesIO()
         write_pack_objects(f, {})
-        self.client.send_pack('/', update_refs, generate_pack_data)
+        self.client.send_pack('/', update_refs, set(), generate_pack_data)
         self.assertEqual(
             self.rout.getvalue(),
             b'008b0000000000000000000000000000000000000000 '
@@ -351,12 +351,12 @@ class GitClientTests(TestCase):
                     b'310ca9477129b8586fa2afc779c1f57cf64bba6c'
             }
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, shallow, ofs_delta=False):
             return pack_objects_to_data([(commit, None), (tree, b''), ])
 
         f = BytesIO()
-        write_pack_data(f, *generate_pack_data(None, None))
-        self.client.send_pack(b'/', update_refs, generate_pack_data)
+        write_pack_data(f, *generate_pack_data(None, None, set()))
+        self.client.send_pack(b'/', update_refs, set(), generate_pack_data)
         self.assertEqual(
             self.rout.getvalue(),
             b'008b0000000000000000000000000000000000000000 ' + commit.id +
@@ -378,12 +378,12 @@ class GitClientTests(TestCase):
         def update_refs(refs):
             return {b'refs/heads/master': b'0' * 40}
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, shallow, ofs_delta=False):
             return 0, []
 
         self.assertRaises(UpdateRefsError,
                           self.client.send_pack, b"/",
-                          update_refs, generate_pack_data)
+                          update_refs, set(), generate_pack_data)
         self.assertEqual(self.rout.getvalue(), b'0000')
 
 
@@ -873,6 +873,7 @@ class LocalGitClientTests(TestCase):
         ref_name = b"refs/heads/" + branch
         new_refs = client.send_pack(target.path,
                                     lambda _: {ref_name: local.refs[ref_name]},
+                                    local.get_shallow(),
                                     local.object_store.generate_pack_data)
 
         self.assertEqual(local.refs[ref_name], new_refs[ref_name])

+ 2 - 2
dulwich/tests/test_missing_obj_finder.py

@@ -43,7 +43,7 @@ class MissingObjectFinderTest(TestCase):
         return self.commits[n-1]
 
     def assertMissingMatch(self, haves, wants, expected):
-        for sha, path in self.store.find_missing_objects(haves, wants):
+        for sha, path in self.store.find_missing_objects(haves, wants, set()):
             self.assertTrue(
                     sha in expected,
                     "(%s,%s) erroneously reported as missing" % (sha, path))
@@ -112,7 +112,7 @@ class MOFLinearRepoTest(MissingObjectFinderTest):
         haves = [self.cmt(1).id]
         wants = [self.cmt(3).id, bogus_sha]
         self.assertRaises(
-                KeyError, self.store.find_missing_objects, haves, wants)
+                KeyError, self.store.find_missing_objects, haves, wants, set())
 
     def test_no_changes(self):
         self.assertMissingMatch([self.cmt(3).id], [self.cmt(3).id], [])