Prechádzať zdrojové kódy

Use Repo.get_parents, so graft points are considered.

Jelmer Vernooij 4 rokov pred
rodič
commit
24cafe790a
3 zmenil súbory, kde vykonal 30 pridanie a 38 odobranie
  1. 11 20
      dulwich/graph.py
  2. 5 5
      dulwich/porcelain.py
  3. 14 13
      dulwich/tests/test_graph.py

+ 11 - 20
dulwich/graph.py

@@ -82,18 +82,15 @@ def _find_lcas(lookup_parents, c1, c2s):
     return results
 
 
-def find_merge_base(object_store, commit_ids):
+def find_merge_base(repo, commit_ids):
     """Find lowest common ancestors of commit_ids[0] and *any* of commits_ids[1:]
 
     Args:
-      object_store: object store
-      commit_ids:  list of commit ids
+      repo: Repository object
+      commit_ids: list of commit ids
     Returns:
       list of lowest common ancestor commit_ids
     """
-    def lookup_parents(commit_id):
-        return object_store[commit_id].parents
-
     if not commit_ids:
         return []
     c1 = commit_ids[0]
@@ -102,51 +99,45 @@ def find_merge_base(object_store, commit_ids):
     c2s = commit_ids[1:]
     if c1 in c2s:
         return [c1]
-    return _find_lcas(lookup_parents, c1, c2s)
+    return _find_lcas(repo.get_parents, c1, c2s)
 
 
-def find_octopus_base(object_store, commit_ids):
+def find_octopus_base(repo, commit_ids):
     """Find lowest common ancestors of *all* provided commit_ids
 
     Args:
-      object_store: Object store
+      repo: Repository
       commit_ids:  list of commit ids
     Returns:
       list of lowest common ancestor commit_ids
     """
 
-    def lookup_parents(commit_id):
-        return object_store[commit_id].parents
-
     if not commit_ids:
         return []
     if len(commit_ids) <= 2:
-        return find_merge_base(object_store, commit_ids)
+        return find_merge_base(repo, commit_ids)
     lcas = [commit_ids[0]]
     others = commit_ids[1:]
     for cmt in others:
         next_lcas = []
         for ca in lcas:
-            res = _find_lcas(lookup_parents, cmt, [ca])
+            res = _find_lcas(repo.get_parents, cmt, [ca])
             next_lcas.extend(res)
         lcas = next_lcas[:]
     return lcas
 
 
-def can_fast_forward(object_store, c1, c2):
+def can_fast_forward(repo, c1, c2):
     """Is it possible to fast-forward from c1 to c2?
 
     Args:
-      object_store: Store to retrieve objects from
+      repo: Repository to retrieve objects from
       c1: Commit id for first commit
       c2: Commit id for second commit
     """
     if c1 == c2:
         return True
 
-    def lookup_parents(commit_id):
-        return object_store[commit_id].parents
-
     # Algorithm: Find the common ancestor
-    lcas = _find_lcas(lookup_parents, c1, [c2])
+    lcas = _find_lcas(repo.get_parents, c1, [c2])
     return lcas == [c1]

+ 5 - 5
dulwich/porcelain.py

@@ -232,16 +232,16 @@ class DivergedBranches(Error):
     """Branches have diverged and fast-forward is not possible."""
 
 
-def check_diverged(store, current_sha, new_sha):
+def check_diverged(repo, current_sha, new_sha):
     """Check if updating to a sha can be done with fast forwarding.
 
     Args:
-      store: Object store
+      repo: Repository object
       current_sha: Current head sha
       new_sha: New head sha
     """
     try:
-        can = can_fast_forward(store, current_sha, new_sha)
+        can = can_fast_forward(repo, current_sha, new_sha)
     except KeyError:
         can = False
     if not can:
@@ -969,7 +969,7 @@ def push(repo, remote_location=None, refspecs=None,
                     remote_changed_refs[rh] = None
                 else:
                     if not force_ref:
-                        check_diverged(r.object_store, refs[rh], r.refs[lh])
+                        check_diverged(r, refs[rh], r.refs[lh])
                     new_refs[rh] = r.refs[lh]
                     remote_changed_refs[rh] = r.refs[lh]
             return new_refs
@@ -1036,7 +1036,7 @@ def pull(repo, remote_location=None, refspecs=None,
         for (lh, rh, force_ref) in selected_refs:
             try:
                 check_diverged(
-                    r.object_store, r.refs[rh], fetch_result.refs[lh])
+                    r, r.refs[rh], fetch_result.refs[lh])
             except DivergedBranches:
                 if fast_forward:
                     raise

+ 14 - 13
dulwich/tests/test_graph.py

@@ -23,7 +23,7 @@
 
 from dulwich.tests import TestCase
 from dulwich.tests.utils import make_commit
-from dulwich.object_store import MemoryObjectStore
+from dulwich.repo import MemoryRepo
 
 from dulwich.graph import _find_lcas, can_fast_forward
 
@@ -161,24 +161,25 @@ class FindMergeBaseTests(TestCase):
 class CanFastForwardTests(TestCase):
 
     def test_ff(self):
-        store = MemoryObjectStore()
+        r = MemoryRepo()
         base = make_commit()
         c1 = make_commit(parents=[base.id])
         c2 = make_commit(parents=[c1.id])
-        store.add_objects([(base, None), (c1, None), (c2, None)])
-        self.assertTrue(can_fast_forward(store, c1.id, c1.id))
-        self.assertTrue(can_fast_forward(store, base.id, c1.id))
-        self.assertTrue(can_fast_forward(store, c1.id, c2.id))
-        self.assertFalse(can_fast_forward(store, c2.id, c1.id))
+        r.object_store.add_objects([(base, None), (c1, None), (c2, None)])
+        self.assertTrue(can_fast_forward(r, c1.id, c1.id))
+        self.assertTrue(can_fast_forward(r, base.id, c1.id))
+        self.assertTrue(can_fast_forward(r, c1.id, c2.id))
+        self.assertFalse(can_fast_forward(r, c2.id, c1.id))
 
     def test_diverged(self):
-        store = MemoryObjectStore()
+        r = MemoryRepo()
         base = make_commit()
         c1 = make_commit(parents=[base.id])
         c2a = make_commit(parents=[c1.id], message=b'2a')
         c2b = make_commit(parents=[c1.id], message=b'2b')
-        store.add_objects([(base, None), (c1, None), (c2a, None), (c2b, None)])
-        self.assertTrue(can_fast_forward(store, c1.id, c2a.id))
-        self.assertTrue(can_fast_forward(store, c1.id, c2b.id))
-        self.assertFalse(can_fast_forward(store, c2a.id, c2b.id))
-        self.assertFalse(can_fast_forward(store, c2b.id, c2a.id))
+        r.object_store.add_objects(
+            [(base, None), (c1, None), (c2a, None), (c2b, None)])
+        self.assertTrue(can_fast_forward(r, c1.id, c2a.id))
+        self.assertTrue(can_fast_forward(r, c1.id, c2b.id))
+        self.assertFalse(can_fast_forward(r, c2a.id, c2b.id))
+        self.assertFalse(can_fast_forward(r, c2b.id, c2a.id))