Browse Source

Improve graph performance (#1198)


Following the logic of actual git, this pr modifies graph.py to use a priority queue and commit time stamps to greatly speed up "can fast forward" calculations by using commit time to know which parent to walk up first and when to stop the search.

The use of commit time stamps necessitates a tweak to the test_graph.py.
Kevin Hendricks 1 year ago
parent
commit
2ea10938af
2 changed files with 126 additions and 25 deletions
  1. 96 21
      dulwich/graph.py
  2. 30 4
      dulwich/tests/test_graph.py

+ 96 - 21
dulwich/graph.py

@@ -20,11 +20,35 @@
 
 """Implementation of merge-base following the approach of git."""
 
-from collections import deque
-from typing import Deque
+from .lru_cache import LRUCache
 
+from heapq import heappush, heappop
 
-def _find_lcas(lookup_parents, c1, c2s):
+
+# priority queue using builtin python minheap tools
+# why they do not have a builtin maxheap is simply ridiculous but
+# liveable with integer time stamps using negation
+class WorkList(object):
+    def __init__(self):
+        self.pq = []
+
+    def add(self, item):
+        dt, cmt = item
+        heappush(self.pq, (-dt, cmt))
+
+    def get(self):
+        item = heappop(self.pq)
+        if item:
+            pr, cmt = item
+            return -pr, cmt
+        return None
+
+    def iter(self):
+        for (pr, cmt) in self.pq:
+            yield (-pr, cmt)
+
+
+def _find_lcas(lookup_parents, c1, c2s, lookup_stamp, min_stamp=0):
     cands = []
     cstates = {}
 
@@ -35,7 +59,7 @@ def _find_lcas(lookup_parents, c1, c2s):
     _LCA = 8  # potential LCA (Lowest Common Ancestor)
 
     def _has_candidates(wlst, cstates):
-        for cmt in wlst:
+        for dt, cmt in wlst.iter():
             if cmt in cstates:
                 if not ((cstates[cmt] & _DNC) == _DNC):
                     return True
@@ -43,18 +67,18 @@ def _find_lcas(lookup_parents, c1, c2s):
 
     # initialize the working list states with ancestry info
     # note possibility of c1 being one of c2s should be handled
-    wlst: Deque[bytes] = deque()
+    wlst = WorkList()
     cstates[c1] = _ANC_OF_1
-    wlst.append(c1)
+    wlst.add((lookup_stamp(c1), c1))
     for c2 in c2s:
         cflags = cstates.get(c2, 0)
         cstates[c2] = cflags | _ANC_OF_2
-        wlst.append(c2)
-    
+        wlst.add((lookup_stamp(c2), c2))
+
     # loop while at least one working list commit is still viable (not marked as _DNC)
     # adding any parents to the list in a breadth first manner
     while _has_candidates(wlst, cstates):
-        cmt = wlst.popleft()
+        dt, cmt = wlst.get()
         # Look only at ANCESTRY and _DNC flags so that already
         # found _LCAs can still be marked _DNC by lower _LCAS
         cflags = cstates[cmt] & (_ANC_OF_1 | _ANC_OF_2 | _DNC)
@@ -62,7 +86,7 @@ def _find_lcas(lookup_parents, c1, c2s):
             # potential common ancestor if not already in candidates add it
             if not (cstates[cmt] & _LCA) == _LCA:
                 cstates[cmt] = cstates[cmt] | _LCA
-                cands.append(cmt)
+                cands.append((dt, cmt))
             # mark any parents of this node _DNC as all parents
             # would be one generation further removed common ancestors
             cflags = cflags | _DNC
@@ -74,17 +98,24 @@ def _find_lcas(lookup_parents, c1, c2s):
                 # do not add it to the working list again
                 if ((pflags & cflags) == cflags):
                     continue
+                pdt = lookup_stamp(pcmt)
+                if pdt < min_stamp:
+                    continue
                 cstates[pcmt] = pflags | cflags
-                wlst.append(pcmt)
+                wlst.add((pdt, pcmt))
 
     # walk final candidates removing any superseded by _DNC by later lower _LCAs
+    # remove any duplicates and sort it so that earliest is first
     results = []
-    for cmt in cands:
-        if not ((cstates[cmt] & _DNC) == _DNC):
-            results.append(cmt)
-    return results
+    for dt, cmt in cands:
+        if not ((cstates[cmt] & _DNC) == _DNC) and not (dt, cmt) in results:
+            results.append((dt, cmt))
+    results.sort(key=lambda x: x[0])
+    lcas = [cmt for dt, cmt in results]
+    return lcas
 
 
+# actual git sorts these based on commit times
 def find_merge_base(repo, commit_ids):
     """Find lowest common ancestors of commit_ids[0] and *any* of commits_ids[1:].
 
@@ -94,6 +125,21 @@ def find_merge_base(repo, commit_ids):
     Returns:
       list of lowest common ancestor commit_ids
     """
+    cmtcache = LRUCache(max_cache=128)
+    parents_provider = repo.parents_provider()
+
+    def lookup_stamp(cmtid):
+        if cmtid not in cmtcache:
+            cmtcache[cmtid] = repo.object_store[cmtid]
+        return cmtcache[cmtid].commit_time
+
+    def lookup_parents(cmtid):
+        commit = None
+        if cmtid in cmtcache:
+            commit = cmtcache[cmtid]
+        # must use parents provider to handle grafts and shallow
+        return parents_provider.get_parents(cmtid, commit=commit)
+
     if not commit_ids:
         return []
     c1 = commit_ids[0]
@@ -102,8 +148,8 @@ def find_merge_base(repo, commit_ids):
     c2s = commit_ids[1:]
     if c1 in c2s:
         return [c1]
-    parents_provider = repo.parents_provider()
-    return _find_lcas(parents_provider.get_parents, c1, c2s)
+    lcas = _find_lcas(lookup_parents, c1, c2s, lookup_stamp)
+    return lcas
 
 
 def find_octopus_base(repo, commit_ids):
@@ -115,17 +161,31 @@ def find_octopus_base(repo, commit_ids):
     Returns:
       list of lowest common ancestor commit_ids
     """
+    cmtcache = LRUCache(max_cache=128)
+    parents_provider = repo.parents_provider()
+
+    def lookup_stamp(cmtid):
+        if cmtid not in cmtcache:
+            cmtcache[cmtid] = repo.object_store[cmtid]
+        return cmtcache[cmtid].commit_time
+
+    def lookup_parents(cmtid):
+        commit = None
+        if cmtid in cmtcache:
+            commit = cmtcache[cmtid]
+        # must use parents provider to handle grafts and shallow
+        return parents_provider.get_parents(cmtid, commit=commit)
+
     if not commit_ids:
         return []
     if len(commit_ids) <= 2:
         return find_merge_base(repo, commit_ids)
-    parents_provider = repo.parents_provider()
     lcas = [commit_ids[0]]
     others = commit_ids[1:]
     for cmt in others:
         next_lcas = []
         for ca in lcas:
-            res = _find_lcas(parents_provider.get_parents, cmt, [ca])
+            res = _find_lcas(lookup_parents, cmt, [ca], lookup_stamp)
             next_lcas.extend(res)
         lcas = next_lcas[:]
     return lcas
@@ -139,10 +199,25 @@ def can_fast_forward(repo, c1, c2):
       c1: Commit id for first commit
       c2: Commit id for second commit
     """
+    cmtcache = LRUCache(max_cache=128)
+    parents_provider = repo.parents_provider()
+
+    def lookup_stamp(cmtid):
+        if cmtid not in cmtcache:
+            cmtcache[cmtid] = repo.object_store[cmtid]
+        return cmtcache[cmtid].commit_time
+
+    def lookup_parents(cmtid):
+        commit = None
+        if cmtid in cmtcache:
+            commit = cmtcache[cmtid]
+        # must use parents provider to handle grafts and shallow
+        return parents_provider.get_parents(cmtid, commit=commit)
+
     if c1 == c2:
         return True
 
     # Algorithm: Find the common ancestor
-    parents_provider = repo.parents_provider()
-    lcas = _find_lcas(parents_provider.get_parents, c1, [c2])
+    min_stamp = lookup_stamp(c1)
+    lcas = _find_lcas(lookup_parents, c1, [c2], lookup_stamp, min_stamp=min_stamp)
     return lcas == [c1]

+ 30 - 4
dulwich/tests/test_graph.py

@@ -1,4 +1,4 @@
-# test_index.py -- Tests for merge
+# test_graph.py -- Tests for merge base
 # Copyright (c) 2020 Kevin B. Hendricks, Stratford Ontario Canada
 #
 # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
@@ -21,7 +21,7 @@
 
 from dulwich.tests import TestCase
 
-from ..graph import _find_lcas, can_fast_forward
+from ..graph import _find_lcas, can_fast_forward, WorkList
 from ..repo import MemoryRepo
 from .utils import make_commit
 
@@ -32,9 +32,14 @@ class FindMergeBaseTests(TestCase):
         def lookup_parents(commit_id):
             return dag[commit_id]
 
+        def lookup_stamp(commit_id):
+            # any constant timestamp value here will work to force
+            # this test to test the same behaviour as done previously
+            return 100
+
         c1 = inputs[0]
         c2s = inputs[1:]
-        return set(_find_lcas(lookup_parents, c1, c2s))
+        return set(_find_lcas(lookup_parents, c1, c2s, lookup_stamp))
 
     def test_multiple_lca(self):
         # two lowest common ancestors
@@ -146,12 +151,17 @@ class FindMergeBaseTests(TestCase):
         def lookup_parents(cid):
             return graph[cid]
 
+        def lookup_stamp(commit_id):
+            # any constant timestamp value here will work to force
+            # this test to test the same behaviour as done previously
+            return 100
+
         lcas = ["A"]
         others = ["B", "C"]
         for cmt in others:
             next_lcas = []
             for ca in lcas:
-                res = _find_lcas(lookup_parents, cmt, [ca])
+                res = _find_lcas(lookup_parents, cmt, [ca], lookup_stamp)
                 next_lcas.extend(res)
             lcas = next_lcas[:]
         self.assertEqual(set(lcas), {"2"})
@@ -180,3 +190,19 @@ class CanFastForwardTests(TestCase):
         self.assertTrue(can_fast_forward(r, c1.id, c2b.id))
         self.assertFalse(can_fast_forward(r, c2a.id, c2b.id))
         self.assertFalse(can_fast_forward(r, c2b.id, c2a.id))
+
+
+class WorkListTest(TestCase):
+    def test_WorkList(self):
+        # tuples of (timestamp, value) are stored in a Priority MaxQueue
+        # repeated use of get should return them in maxheap timestamp
+        # order: largest time value (most recent in time) first then earlier/older
+        wlst = WorkList()
+        wlst.add((100, "Test Value 1"))
+        wlst.add((50, "Test Value 2"))
+        wlst.add((200, "Test Value 3"))
+        self.assertTrue(wlst.get() == (200, "Test Value 3"))
+        self.assertTrue(wlst.get() == (100, "Test Value 1"))
+        wlst.add((150, "Test Value 4"))
+        self.assertTrue(wlst.get() == (150, "Test Value 4"))
+        self.assertTrue(wlst.get() == (50, "Test Value 2"))