Browse Source

Add ObjectStore.iter_prefix (#1402)

This allow scanning an object store for a specific SHA prefix. The
default implementation is naive and scans the entire key space, but
specific implementations have more efficient approaches, improving
performance.,
Jelmer Vernooij 10 months ago
parent
commit
8ccaacf7a3
6 changed files with 96 additions and 8 deletions
  1. 2 0
      NEWS
  2. 43 0
      dulwich/object_store.py
  3. 8 8
      dulwich/objectspec.py
  4. 22 0
      dulwich/pack.py
  5. 11 0
      dulwich/tests/test_object_store.py
  6. 10 0
      tests/test_pack.py

+ 2 - 0
NEWS

@@ -3,6 +3,8 @@
  * Fix handling of symrefs with protocol v2.
    (Jelmer Vernooij, #1389)
 
+ * Add ``ObjectStore.iter_prefix``.  (Jelmer Vernooij)
+
 0.22.3	2024-10-15
 
  * Improve wheel building in CI, so we can upload wheels for the next release.

+ 43 - 0
dulwich/object_store.py

@@ -22,6 +22,7 @@
 
 """Git object store interfaces and implementation."""
 
+import binascii
 import os
 import stat
 import sys
@@ -358,6 +359,17 @@ class BaseObjectStore:
         """Close any files opened by this object store."""
         # Default implementation is a NO-OP
 
+    def iter_prefix(self, prefix: bytes) -> Iterator[ObjectID]:
+        """Iterate over all SHA1s that start with a given prefix.
+
+        The default implementation is a naive iteration over all objects.
+        However, subclasses may override this method with more efficient
+        implementations.
+        """
+        for sha in self:
+            if sha.startswith(prefix):
+                yield sha
+
 
 class PackBasedObjectStore(BaseObjectStore):
     def __init__(self, pack_compression_level=-1) -> None:
@@ -1027,6 +1039,37 @@ class DiskObjectStore(PackBasedObjectStore):
         os.mkdir(os.path.join(path, PACKDIR))
         return cls(path)
 
+    def iter_prefix(self, prefix):
+        if len(prefix) < 2:
+            yield from super().iter_prefix(prefix)
+            return
+        seen = set()
+        dir = prefix[:2].decode()
+        rest = prefix[2:].decode()
+        for name in os.listdir(os.path.join(self.path, dir)):
+            if name.startswith(rest):
+                sha = os.fsencode(dir + name)
+                if sha not in seen:
+                    seen.add(sha)
+                    yield sha
+
+        for p in self.packs:
+            bin_prefix = (
+                binascii.unhexlify(prefix)
+                if len(prefix) % 2 == 0
+                else binascii.unhexlify(prefix[:-1])
+            )
+            for sha in p.index.iter_prefix(bin_prefix):
+                sha = sha_to_hex(sha)
+                if sha.startswith(prefix) and sha not in seen:
+                    seen.add(sha)
+                    yield sha
+        for alternate in self.alternates:
+            for sha in alternate.iter_prefix(prefix):
+                if sha not in seen:
+                    seen.add(sha)
+                    yield sha
+
 
 class MemoryObjectStore(BaseObjectStore):
     """Object store that keeps all objects in memory."""

+ 8 - 8
dulwich/objectspec.py

@@ -22,8 +22,9 @@
 
 from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union
 
+from .objects import Commit, ShaFile, Tree
+
 if TYPE_CHECKING:
-    from .objects import Commit, ShaFile, Tree
     from .refs import Ref, RefsContainer
     from .repo import Repo
 
@@ -209,14 +210,13 @@ class AmbiguousShortId(Exception):
         self.options = options
 
 
-def scan_for_short_id(object_store, prefix):
+def scan_for_short_id(object_store, prefix, tp):
     """Scan an object store for a short id."""
-    # TODO(jelmer): This could short-circuit looking for objects
-    # starting with a certain prefix.
     ret = []
-    for object_id in object_store:
-        if object_id.startswith(prefix):
-            ret.append(object_store[object_id])
+    for object_id in object_store.iter_prefix(prefix):
+        o = object_store[object_id]
+        if isinstance(o, tp):
+            ret.append(o)
     if not ret:
         raise KeyError(prefix)
     if len(ret) == 1:
@@ -251,7 +251,7 @@ def parse_commit(repo: "Repo", committish: Union[str, bytes]) -> "Commit":
             pass
         else:
             try:
-                return scan_for_short_id(repo.object_store, committish)
+                return scan_for_short_id(repo.object_store, committish, Commit)
             except KeyError:
                 pass
     raise KeyError(committish)

+ 22 - 0
dulwich/pack.py

@@ -746,6 +746,28 @@ class FilePackIndex(PackIndex):
             raise KeyError(sha)
         return self._unpack_offset(i)
 
+    def iter_prefix(self, prefix: bytes) -> Iterator[bytes]:
+        """Iterate over all SHA1s with the given prefix."""
+        start = ord(prefix[:1])
+        if start == 0:
+            start = 0
+        else:
+            start = self._fan_out_table[start - 1]
+        end = ord(prefix[:1]) + 1
+        if end == 0x100:
+            end = len(self)
+        else:
+            end = self._fan_out_table[end]
+        assert start <= end
+        started = False
+        for i in range(start, end):
+            name = self._unpack_name(i)
+            if name.startswith(prefix):
+                yield name
+                started = True
+            elif started:
+                break
+
 
 class PackIndex1(FilePackIndex):
     """Version 1 Pack Index file."""

+ 11 - 0
dulwich/tests/test_object_store.py

@@ -236,6 +236,17 @@ class ObjectStoreTests:
         self.store.add_object(testobject)
         self.store.close()
 
+    def test_iter_prefix(self):
+        self.store.add_object(testobject)
+        self.assertEqual([testobject.id], list(self.store.iter_prefix(testobject.id)))
+        self.assertEqual(
+            [testobject.id], list(self.store.iter_prefix(testobject.id[:10]))
+        )
+        self.assertEqual(
+            [testobject.id], list(self.store.iter_prefix(testobject.id[:4]))
+        )
+        self.assertEqual([testobject.id], list(self.store.iter_prefix(b"")))
+
 
 class PackBasedObjectStoreTests(ObjectStoreTests):
     def tearDown(self):

+ 10 - 0
tests/test_pack.py

@@ -124,6 +124,16 @@ class PackIndexTests(PackTests):
         self.assertEqual(p.object_sha1(138), hex_to_sha(tree_sha))
         self.assertEqual(p.object_sha1(12), hex_to_sha(commit_sha))
 
+    def test_iter_prefix(self):
+        p = self.get_pack_index(pack1_sha)
+        self.assertEqual([p.object_sha1(178)], list(p.iter_prefix(hex_to_sha(a_sha))))
+        self.assertEqual(
+            [p.object_sha1(178)], list(p.iter_prefix(hex_to_sha(a_sha)[:5]))
+        )
+        self.assertEqual(
+            [p.object_sha1(178)], list(p.iter_prefix(hex_to_sha(a_sha)[:2]))
+        )
+
     def test_index_len(self):
         p = self.get_pack_index(pack1_sha)
         self.assertEqual(3, len(p))