Преглед изворни кода

Add ObjectStore.iter_prefix (#1402)

This allow scanning an object store for a specific SHA prefix. The
default implementation is naive and scans the entire key space, but
specific implementations have more efficient approaches, improving
performance.,
Jelmer Vernooij пре 4 месеци
родитељ
комит
8ccaacf7a3
6 измењених фајлова са 96 додато и 8 уклоњено
  1. 2 0
      NEWS
  2. 43 0
      dulwich/object_store.py
  3. 8 8
      dulwich/objectspec.py
  4. 22 0
      dulwich/pack.py
  5. 11 0
      dulwich/tests/test_object_store.py
  6. 10 0
      tests/test_pack.py

+ 2 - 0
NEWS

@@ -3,6 +3,8 @@
  * Fix handling of symrefs with protocol v2.
    (Jelmer Vernooij, #1389)
 
+ * Add ``ObjectStore.iter_prefix``.  (Jelmer Vernooij)
+
 0.22.3	2024-10-15
 
  * Improve wheel building in CI, so we can upload wheels for the next release.

+ 43 - 0
dulwich/object_store.py

@@ -22,6 +22,7 @@
 
 """Git object store interfaces and implementation."""
 
+import binascii
 import os
 import stat
 import sys
@@ -358,6 +359,17 @@ class BaseObjectStore:
         """Close any files opened by this object store."""
         # Default implementation is a NO-OP
 
+    def iter_prefix(self, prefix: bytes) -> Iterator[ObjectID]:
+        """Iterate over all SHA1s that start with a given prefix.
+
+        The default implementation is a naive iteration over all objects.
+        However, subclasses may override this method with more efficient
+        implementations.
+        """
+        for sha in self:
+            if sha.startswith(prefix):
+                yield sha
+
 
 class PackBasedObjectStore(BaseObjectStore):
     def __init__(self, pack_compression_level=-1) -> None:
@@ -1027,6 +1039,37 @@ class DiskObjectStore(PackBasedObjectStore):
         os.mkdir(os.path.join(path, PACKDIR))
         return cls(path)
 
+    def iter_prefix(self, prefix):
+        if len(prefix) < 2:
+            yield from super().iter_prefix(prefix)
+            return
+        seen = set()
+        dir = prefix[:2].decode()
+        rest = prefix[2:].decode()
+        for name in os.listdir(os.path.join(self.path, dir)):
+            if name.startswith(rest):
+                sha = os.fsencode(dir + name)
+                if sha not in seen:
+                    seen.add(sha)
+                    yield sha
+
+        for p in self.packs:
+            bin_prefix = (
+                binascii.unhexlify(prefix)
+                if len(prefix) % 2 == 0
+                else binascii.unhexlify(prefix[:-1])
+            )
+            for sha in p.index.iter_prefix(bin_prefix):
+                sha = sha_to_hex(sha)
+                if sha.startswith(prefix) and sha not in seen:
+                    seen.add(sha)
+                    yield sha
+        for alternate in self.alternates:
+            for sha in alternate.iter_prefix(prefix):
+                if sha not in seen:
+                    seen.add(sha)
+                    yield sha
+
 
 class MemoryObjectStore(BaseObjectStore):
     """Object store that keeps all objects in memory."""

+ 8 - 8
dulwich/objectspec.py

@@ -22,8 +22,9 @@
 
 from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union
 
+from .objects import Commit, ShaFile, Tree
+
 if TYPE_CHECKING:
-    from .objects import Commit, ShaFile, Tree
     from .refs import Ref, RefsContainer
     from .repo import Repo
 
@@ -209,14 +210,13 @@ class AmbiguousShortId(Exception):
         self.options = options
 
 
-def scan_for_short_id(object_store, prefix):
+def scan_for_short_id(object_store, prefix, tp):
     """Scan an object store for a short id."""
-    # TODO(jelmer): This could short-circuit looking for objects
-    # starting with a certain prefix.
     ret = []
-    for object_id in object_store:
-        if object_id.startswith(prefix):
-            ret.append(object_store[object_id])
+    for object_id in object_store.iter_prefix(prefix):
+        o = object_store[object_id]
+        if isinstance(o, tp):
+            ret.append(o)
     if not ret:
         raise KeyError(prefix)
     if len(ret) == 1:
@@ -251,7 +251,7 @@ def parse_commit(repo: "Repo", committish: Union[str, bytes]) -> "Commit":
             pass
         else:
             try:
-                return scan_for_short_id(repo.object_store, committish)
+                return scan_for_short_id(repo.object_store, committish, Commit)
             except KeyError:
                 pass
     raise KeyError(committish)

+ 22 - 0
dulwich/pack.py

@@ -746,6 +746,28 @@ class FilePackIndex(PackIndex):
             raise KeyError(sha)
         return self._unpack_offset(i)
 
+    def iter_prefix(self, prefix: bytes) -> Iterator[bytes]:
+        """Iterate over all SHA1s with the given prefix."""
+        start = ord(prefix[:1])
+        if start == 0:
+            start = 0
+        else:
+            start = self._fan_out_table[start - 1]
+        end = ord(prefix[:1]) + 1
+        if end == 0x100:
+            end = len(self)
+        else:
+            end = self._fan_out_table[end]
+        assert start <= end
+        started = False
+        for i in range(start, end):
+            name = self._unpack_name(i)
+            if name.startswith(prefix):
+                yield name
+                started = True
+            elif started:
+                break
+
 
 class PackIndex1(FilePackIndex):
     """Version 1 Pack Index file."""

+ 11 - 0
dulwich/tests/test_object_store.py

@@ -236,6 +236,17 @@ class ObjectStoreTests:
         self.store.add_object(testobject)
         self.store.close()
 
+    def test_iter_prefix(self):
+        self.store.add_object(testobject)
+        self.assertEqual([testobject.id], list(self.store.iter_prefix(testobject.id)))
+        self.assertEqual(
+            [testobject.id], list(self.store.iter_prefix(testobject.id[:10]))
+        )
+        self.assertEqual(
+            [testobject.id], list(self.store.iter_prefix(testobject.id[:4]))
+        )
+        self.assertEqual([testobject.id], list(self.store.iter_prefix(b"")))
+
 
 class PackBasedObjectStoreTests(ObjectStoreTests):
     def tearDown(self):

+ 10 - 0
tests/test_pack.py

@@ -124,6 +124,16 @@ class PackIndexTests(PackTests):
         self.assertEqual(p.object_sha1(138), hex_to_sha(tree_sha))
         self.assertEqual(p.object_sha1(12), hex_to_sha(commit_sha))
 
+    def test_iter_prefix(self):
+        p = self.get_pack_index(pack1_sha)
+        self.assertEqual([p.object_sha1(178)], list(p.iter_prefix(hex_to_sha(a_sha))))
+        self.assertEqual(
+            [p.object_sha1(178)], list(p.iter_prefix(hex_to_sha(a_sha)[:5]))
+        )
+        self.assertEqual(
+            [p.object_sha1(178)], list(p.iter_prefix(hex_to_sha(a_sha)[:2]))
+        )
+
     def test_index_len(self):
         p = self.get_pack_index(pack1_sha)
         self.assertEqual(3, len(p))