浏览代码

Add ObjectStore.iter_prefix (#1402)

This allow scanning an object store for a specific SHA prefix. The
default implementation is naive and scans the entire key space, but
specific implementations have more efficient approaches, improving
performance.,
Jelmer Vernooij 10 月之前
父节点
当前提交
8ccaacf7a3
共有 6 个文件被更改,包括 96 次插入8 次删除
  1. 2 0
      NEWS
  2. 43 0
      dulwich/object_store.py
  3. 8 8
      dulwich/objectspec.py
  4. 22 0
      dulwich/pack.py
  5. 11 0
      dulwich/tests/test_object_store.py
  6. 10 0
      tests/test_pack.py

+ 2 - 0
NEWS

@@ -3,6 +3,8 @@
  * Fix handling of symrefs with protocol v2.
  * Fix handling of symrefs with protocol v2.
    (Jelmer Vernooij, #1389)
    (Jelmer Vernooij, #1389)
 
 
+ * Add ``ObjectStore.iter_prefix``.  (Jelmer Vernooij)
+
 0.22.3	2024-10-15
 0.22.3	2024-10-15
 
 
  * Improve wheel building in CI, so we can upload wheels for the next release.
  * Improve wheel building in CI, so we can upload wheels for the next release.

+ 43 - 0
dulwich/object_store.py

@@ -22,6 +22,7 @@
 
 
 """Git object store interfaces and implementation."""
 """Git object store interfaces and implementation."""
 
 
+import binascii
 import os
 import os
 import stat
 import stat
 import sys
 import sys
@@ -358,6 +359,17 @@ class BaseObjectStore:
         """Close any files opened by this object store."""
         """Close any files opened by this object store."""
         # Default implementation is a NO-OP
         # Default implementation is a NO-OP
 
 
+    def iter_prefix(self, prefix: bytes) -> Iterator[ObjectID]:
+        """Iterate over all SHA1s that start with a given prefix.
+
+        The default implementation is a naive iteration over all objects.
+        However, subclasses may override this method with more efficient
+        implementations.
+        """
+        for sha in self:
+            if sha.startswith(prefix):
+                yield sha
+
 
 
 class PackBasedObjectStore(BaseObjectStore):
 class PackBasedObjectStore(BaseObjectStore):
     def __init__(self, pack_compression_level=-1) -> None:
     def __init__(self, pack_compression_level=-1) -> None:
@@ -1027,6 +1039,37 @@ class DiskObjectStore(PackBasedObjectStore):
         os.mkdir(os.path.join(path, PACKDIR))
         os.mkdir(os.path.join(path, PACKDIR))
         return cls(path)
         return cls(path)
 
 
+    def iter_prefix(self, prefix):
+        if len(prefix) < 2:
+            yield from super().iter_prefix(prefix)
+            return
+        seen = set()
+        dir = prefix[:2].decode()
+        rest = prefix[2:].decode()
+        for name in os.listdir(os.path.join(self.path, dir)):
+            if name.startswith(rest):
+                sha = os.fsencode(dir + name)
+                if sha not in seen:
+                    seen.add(sha)
+                    yield sha
+
+        for p in self.packs:
+            bin_prefix = (
+                binascii.unhexlify(prefix)
+                if len(prefix) % 2 == 0
+                else binascii.unhexlify(prefix[:-1])
+            )
+            for sha in p.index.iter_prefix(bin_prefix):
+                sha = sha_to_hex(sha)
+                if sha.startswith(prefix) and sha not in seen:
+                    seen.add(sha)
+                    yield sha
+        for alternate in self.alternates:
+            for sha in alternate.iter_prefix(prefix):
+                if sha not in seen:
+                    seen.add(sha)
+                    yield sha
+
 
 
 class MemoryObjectStore(BaseObjectStore):
 class MemoryObjectStore(BaseObjectStore):
     """Object store that keeps all objects in memory."""
     """Object store that keeps all objects in memory."""

+ 8 - 8
dulwich/objectspec.py

@@ -22,8 +22,9 @@
 
 
 from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union
 from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union
 
 
+from .objects import Commit, ShaFile, Tree
+
 if TYPE_CHECKING:
 if TYPE_CHECKING:
-    from .objects import Commit, ShaFile, Tree
     from .refs import Ref, RefsContainer
     from .refs import Ref, RefsContainer
     from .repo import Repo
     from .repo import Repo
 
 
@@ -209,14 +210,13 @@ class AmbiguousShortId(Exception):
         self.options = options
         self.options = options
 
 
 
 
-def scan_for_short_id(object_store, prefix):
+def scan_for_short_id(object_store, prefix, tp):
     """Scan an object store for a short id."""
     """Scan an object store for a short id."""
-    # TODO(jelmer): This could short-circuit looking for objects
-    # starting with a certain prefix.
     ret = []
     ret = []
-    for object_id in object_store:
-        if object_id.startswith(prefix):
-            ret.append(object_store[object_id])
+    for object_id in object_store.iter_prefix(prefix):
+        o = object_store[object_id]
+        if isinstance(o, tp):
+            ret.append(o)
     if not ret:
     if not ret:
         raise KeyError(prefix)
         raise KeyError(prefix)
     if len(ret) == 1:
     if len(ret) == 1:
@@ -251,7 +251,7 @@ def parse_commit(repo: "Repo", committish: Union[str, bytes]) -> "Commit":
             pass
             pass
         else:
         else:
             try:
             try:
-                return scan_for_short_id(repo.object_store, committish)
+                return scan_for_short_id(repo.object_store, committish, Commit)
             except KeyError:
             except KeyError:
                 pass
                 pass
     raise KeyError(committish)
     raise KeyError(committish)

+ 22 - 0
dulwich/pack.py

@@ -746,6 +746,28 @@ class FilePackIndex(PackIndex):
             raise KeyError(sha)
             raise KeyError(sha)
         return self._unpack_offset(i)
         return self._unpack_offset(i)
 
 
+    def iter_prefix(self, prefix: bytes) -> Iterator[bytes]:
+        """Iterate over all SHA1s with the given prefix."""
+        start = ord(prefix[:1])
+        if start == 0:
+            start = 0
+        else:
+            start = self._fan_out_table[start - 1]
+        end = ord(prefix[:1]) + 1
+        if end == 0x100:
+            end = len(self)
+        else:
+            end = self._fan_out_table[end]
+        assert start <= end
+        started = False
+        for i in range(start, end):
+            name = self._unpack_name(i)
+            if name.startswith(prefix):
+                yield name
+                started = True
+            elif started:
+                break
+
 
 
 class PackIndex1(FilePackIndex):
 class PackIndex1(FilePackIndex):
     """Version 1 Pack Index file."""
     """Version 1 Pack Index file."""

+ 11 - 0
dulwich/tests/test_object_store.py

@@ -236,6 +236,17 @@ class ObjectStoreTests:
         self.store.add_object(testobject)
         self.store.add_object(testobject)
         self.store.close()
         self.store.close()
 
 
+    def test_iter_prefix(self):
+        self.store.add_object(testobject)
+        self.assertEqual([testobject.id], list(self.store.iter_prefix(testobject.id)))
+        self.assertEqual(
+            [testobject.id], list(self.store.iter_prefix(testobject.id[:10]))
+        )
+        self.assertEqual(
+            [testobject.id], list(self.store.iter_prefix(testobject.id[:4]))
+        )
+        self.assertEqual([testobject.id], list(self.store.iter_prefix(b"")))
+
 
 
 class PackBasedObjectStoreTests(ObjectStoreTests):
 class PackBasedObjectStoreTests(ObjectStoreTests):
     def tearDown(self):
     def tearDown(self):

+ 10 - 0
tests/test_pack.py

@@ -124,6 +124,16 @@ class PackIndexTests(PackTests):
         self.assertEqual(p.object_sha1(138), hex_to_sha(tree_sha))
         self.assertEqual(p.object_sha1(138), hex_to_sha(tree_sha))
         self.assertEqual(p.object_sha1(12), hex_to_sha(commit_sha))
         self.assertEqual(p.object_sha1(12), hex_to_sha(commit_sha))
 
 
+    def test_iter_prefix(self):
+        p = self.get_pack_index(pack1_sha)
+        self.assertEqual([p.object_sha1(178)], list(p.iter_prefix(hex_to_sha(a_sha))))
+        self.assertEqual(
+            [p.object_sha1(178)], list(p.iter_prefix(hex_to_sha(a_sha)[:5]))
+        )
+        self.assertEqual(
+            [p.object_sha1(178)], list(p.iter_prefix(hex_to_sha(a_sha)[:2]))
+        )
+
     def test_index_len(self):
     def test_index_len(self):
         p = self.get_pack_index(pack1_sha)
         p = self.get_pack_index(pack1_sha)
         self.assertEqual(3, len(p))
         self.assertEqual(3, len(p))