Browse Source

Add Pack.iterobjects_subset.

Jelmer Vernooij 2 years ago
parent
commit
4031620b5c
2 changed files with 127 additions and 39 deletions
  1. 104 34
      dulwich/pack.py
  2. 23 5
      dulwich/tests/test_pack.py

+ 104 - 34
dulwich/pack.py

@@ -35,6 +35,7 @@ a pointer in to the corresponding packfile.
 from collections import defaultdict
 
 import binascii
+from contextlib import suppress
 from io import BytesIO, UnsupportedOperation
 from collections import (
     deque,
@@ -49,7 +50,7 @@ from itertools import chain
 
 import os
 import sys
-from typing import Optional, Callable, Tuple, List, Deque, Union, Protocol, Iterable, Iterator, Dict
+from typing import Optional, Callable, Tuple, List, Deque, Union, Protocol, Iterable, Iterator, Dict, TypeVar, Generic
 import warnings
 
 from hashlib import sha1
@@ -438,13 +439,17 @@ class PackIndex:
         raise NotImplementedError(self.get_pack_checksum)
 
     def object_index(self, sha: bytes) -> int:
-        """Return the index in to the corresponding packfile for the object.
+        warnings.warn('Please use object_offset instead', DeprecationWarning, stacklevel=2)
+        return self.object_offset(sha)
+
+    def object_offset(self, sha: bytes) -> int:
+        """Return the offset in to the corresponding packfile for the object.
 
         Given the name of an object it will return the offset that object
         lives at within the corresponding pack file. If the pack file doesn't
         have the object then None will be returned.
         """
-        raise NotImplementedError(self.object_index)
+        raise NotImplementedError(self.object_offset)
 
     def object_sha1(self, index: int) -> bytes:
         """Return the SHA1 corresponding to the index in the pack file."""
@@ -455,13 +460,13 @@ class PackIndex:
         else:
             raise KeyError(index)
 
-    def _object_index(self, sha: bytes) -> int:
-        """See object_index.
+    def _object_offset(self, sha: bytes) -> int:
+        """See object_offset.
 
         Args:
           sha: A *binary* SHA string. (20 characters long)_
         """
-        raise NotImplementedError(self._object_index)
+        raise NotImplementedError(self._object_offset)
 
     def objects_sha1(self) -> bytes:
         """Return the hex SHA1 over all the shas of all objects in this pack.
@@ -492,10 +497,10 @@ class MemoryPackIndex(PackIndex):
           pack_checksum: Optional pack checksum
         """
         self._by_sha = {}
-        self._by_index = {}
-        for name, idx, crc32 in entries:
-            self._by_sha[name] = idx
-            self._by_index[idx] = name
+        self._by_offset = {}
+        for name, offset, crc32 in entries:
+            self._by_sha[name] = offset
+            self._by_offset[offset] = name
         self._entries = entries
         self._pack_checksum = pack_checksum
 
@@ -505,13 +510,13 @@ class MemoryPackIndex(PackIndex):
     def __len__(self):
         return len(self._entries)
 
-    def object_index(self, sha):
+    def object_offset(self, sha):
         if len(sha) == 40:
             sha = hex_to_sha(sha)
-        return self._by_sha[sha][0]
+        return self._by_sha[sha]
 
-    def object_sha1(self, index):
-        return self._by_index[index]
+    def object_sha1(self, offset):
+        return self._by_offset[offset]
 
     def _itersha(self):
         return iter(self._by_sha)
@@ -519,6 +524,10 @@ class MemoryPackIndex(PackIndex):
     def iterentries(self):
         return iter(self._entries)
 
+    @classmethod
+    def for_pack(cls, pack):
+        return MemoryPackIndex(pack.sorted_entries(), pack.calculate_checksum())
+
 
 class FilePackIndex(PackIndex):
     """Pack index that is based on a file.
@@ -645,8 +654,8 @@ class FilePackIndex(PackIndex):
         """
         return bytes(self._contents[-20:])
 
-    def object_index(self, sha: bytes) -> int:
-        """Return the index in to the corresponding packfile for the object.
+    def object_offset(self, sha: bytes) -> int:
+        """Return the offset in to the corresponding packfile for the object.
 
         Given the name of an object it will return the offset that object
         lives at within the corresponding pack file. If the pack file doesn't
@@ -655,15 +664,15 @@ class FilePackIndex(PackIndex):
         if len(sha) == 40:
             sha = hex_to_sha(sha)
         try:
-            return self._object_index(sha)
+            return self._object_offset(sha)
         except ValueError as exc:
             closed = getattr(self._contents, "closed", None)
             if closed in (None, True):
                 raise PackFileDisappeared(self) from exc
             raise
 
-    def _object_index(self, sha: bytes) -> int:
-        """See object_index.
+    def _object_offset(self, sha: bytes) -> int:
+        """See object_offset.
 
         Args:
           sha: A *binary* SHA string. (20 characters long)_
@@ -1208,7 +1217,7 @@ class PackData:
             # Back up over unused data.
             self._file.seek(-len(unused), SEEK_CUR)
 
-    def _iter_unpacked(self):
+    def iter_unpacked(self, *, include_comp: bool = False):
         # TODO(dborowitz): Merge this with iterobjects, if we can change its
         # return type.
         self._file.seek(self._header_size)
@@ -1218,7 +1227,7 @@ class PackData:
 
         for _ in range(self._num_objects):
             offset = self._file.tell()
-            unpacked, unused = unpack_object(self._file.read, compute_crc32=False)
+            unpacked, unused = unpack_object(self._file.read, compute_crc32=False, include_comp=include_comp)
             unpacked.offset = offset
             yield unpacked
             # Back up over unused data.
@@ -1311,6 +1320,7 @@ class PackData:
         assert offset >= self._header_size
         self._file.seek(offset)
         unpacked, _ = unpack_object(self._file.read, include_comp=include_comp)
+        unpacked.offset = offset
         return unpacked
 
     def get_compressed_data_at(self, offset):
@@ -1356,7 +1366,10 @@ class PackData:
         return (unpacked.pack_type_num, unpacked._obj())
 
 
-class DeltaChainIterator:
+T = TypeVar('T')
+
+
+class DeltaChainIterator(Generic[T]):
     """Abstract iterator over pack data based on delta chains.
 
     Each object in the pack is guaranteed to be inflated exactly once,
@@ -1392,8 +1405,39 @@ class DeltaChainIterator:
     def for_pack_data(cls, pack_data: PackData, resolve_ext_ref=None):
         walker = cls(None, resolve_ext_ref=resolve_ext_ref)
         walker.set_pack_data(pack_data)
-        for unpacked in pack_data._iter_unpacked():
+        for unpacked in pack_data.iter_unpacked():
+            walker.record(unpacked)
+        return walker
+
+    @classmethod
+    def for_pack_subset(
+            cls, pack: "Pack", shas: Iterable[bytes], *,
+            allow_missing: bool = False, resolve_ext_ref=None):
+        walker = cls(None, resolve_ext_ref=resolve_ext_ref)
+        walker.set_pack_data(pack.data)
+        todo = set()
+        for sha in shas:
+            assert isinstance(sha, bytes)
+            try:
+                off = pack.index.object_offset(sha)
+            except KeyError:
+                if not allow_missing:
+                    raise
+            todo.add(off)
+        done = set()
+        while todo:
+            off = todo.pop()
+            unpacked = pack.data.get_unpacked_object_at(off)
             walker.record(unpacked)
+            done.add(off)
+            base_ofs = None
+            if unpacked.pack_type_num == OFS_DELTA:
+                base_ofs = unpacked.offset - unpacked.delta_base
+            elif unpacked.pack_type_num == REF_DELTA:
+                with suppress(KeyError):
+                    base_ofs = pack.index.object_index(unpacked.delta_base)
+            if base_ofs is not None and base_ofs not in done:
+                todo.add(base_ofs)
         return walker
 
     def record(self, unpacked: UnpackedObject) -> None:
@@ -1414,7 +1458,7 @@ class DeltaChainIterator:
         for offset, type_num in self._full_ofs:
             yield from self._follow_chain(offset, type_num, None)
         yield from self._walk_ref_chains()
-        assert not self._pending_ofs
+        assert not self._pending_ofs, repr(self._pending_ofs)
 
     def _ensure_no_pending(self) -> None:
         if self._pending_ref:
@@ -1442,8 +1486,8 @@ class DeltaChainIterator:
 
         self._ensure_no_pending()
 
-    def _result(self, unpacked):
-        return unpacked
+    def _result(self, unpacked: UnpackedObject) -> T:
+        raise NotImplementedError
 
     def _resolve_object(self, offset: int, obj_type_num: int, base_chunks: List[bytes]) -> UnpackedObject:
         self._file.seek(offset)
@@ -1479,14 +1523,21 @@ class DeltaChainIterator:
                 for new_offset in unblocked
             )
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[T]:
         return self._walk_all_chains()
 
     def ext_refs(self):
         return self._ext_refs
 
 
-class PackIndexer(DeltaChainIterator):
+class UnpackedObjectIterator(DeltaChainIterator[UnpackedObject]):
+    """Delta chain iterator that yield unpacked objects."""
+
+    def _result(self, unpacked):
+        return unpacked
+
+
+class PackIndexer(DeltaChainIterator[PackIndexEntry]):
     """Delta chain iterator that yields index entries."""
 
     _compute_crc32 = True
@@ -1495,7 +1546,7 @@ class PackIndexer(DeltaChainIterator):
         return unpacked.sha(), unpacked.offset, unpacked.crc32
 
 
-class PackInflater(DeltaChainIterator):
+class PackInflater(DeltaChainIterator[ShaFile]):
     """Delta chain iterator that yields ShaFile objects."""
 
     def _result(self, unpacked):
@@ -2263,7 +2314,7 @@ class Pack:
     def __contains__(self, sha1: bytes) -> bool:
         """Check whether this pack contains a particular SHA1."""
         try:
-            self.index.object_index(sha1)
+            self.index.object_offset(sha1)
             return True
         except KeyError:
             return False
@@ -2276,7 +2327,7 @@ class Pack:
         Returns: Tuple with pack object type, delta base (if applicable),
             list of data chunks
         """
-        offset = self.index.object_index(sha1)
+        offset = self.index.object_offset(sha1)
         (obj_type, delta_base, chunks) = self.data.get_compressed_data_at(offset)
         if obj_type == OFS_DELTA:
             delta_base = sha_to_hex(self.index.object_sha1(offset - delta_base))
@@ -2284,7 +2335,7 @@ class Pack:
         return (obj_type, delta_base, chunks)
 
     def get_raw(self, sha1: bytes) -> Tuple[int, bytes]:
-        offset = self.index.object_index(sha1)
+        offset = self.index.object_offset(sha1)
         obj_type, obj = self.data.get_object_at(offset)
         type_num, chunks = self.resolve_object(offset, obj_type, obj)
         return type_num, b"".join(chunks)
@@ -2294,12 +2345,21 @@ class Pack:
         type, uncomp = self.get_raw(sha1)
         return ShaFile.from_raw_string(type, uncomp, sha=sha1)
 
-    def iterobjects(self):
+    def iterobjects(self) -> Iterator[ShaFile]:
         """Iterate over the objects in this pack."""
         return iter(
             PackInflater.for_pack_data(self.data, resolve_ext_ref=self.resolve_ext_ref)
         )
 
+    def iterobjects_subset(self, shas, *, allow_missing: bool = False) -> Iterator[ShaFile]:
+        return (
+            uo
+            for uo in
+            PackInflater.for_pack_subset(
+                self, shas, allow_missing=allow_missing,
+                resolve_ext_ref=self.resolve_ext_ref)
+            if uo.sha() in shas)
+
     def pack_tuples(self):
         """Provide an iterable for use with write_pack_objects.
 
@@ -2328,7 +2388,7 @@ class Pack:
         """Get the object for a ref SHA, only looking in this pack."""
         # TODO: cache these results
         try:
-            offset = self.index.object_index(sha)
+            offset = self.index.object_offset(sha)
         except KeyError:
             offset = None
         if offset:
@@ -2403,6 +2463,16 @@ class Pack:
         return self.data.sorted_entries(
             progress=progress, resolve_ext_ref=self.resolve_ext_ref)
 
+    def get_unpacked_object(self, sha: bytes, *, include_comp: bool = False) -> UnpackedObject:
+        """Get the unpacked object for a sha.
+
+        Args:
+          sha: SHA of object to fetch
+          include_comp: Whether to include compression data in UnpackedObject
+        """
+        offset = self.index.object_offset(sha)
+        return self.data.get_unpacked_object_at(offset, include_comp=include_comp)
+
 
 try:
     from dulwich._pack import (  # type: ignore # noqa: F811

+ 23 - 5
dulwich/tests/test_pack.py

@@ -122,13 +122,13 @@ class PackTests(TestCase):
 class PackIndexTests(PackTests):
     """Class that tests the index of packfiles"""
 
-    def test_object_index(self):
+    def test_object_offset(self):
         """Tests that the correct object offset is returned from the index."""
         p = self.get_pack_index(pack1_sha)
-        self.assertRaises(KeyError, p.object_index, pack1_sha)
-        self.assertEqual(p.object_index(a_sha), 178)
-        self.assertEqual(p.object_index(tree_sha), 138)
-        self.assertEqual(p.object_index(commit_sha), 12)
+        self.assertRaises(KeyError, p.object_offset, pack1_sha)
+        self.assertEqual(p.object_offset(a_sha), 178)
+        self.assertEqual(p.object_offset(tree_sha), 138)
+        self.assertEqual(p.object_offset(commit_sha), 12)
 
     def test_object_sha1(self):
         """Tests that the correct object offset is returned from the index."""
@@ -1006,6 +1006,16 @@ class DeltaChainIteratorTests(TestCase):
         data = PackData("test.pack", file=f)
         return TestPackIterator.for_pack_data(data, resolve_ext_ref=resolve_ext_ref)
 
+    def make_pack_iter_subset(self, f, subset, thin=None):
+        if thin is None:
+            thin = bool(list(self.store))
+        resolve_ext_ref = thin and self.get_raw_no_repeat or None
+        data = PackData("test.pack", file=f)
+        assert data
+        index = MemoryPackIndex.for_pack(data)
+        pack = Pack.from_objects(data, index)
+        return TestPackIterator.for_pack_subset(pack, subset, resolve_ext_ref=resolve_ext_ref)
+
     def assertEntriesMatch(self, expected_indexes, entries, pack_iter):
         expected = [entries[i] for i in expected_indexes]
         self.assertEqual(expected, list(pack_iter._walk_all_chains()))
@@ -1021,6 +1031,10 @@ class DeltaChainIteratorTests(TestCase):
             ],
         )
         self.assertEntriesMatch([0, 1, 2], entries, self.make_pack_iter(f))
+        f.seek(0)
+        self.assertEntriesMatch([], entries, self.make_pack_iter_subset(f, []))
+        f.seek(0)
+        self.assertEntriesMatch([1, 0], entries, self.make_pack_iter_subset(f, [entries[0][3], entries[1][3]]))
 
     def test_ofs_deltas(self):
         f = BytesIO()
@@ -1034,6 +1048,10 @@ class DeltaChainIteratorTests(TestCase):
         )
         # Delta resolution changed to DFS
         self.assertEntriesMatch([0, 2, 1], entries, self.make_pack_iter(f))
+        f.seek(0)
+        self.assertEntriesMatch(
+            [0, 2, 1], entries,
+            self.make_pack_iter_subset(f, [entries[1][3], entries[2][3]]))
 
     def test_ofs_deltas_chain(self):
         f = BytesIO()