Browse Source

Improve graph performance (#1198)


Following the logic of actual git, this pr modifies graph.py to use a priority queue and commit time stamps to greatly speed up "can fast forward" calculations by using commit time to know which parent to walk up first and when to stop the search.

The use of commit time stamps necessitates a tweak to the test_graph.py.
Kevin Hendricks 1 year ago
parent
commit
2ea10938af
2 changed files with 126 additions and 25 deletions
  1. 96 21
      dulwich/graph.py
  2. 30 4
      dulwich/tests/test_graph.py

+ 96 - 21
dulwich/graph.py

@@ -20,11 +20,35 @@
 
 
 """Implementation of merge-base following the approach of git."""
 """Implementation of merge-base following the approach of git."""
 
 
-from collections import deque
-from typing import Deque
+from .lru_cache import LRUCache
 
 
+from heapq import heappush, heappop
 
 
-def _find_lcas(lookup_parents, c1, c2s):
+
+# priority queue using builtin python minheap tools
+# why they do not have a builtin maxheap is simply ridiculous but
+# liveable with integer time stamps using negation
+class WorkList(object):
+    def __init__(self):
+        self.pq = []
+
+    def add(self, item):
+        dt, cmt = item
+        heappush(self.pq, (-dt, cmt))
+
+    def get(self):
+        item = heappop(self.pq)
+        if item:
+            pr, cmt = item
+            return -pr, cmt
+        return None
+
+    def iter(self):
+        for (pr, cmt) in self.pq:
+            yield (-pr, cmt)
+
+
+def _find_lcas(lookup_parents, c1, c2s, lookup_stamp, min_stamp=0):
     cands = []
     cands = []
     cstates = {}
     cstates = {}
 
 
@@ -35,7 +59,7 @@ def _find_lcas(lookup_parents, c1, c2s):
     _LCA = 8  # potential LCA (Lowest Common Ancestor)
     _LCA = 8  # potential LCA (Lowest Common Ancestor)
 
 
     def _has_candidates(wlst, cstates):
     def _has_candidates(wlst, cstates):
-        for cmt in wlst:
+        for dt, cmt in wlst.iter():
             if cmt in cstates:
             if cmt in cstates:
                 if not ((cstates[cmt] & _DNC) == _DNC):
                 if not ((cstates[cmt] & _DNC) == _DNC):
                     return True
                     return True
@@ -43,18 +67,18 @@ def _find_lcas(lookup_parents, c1, c2s):
 
 
     # initialize the working list states with ancestry info
     # initialize the working list states with ancestry info
     # note possibility of c1 being one of c2s should be handled
     # note possibility of c1 being one of c2s should be handled
-    wlst: Deque[bytes] = deque()
+    wlst = WorkList()
     cstates[c1] = _ANC_OF_1
     cstates[c1] = _ANC_OF_1
-    wlst.append(c1)
+    wlst.add((lookup_stamp(c1), c1))
     for c2 in c2s:
     for c2 in c2s:
         cflags = cstates.get(c2, 0)
         cflags = cstates.get(c2, 0)
         cstates[c2] = cflags | _ANC_OF_2
         cstates[c2] = cflags | _ANC_OF_2
-        wlst.append(c2)
-    
+        wlst.add((lookup_stamp(c2), c2))
+
     # loop while at least one working list commit is still viable (not marked as _DNC)
     # loop while at least one working list commit is still viable (not marked as _DNC)
     # adding any parents to the list in a breadth first manner
     # adding any parents to the list in a breadth first manner
     while _has_candidates(wlst, cstates):
     while _has_candidates(wlst, cstates):
-        cmt = wlst.popleft()
+        dt, cmt = wlst.get()
         # Look only at ANCESTRY and _DNC flags so that already
         # Look only at ANCESTRY and _DNC flags so that already
         # found _LCAs can still be marked _DNC by lower _LCAS
         # found _LCAs can still be marked _DNC by lower _LCAS
         cflags = cstates[cmt] & (_ANC_OF_1 | _ANC_OF_2 | _DNC)
         cflags = cstates[cmt] & (_ANC_OF_1 | _ANC_OF_2 | _DNC)
@@ -62,7 +86,7 @@ def _find_lcas(lookup_parents, c1, c2s):
             # potential common ancestor if not already in candidates add it
             # potential common ancestor if not already in candidates add it
             if not (cstates[cmt] & _LCA) == _LCA:
             if not (cstates[cmt] & _LCA) == _LCA:
                 cstates[cmt] = cstates[cmt] | _LCA
                 cstates[cmt] = cstates[cmt] | _LCA
-                cands.append(cmt)
+                cands.append((dt, cmt))
             # mark any parents of this node _DNC as all parents
             # mark any parents of this node _DNC as all parents
             # would be one generation further removed common ancestors
             # would be one generation further removed common ancestors
             cflags = cflags | _DNC
             cflags = cflags | _DNC
@@ -74,17 +98,24 @@ def _find_lcas(lookup_parents, c1, c2s):
                 # do not add it to the working list again
                 # do not add it to the working list again
                 if ((pflags & cflags) == cflags):
                 if ((pflags & cflags) == cflags):
                     continue
                     continue
+                pdt = lookup_stamp(pcmt)
+                if pdt < min_stamp:
+                    continue
                 cstates[pcmt] = pflags | cflags
                 cstates[pcmt] = pflags | cflags
-                wlst.append(pcmt)
+                wlst.add((pdt, pcmt))
 
 
     # walk final candidates removing any superseded by _DNC by later lower _LCAs
     # walk final candidates removing any superseded by _DNC by later lower _LCAs
+    # remove any duplicates and sort it so that earliest is first
     results = []
     results = []
-    for cmt in cands:
-        if not ((cstates[cmt] & _DNC) == _DNC):
-            results.append(cmt)
-    return results
+    for dt, cmt in cands:
+        if not ((cstates[cmt] & _DNC) == _DNC) and not (dt, cmt) in results:
+            results.append((dt, cmt))
+    results.sort(key=lambda x: x[0])
+    lcas = [cmt for dt, cmt in results]
+    return lcas
 
 
 
 
+# actual git sorts these based on commit times
 def find_merge_base(repo, commit_ids):
 def find_merge_base(repo, commit_ids):
     """Find lowest common ancestors of commit_ids[0] and *any* of commits_ids[1:].
     """Find lowest common ancestors of commit_ids[0] and *any* of commits_ids[1:].
 
 
@@ -94,6 +125,21 @@ def find_merge_base(repo, commit_ids):
     Returns:
     Returns:
       list of lowest common ancestor commit_ids
       list of lowest common ancestor commit_ids
     """
     """
+    cmtcache = LRUCache(max_cache=128)
+    parents_provider = repo.parents_provider()
+
+    def lookup_stamp(cmtid):
+        if cmtid not in cmtcache:
+            cmtcache[cmtid] = repo.object_store[cmtid]
+        return cmtcache[cmtid].commit_time
+
+    def lookup_parents(cmtid):
+        commit = None
+        if cmtid in cmtcache:
+            commit = cmtcache[cmtid]
+        # must use parents provider to handle grafts and shallow
+        return parents_provider.get_parents(cmtid, commit=commit)
+
     if not commit_ids:
     if not commit_ids:
         return []
         return []
     c1 = commit_ids[0]
     c1 = commit_ids[0]
@@ -102,8 +148,8 @@ def find_merge_base(repo, commit_ids):
     c2s = commit_ids[1:]
     c2s = commit_ids[1:]
     if c1 in c2s:
     if c1 in c2s:
         return [c1]
         return [c1]
-    parents_provider = repo.parents_provider()
-    return _find_lcas(parents_provider.get_parents, c1, c2s)
+    lcas = _find_lcas(lookup_parents, c1, c2s, lookup_stamp)
+    return lcas
 
 
 
 
 def find_octopus_base(repo, commit_ids):
 def find_octopus_base(repo, commit_ids):
@@ -115,17 +161,31 @@ def find_octopus_base(repo, commit_ids):
     Returns:
     Returns:
       list of lowest common ancestor commit_ids
       list of lowest common ancestor commit_ids
     """
     """
+    cmtcache = LRUCache(max_cache=128)
+    parents_provider = repo.parents_provider()
+
+    def lookup_stamp(cmtid):
+        if cmtid not in cmtcache:
+            cmtcache[cmtid] = repo.object_store[cmtid]
+        return cmtcache[cmtid].commit_time
+
+    def lookup_parents(cmtid):
+        commit = None
+        if cmtid in cmtcache:
+            commit = cmtcache[cmtid]
+        # must use parents provider to handle grafts and shallow
+        return parents_provider.get_parents(cmtid, commit=commit)
+
     if not commit_ids:
     if not commit_ids:
         return []
         return []
     if len(commit_ids) <= 2:
     if len(commit_ids) <= 2:
         return find_merge_base(repo, commit_ids)
         return find_merge_base(repo, commit_ids)
-    parents_provider = repo.parents_provider()
     lcas = [commit_ids[0]]
     lcas = [commit_ids[0]]
     others = commit_ids[1:]
     others = commit_ids[1:]
     for cmt in others:
     for cmt in others:
         next_lcas = []
         next_lcas = []
         for ca in lcas:
         for ca in lcas:
-            res = _find_lcas(parents_provider.get_parents, cmt, [ca])
+            res = _find_lcas(lookup_parents, cmt, [ca], lookup_stamp)
             next_lcas.extend(res)
             next_lcas.extend(res)
         lcas = next_lcas[:]
         lcas = next_lcas[:]
     return lcas
     return lcas
@@ -139,10 +199,25 @@ def can_fast_forward(repo, c1, c2):
       c1: Commit id for first commit
       c1: Commit id for first commit
       c2: Commit id for second commit
       c2: Commit id for second commit
     """
     """
+    cmtcache = LRUCache(max_cache=128)
+    parents_provider = repo.parents_provider()
+
+    def lookup_stamp(cmtid):
+        if cmtid not in cmtcache:
+            cmtcache[cmtid] = repo.object_store[cmtid]
+        return cmtcache[cmtid].commit_time
+
+    def lookup_parents(cmtid):
+        commit = None
+        if cmtid in cmtcache:
+            commit = cmtcache[cmtid]
+        # must use parents provider to handle grafts and shallow
+        return parents_provider.get_parents(cmtid, commit=commit)
+
     if c1 == c2:
     if c1 == c2:
         return True
         return True
 
 
     # Algorithm: Find the common ancestor
     # Algorithm: Find the common ancestor
-    parents_provider = repo.parents_provider()
-    lcas = _find_lcas(parents_provider.get_parents, c1, [c2])
+    min_stamp = lookup_stamp(c1)
+    lcas = _find_lcas(lookup_parents, c1, [c2], lookup_stamp, min_stamp=min_stamp)
     return lcas == [c1]
     return lcas == [c1]

+ 30 - 4
dulwich/tests/test_graph.py

@@ -1,4 +1,4 @@
-# test_index.py -- Tests for merge
+# test_graph.py -- Tests for merge base
 # Copyright (c) 2020 Kevin B. Hendricks, Stratford Ontario Canada
 # Copyright (c) 2020 Kevin B. Hendricks, Stratford Ontario Canada
 #
 #
 # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
 # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
@@ -21,7 +21,7 @@
 
 
 from dulwich.tests import TestCase
 from dulwich.tests import TestCase
 
 
-from ..graph import _find_lcas, can_fast_forward
+from ..graph import _find_lcas, can_fast_forward, WorkList
 from ..repo import MemoryRepo
 from ..repo import MemoryRepo
 from .utils import make_commit
 from .utils import make_commit
 
 
@@ -32,9 +32,14 @@ class FindMergeBaseTests(TestCase):
         def lookup_parents(commit_id):
         def lookup_parents(commit_id):
             return dag[commit_id]
             return dag[commit_id]
 
 
+        def lookup_stamp(commit_id):
+            # any constant timestamp value here will work to force
+            # this test to test the same behaviour as done previously
+            return 100
+
         c1 = inputs[0]
         c1 = inputs[0]
         c2s = inputs[1:]
         c2s = inputs[1:]
-        return set(_find_lcas(lookup_parents, c1, c2s))
+        return set(_find_lcas(lookup_parents, c1, c2s, lookup_stamp))
 
 
     def test_multiple_lca(self):
     def test_multiple_lca(self):
         # two lowest common ancestors
         # two lowest common ancestors
@@ -146,12 +151,17 @@ class FindMergeBaseTests(TestCase):
         def lookup_parents(cid):
         def lookup_parents(cid):
             return graph[cid]
             return graph[cid]
 
 
+        def lookup_stamp(commit_id):
+            # any constant timestamp value here will work to force
+            # this test to test the same behaviour as done previously
+            return 100
+
         lcas = ["A"]
         lcas = ["A"]
         others = ["B", "C"]
         others = ["B", "C"]
         for cmt in others:
         for cmt in others:
             next_lcas = []
             next_lcas = []
             for ca in lcas:
             for ca in lcas:
-                res = _find_lcas(lookup_parents, cmt, [ca])
+                res = _find_lcas(lookup_parents, cmt, [ca], lookup_stamp)
                 next_lcas.extend(res)
                 next_lcas.extend(res)
             lcas = next_lcas[:]
             lcas = next_lcas[:]
         self.assertEqual(set(lcas), {"2"})
         self.assertEqual(set(lcas), {"2"})
@@ -180,3 +190,19 @@ class CanFastForwardTests(TestCase):
         self.assertTrue(can_fast_forward(r, c1.id, c2b.id))
         self.assertTrue(can_fast_forward(r, c1.id, c2b.id))
         self.assertFalse(can_fast_forward(r, c2a.id, c2b.id))
         self.assertFalse(can_fast_forward(r, c2a.id, c2b.id))
         self.assertFalse(can_fast_forward(r, c2b.id, c2a.id))
         self.assertFalse(can_fast_forward(r, c2b.id, c2a.id))
+
+
+class WorkListTest(TestCase):
+    def test_WorkList(self):
+        # tuples of (timestamp, value) are stored in a Priority MaxQueue
+        # repeated use of get should return them in maxheap timestamp
+        # order: largest time value (most recent in time) first then earlier/older
+        wlst = WorkList()
+        wlst.add((100, "Test Value 1"))
+        wlst.add((50, "Test Value 2"))
+        wlst.add((200, "Test Value 3"))
+        self.assertTrue(wlst.get() == (200, "Test Value 3"))
+        self.assertTrue(wlst.get() == (100, "Test Value 1"))
+        wlst.add((150, "Test Value 4"))
+        self.assertTrue(wlst.get() == (150, "Test Value 4"))
+        self.assertTrue(wlst.get() == (50, "Test Value 2"))