Răsfoiți Sursa

Honor shallows when pushing from a shallow clone. Fixes: #794

Jelmer Vernooij 4 ani în urmă
părinte
comite
58df74905d
4 a modificat fișierele cu 51 adăugiri și 13 ștergeri
  1. 3 0
      NEWS
  2. 6 3
      dulwich/graph.py
  3. 31 10
      dulwich/repo.py
  4. 11 0
      dulwich/tests/test_repository.py

+ 3 - 0
NEWS

@@ -6,6 +6,9 @@
  * Fix pushing of new branches from porcelain.push.
    (Jelmer Vernooij, #788)
 
+ * Honor shallows when pushing from a shallow clone.
+   (Jelmer Vernooij, #794)
+
 0.20.5	2020-06-22
 
  * Print a clearer exception when setup.py is executed on Python < 3.5.

+ 6 - 3
dulwich/graph.py

@@ -99,7 +99,8 @@ def find_merge_base(repo, commit_ids):
     c2s = commit_ids[1:]
     if c1 in c2s:
         return [c1]
-    return _find_lcas(repo.get_parents, c1, c2s)
+    parents_provider = repo.parents_provider()
+    return _find_lcas(parents_provider.get_parents, c1, c2s)
 
 
 def find_octopus_base(repo, commit_ids):
@@ -116,12 +117,13 @@ def find_octopus_base(repo, commit_ids):
         return []
     if len(commit_ids) <= 2:
         return find_merge_base(repo, commit_ids)
+    parents_provider = repo.parents_provider()
     lcas = [commit_ids[0]]
     others = commit_ids[1:]
     for cmt in others:
         next_lcas = []
         for ca in lcas:
-            res = _find_lcas(repo.get_parents, cmt, [ca])
+            res = _find_lcas(parents_provider.get_parents, cmt, [ca])
             next_lcas.extend(res)
         lcas = next_lcas[:]
     return lcas
@@ -139,5 +141,6 @@ def can_fast_forward(repo, c1, c2):
         return True
 
     # Algorithm: Find the common ancestor
-    lcas = _find_lcas(repo.get_parents, c1, [c2])
+    parents_provider = repo.parents_provider()
+    lcas = _find_lcas(parents_provider.get_parents, c1, [c2])
     return lcas == [c1]

+ 31 - 10
dulwich/repo.py

@@ -297,6 +297,25 @@ def _set_filesystem_hidden(path):
     # Could implement other platform specific filesytem hiding here
 
 
+class ParentsProvider(object):
+
+    def __init__(self, store, grafts={}, shallows=[]):
+        self.store = store
+        self.grafts = grafts
+        self.shallows = set(shallows)
+
+    def get_parents(self, commit_id, commit=None):
+        try:
+            return self.grafts[commit_id]
+        except KeyError:
+            pass
+        if commit_id in self.shallows:
+            return []
+        if commit is None:
+            commit = self.store[commit_id]
+        return commit.parents
+
+
 class BaseRepo(object):
     """Base class for a git repository.
 
@@ -487,10 +506,11 @@ class BaseRepo(object):
             # commits aren't missing.
             haves = []
 
+        parents_provider = ParentsProvider(
+            self.object_store, shallows=shallows)
+
         def get_parents(commit):
-            if commit.id in shallows:
-                return []
-            return self.get_parents(commit.id, commit)
+            return parents_provider.get_parents(commit.id, commit)
 
         return self.object_store.iter_shas(
           self.object_store.find_missing_objects(
@@ -525,8 +545,9 @@ class BaseRepo(object):
             heads = [
                 sha for sha in self.refs.as_dict(b'refs/heads').values()
                 if sha in self.object_store]
+        parents_provider = ParentsProvider(self.object_store)
         return ObjectStoreGraphWalker(
-            heads, self.get_parents, shallow=self.get_shallow())
+            heads, parents_provider.get_parents, shallow=self.get_shallow())
 
     def get_refs(self) -> Dict[bytes, bytes]:
         """Get dictionary with all refs.
@@ -567,6 +588,11 @@ class BaseRepo(object):
         """
         return self.object_store[sha]
 
+    def parents_provider(self):
+        return ParentsProvider(
+            self.object_store, grafts=self._graftpoints,
+            shallows=self.get_shallow())
+
     def get_parents(self, sha: bytes, commit: Commit = None) -> List[bytes]:
         """Retrieve the parents of a specific commit.
 
@@ -578,12 +604,7 @@ class BaseRepo(object):
           commit: Optional commit matching the sha
         Returns: List of parents
         """
-        try:
-            return self._graftpoints[sha]
-        except KeyError:
-            if commit is None:
-                commit = self[sha]
-            return commit.parents
+        return self.parents_provider().get_parents(sha, commit)
 
     def get_config(self):
         """Retrieve the config object.

+ 11 - 0
dulwich/tests/test_repository.py

@@ -252,6 +252,17 @@ class RepositoryRootTests(TestCase):
         r = self.open_repo('a.git')
         self.assertEqual(r.get_peeled(b'HEAD'), r.head())
 
+    def test_get_parents(self):
+        r = self.open_repo('a.git')
+        self.assertEqual(
+            [b'2a72d929692c41d8554c07f6301757ba18a65d91'],
+            r.get_parents(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097'))
+        r.update_shallow(
+                [b'a90fa2d900a17e99b433217e988c4eb4a2e9a097'],
+                None)
+        self.assertEqual(
+            [], r.get_parents(b'a90fa2d900a17e99b433217e988c4eb4a2e9a097'))
+
     def test_get_walker(self):
         r = self.open_repo('a.git')
         # include defaults to [r.head()]