2
0
Эх сурвалжийг харах

Add Pack.iterobjects_subset.

Jelmer Vernooij 2 жил өмнө
parent
commit
4031620b5c

+ 104 - 34
dulwich/pack.py

@@ -35,6 +35,7 @@ a pointer in to the corresponding packfile.
 from collections import defaultdict
 from collections import defaultdict
 
 
 import binascii
 import binascii
+from contextlib import suppress
 from io import BytesIO, UnsupportedOperation
 from io import BytesIO, UnsupportedOperation
 from collections import (
 from collections import (
     deque,
     deque,
@@ -49,7 +50,7 @@ from itertools import chain
 
 
 import os
 import os
 import sys
 import sys
-from typing import Optional, Callable, Tuple, List, Deque, Union, Protocol, Iterable, Iterator, Dict
+from typing import Optional, Callable, Tuple, List, Deque, Union, Protocol, Iterable, Iterator, Dict, TypeVar, Generic
 import warnings
 import warnings
 
 
 from hashlib import sha1
 from hashlib import sha1
@@ -438,13 +439,17 @@ class PackIndex:
         raise NotImplementedError(self.get_pack_checksum)
         raise NotImplementedError(self.get_pack_checksum)
 
 
     def object_index(self, sha: bytes) -> int:
     def object_index(self, sha: bytes) -> int:
-        """Return the index in to the corresponding packfile for the object.
+        warnings.warn('Please use object_offset instead', DeprecationWarning, stacklevel=2)
+        return self.object_offset(sha)
+
+    def object_offset(self, sha: bytes) -> int:
+        """Return the offset in to the corresponding packfile for the object.
 
 
         Given the name of an object it will return the offset that object
         Given the name of an object it will return the offset that object
         lives at within the corresponding pack file. If the pack file doesn't
         lives at within the corresponding pack file. If the pack file doesn't
         have the object then None will be returned.
         have the object then None will be returned.
         """
         """
-        raise NotImplementedError(self.object_index)
+        raise NotImplementedError(self.object_offset)
 
 
     def object_sha1(self, index: int) -> bytes:
     def object_sha1(self, index: int) -> bytes:
         """Return the SHA1 corresponding to the index in the pack file."""
         """Return the SHA1 corresponding to the index in the pack file."""
@@ -455,13 +460,13 @@ class PackIndex:
         else:
         else:
             raise KeyError(index)
             raise KeyError(index)
 
 
-    def _object_index(self, sha: bytes) -> int:
-        """See object_index.
+    def _object_offset(self, sha: bytes) -> int:
+        """See object_offset.
 
 
         Args:
         Args:
           sha: A *binary* SHA string. (20 characters long)_
           sha: A *binary* SHA string. (20 characters long)_
         """
         """
-        raise NotImplementedError(self._object_index)
+        raise NotImplementedError(self._object_offset)
 
 
     def objects_sha1(self) -> bytes:
     def objects_sha1(self) -> bytes:
         """Return the hex SHA1 over all the shas of all objects in this pack.
         """Return the hex SHA1 over all the shas of all objects in this pack.
@@ -492,10 +497,10 @@ class MemoryPackIndex(PackIndex):
           pack_checksum: Optional pack checksum
           pack_checksum: Optional pack checksum
         """
         """
         self._by_sha = {}
         self._by_sha = {}
-        self._by_index = {}
-        for name, idx, crc32 in entries:
-            self._by_sha[name] = idx
-            self._by_index[idx] = name
+        self._by_offset = {}
+        for name, offset, crc32 in entries:
+            self._by_sha[name] = offset
+            self._by_offset[offset] = name
         self._entries = entries
         self._entries = entries
         self._pack_checksum = pack_checksum
         self._pack_checksum = pack_checksum
 
 
@@ -505,13 +510,13 @@ class MemoryPackIndex(PackIndex):
     def __len__(self):
     def __len__(self):
         return len(self._entries)
         return len(self._entries)
 
 
-    def object_index(self, sha):
+    def object_offset(self, sha):
         if len(sha) == 40:
         if len(sha) == 40:
             sha = hex_to_sha(sha)
             sha = hex_to_sha(sha)
-        return self._by_sha[sha][0]
+        return self._by_sha[sha]
 
 
-    def object_sha1(self, index):
-        return self._by_index[index]
+    def object_sha1(self, offset):
+        return self._by_offset[offset]
 
 
     def _itersha(self):
     def _itersha(self):
         return iter(self._by_sha)
         return iter(self._by_sha)
@@ -519,6 +524,10 @@ class MemoryPackIndex(PackIndex):
     def iterentries(self):
     def iterentries(self):
         return iter(self._entries)
         return iter(self._entries)
 
 
+    @classmethod
+    def for_pack(cls, pack):
+        return MemoryPackIndex(pack.sorted_entries(), pack.calculate_checksum())
+
 
 
 class FilePackIndex(PackIndex):
 class FilePackIndex(PackIndex):
     """Pack index that is based on a file.
     """Pack index that is based on a file.
@@ -645,8 +654,8 @@ class FilePackIndex(PackIndex):
         """
         """
         return bytes(self._contents[-20:])
         return bytes(self._contents[-20:])
 
 
-    def object_index(self, sha: bytes) -> int:
-        """Return the index in to the corresponding packfile for the object.
+    def object_offset(self, sha: bytes) -> int:
+        """Return the offset in to the corresponding packfile for the object.
 
 
         Given the name of an object it will return the offset that object
         Given the name of an object it will return the offset that object
         lives at within the corresponding pack file. If the pack file doesn't
         lives at within the corresponding pack file. If the pack file doesn't
@@ -655,15 +664,15 @@ class FilePackIndex(PackIndex):
         if len(sha) == 40:
         if len(sha) == 40:
             sha = hex_to_sha(sha)
             sha = hex_to_sha(sha)
         try:
         try:
-            return self._object_index(sha)
+            return self._object_offset(sha)
         except ValueError as exc:
         except ValueError as exc:
             closed = getattr(self._contents, "closed", None)
             closed = getattr(self._contents, "closed", None)
             if closed in (None, True):
             if closed in (None, True):
                 raise PackFileDisappeared(self) from exc
                 raise PackFileDisappeared(self) from exc
             raise
             raise
 
 
-    def _object_index(self, sha: bytes) -> int:
-        """See object_index.
+    def _object_offset(self, sha: bytes) -> int:
+        """See object_offset.
 
 
         Args:
         Args:
           sha: A *binary* SHA string. (20 characters long)_
           sha: A *binary* SHA string. (20 characters long)_
@@ -1208,7 +1217,7 @@ class PackData:
             # Back up over unused data.
             # Back up over unused data.
             self._file.seek(-len(unused), SEEK_CUR)
             self._file.seek(-len(unused), SEEK_CUR)
 
 
-    def _iter_unpacked(self):
+    def iter_unpacked(self, *, include_comp: bool = False):
         # TODO(dborowitz): Merge this with iterobjects, if we can change its
         # TODO(dborowitz): Merge this with iterobjects, if we can change its
         # return type.
         # return type.
         self._file.seek(self._header_size)
         self._file.seek(self._header_size)
@@ -1218,7 +1227,7 @@ class PackData:
 
 
         for _ in range(self._num_objects):
         for _ in range(self._num_objects):
             offset = self._file.tell()
             offset = self._file.tell()
-            unpacked, unused = unpack_object(self._file.read, compute_crc32=False)
+            unpacked, unused = unpack_object(self._file.read, compute_crc32=False, include_comp=include_comp)
             unpacked.offset = offset
             unpacked.offset = offset
             yield unpacked
             yield unpacked
             # Back up over unused data.
             # Back up over unused data.
@@ -1311,6 +1320,7 @@ class PackData:
         assert offset >= self._header_size
         assert offset >= self._header_size
         self._file.seek(offset)
         self._file.seek(offset)
         unpacked, _ = unpack_object(self._file.read, include_comp=include_comp)
         unpacked, _ = unpack_object(self._file.read, include_comp=include_comp)
+        unpacked.offset = offset
         return unpacked
         return unpacked
 
 
     def get_compressed_data_at(self, offset):
     def get_compressed_data_at(self, offset):
@@ -1356,7 +1366,10 @@ class PackData:
         return (unpacked.pack_type_num, unpacked._obj())
         return (unpacked.pack_type_num, unpacked._obj())
 
 
 
 
-class DeltaChainIterator:
+T = TypeVar('T')
+
+
+class DeltaChainIterator(Generic[T]):
     """Abstract iterator over pack data based on delta chains.
     """Abstract iterator over pack data based on delta chains.
 
 
     Each object in the pack is guaranteed to be inflated exactly once,
     Each object in the pack is guaranteed to be inflated exactly once,
@@ -1392,8 +1405,39 @@ class DeltaChainIterator:
     def for_pack_data(cls, pack_data: PackData, resolve_ext_ref=None):
     def for_pack_data(cls, pack_data: PackData, resolve_ext_ref=None):
         walker = cls(None, resolve_ext_ref=resolve_ext_ref)
         walker = cls(None, resolve_ext_ref=resolve_ext_ref)
         walker.set_pack_data(pack_data)
         walker.set_pack_data(pack_data)
-        for unpacked in pack_data._iter_unpacked():
+        for unpacked in pack_data.iter_unpacked():
+            walker.record(unpacked)
+        return walker
+
+    @classmethod
+    def for_pack_subset(
+            cls, pack: "Pack", shas: Iterable[bytes], *,
+            allow_missing: bool = False, resolve_ext_ref=None):
+        walker = cls(None, resolve_ext_ref=resolve_ext_ref)
+        walker.set_pack_data(pack.data)
+        todo = set()
+        for sha in shas:
+            assert isinstance(sha, bytes)
+            try:
+                off = pack.index.object_offset(sha)
+            except KeyError:
+                if not allow_missing:
+                    raise
+            todo.add(off)
+        done = set()
+        while todo:
+            off = todo.pop()
+            unpacked = pack.data.get_unpacked_object_at(off)
             walker.record(unpacked)
             walker.record(unpacked)
+            done.add(off)
+            base_ofs = None
+            if unpacked.pack_type_num == OFS_DELTA:
+                base_ofs = unpacked.offset - unpacked.delta_base
+            elif unpacked.pack_type_num == REF_DELTA:
+                with suppress(KeyError):
+                    base_ofs = pack.index.object_index(unpacked.delta_base)
+            if base_ofs is not None and base_ofs not in done:
+                todo.add(base_ofs)
         return walker
         return walker
 
 
     def record(self, unpacked: UnpackedObject) -> None:
     def record(self, unpacked: UnpackedObject) -> None:
@@ -1414,7 +1458,7 @@ class DeltaChainIterator:
         for offset, type_num in self._full_ofs:
         for offset, type_num in self._full_ofs:
             yield from self._follow_chain(offset, type_num, None)
             yield from self._follow_chain(offset, type_num, None)
         yield from self._walk_ref_chains()
         yield from self._walk_ref_chains()
-        assert not self._pending_ofs
+        assert not self._pending_ofs, repr(self._pending_ofs)
 
 
     def _ensure_no_pending(self) -> None:
     def _ensure_no_pending(self) -> None:
         if self._pending_ref:
         if self._pending_ref:
@@ -1442,8 +1486,8 @@ class DeltaChainIterator:
 
 
         self._ensure_no_pending()
         self._ensure_no_pending()
 
 
-    def _result(self, unpacked):
-        return unpacked
+    def _result(self, unpacked: UnpackedObject) -> T:
+        raise NotImplementedError
 
 
     def _resolve_object(self, offset: int, obj_type_num: int, base_chunks: List[bytes]) -> UnpackedObject:
     def _resolve_object(self, offset: int, obj_type_num: int, base_chunks: List[bytes]) -> UnpackedObject:
         self._file.seek(offset)
         self._file.seek(offset)
@@ -1479,14 +1523,21 @@ class DeltaChainIterator:
                 for new_offset in unblocked
                 for new_offset in unblocked
             )
             )
 
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[T]:
         return self._walk_all_chains()
         return self._walk_all_chains()
 
 
     def ext_refs(self):
     def ext_refs(self):
         return self._ext_refs
         return self._ext_refs
 
 
 
 
-class PackIndexer(DeltaChainIterator):
+class UnpackedObjectIterator(DeltaChainIterator[UnpackedObject]):
+    """Delta chain iterator that yield unpacked objects."""
+
+    def _result(self, unpacked):
+        return unpacked
+
+
+class PackIndexer(DeltaChainIterator[PackIndexEntry]):
     """Delta chain iterator that yields index entries."""
     """Delta chain iterator that yields index entries."""
 
 
     _compute_crc32 = True
     _compute_crc32 = True
@@ -1495,7 +1546,7 @@ class PackIndexer(DeltaChainIterator):
         return unpacked.sha(), unpacked.offset, unpacked.crc32
         return unpacked.sha(), unpacked.offset, unpacked.crc32
 
 
 
 
-class PackInflater(DeltaChainIterator):
+class PackInflater(DeltaChainIterator[ShaFile]):
     """Delta chain iterator that yields ShaFile objects."""
     """Delta chain iterator that yields ShaFile objects."""
 
 
     def _result(self, unpacked):
     def _result(self, unpacked):
@@ -2263,7 +2314,7 @@ class Pack:
     def __contains__(self, sha1: bytes) -> bool:
     def __contains__(self, sha1: bytes) -> bool:
         """Check whether this pack contains a particular SHA1."""
         """Check whether this pack contains a particular SHA1."""
         try:
         try:
-            self.index.object_index(sha1)
+            self.index.object_offset(sha1)
             return True
             return True
         except KeyError:
         except KeyError:
             return False
             return False
@@ -2276,7 +2327,7 @@ class Pack:
         Returns: Tuple with pack object type, delta base (if applicable),
         Returns: Tuple with pack object type, delta base (if applicable),
             list of data chunks
             list of data chunks
         """
         """
-        offset = self.index.object_index(sha1)
+        offset = self.index.object_offset(sha1)
         (obj_type, delta_base, chunks) = self.data.get_compressed_data_at(offset)
         (obj_type, delta_base, chunks) = self.data.get_compressed_data_at(offset)
         if obj_type == OFS_DELTA:
         if obj_type == OFS_DELTA:
             delta_base = sha_to_hex(self.index.object_sha1(offset - delta_base))
             delta_base = sha_to_hex(self.index.object_sha1(offset - delta_base))
@@ -2284,7 +2335,7 @@ class Pack:
         return (obj_type, delta_base, chunks)
         return (obj_type, delta_base, chunks)
 
 
     def get_raw(self, sha1: bytes) -> Tuple[int, bytes]:
     def get_raw(self, sha1: bytes) -> Tuple[int, bytes]:
-        offset = self.index.object_index(sha1)
+        offset = self.index.object_offset(sha1)
         obj_type, obj = self.data.get_object_at(offset)
         obj_type, obj = self.data.get_object_at(offset)
         type_num, chunks = self.resolve_object(offset, obj_type, obj)
         type_num, chunks = self.resolve_object(offset, obj_type, obj)
         return type_num, b"".join(chunks)
         return type_num, b"".join(chunks)
@@ -2294,12 +2345,21 @@ class Pack:
         type, uncomp = self.get_raw(sha1)
         type, uncomp = self.get_raw(sha1)
         return ShaFile.from_raw_string(type, uncomp, sha=sha1)
         return ShaFile.from_raw_string(type, uncomp, sha=sha1)
 
 
-    def iterobjects(self):
+    def iterobjects(self) -> Iterator[ShaFile]:
         """Iterate over the objects in this pack."""
         """Iterate over the objects in this pack."""
         return iter(
         return iter(
             PackInflater.for_pack_data(self.data, resolve_ext_ref=self.resolve_ext_ref)
             PackInflater.for_pack_data(self.data, resolve_ext_ref=self.resolve_ext_ref)
         )
         )
 
 
+    def iterobjects_subset(self, shas, *, allow_missing: bool = False) -> Iterator[ShaFile]:
+        return (
+            uo
+            for uo in
+            PackInflater.for_pack_subset(
+                self, shas, allow_missing=allow_missing,
+                resolve_ext_ref=self.resolve_ext_ref)
+            if uo.sha() in shas)
+
     def pack_tuples(self):
     def pack_tuples(self):
         """Provide an iterable for use with write_pack_objects.
         """Provide an iterable for use with write_pack_objects.
 
 
@@ -2328,7 +2388,7 @@ class Pack:
         """Get the object for a ref SHA, only looking in this pack."""
         """Get the object for a ref SHA, only looking in this pack."""
         # TODO: cache these results
         # TODO: cache these results
         try:
         try:
-            offset = self.index.object_index(sha)
+            offset = self.index.object_offset(sha)
         except KeyError:
         except KeyError:
             offset = None
             offset = None
         if offset:
         if offset:
@@ -2403,6 +2463,16 @@ class Pack:
         return self.data.sorted_entries(
         return self.data.sorted_entries(
             progress=progress, resolve_ext_ref=self.resolve_ext_ref)
             progress=progress, resolve_ext_ref=self.resolve_ext_ref)
 
 
+    def get_unpacked_object(self, sha: bytes, *, include_comp: bool = False) -> UnpackedObject:
+        """Get the unpacked object for a sha.
+
+        Args:
+          sha: SHA of object to fetch
+          include_comp: Whether to include compression data in UnpackedObject
+        """
+        offset = self.index.object_offset(sha)
+        return self.data.get_unpacked_object_at(offset, include_comp=include_comp)
+
 
 
 try:
 try:
     from dulwich._pack import (  # type: ignore # noqa: F811
     from dulwich._pack import (  # type: ignore # noqa: F811

+ 23 - 5
dulwich/tests/test_pack.py

@@ -122,13 +122,13 @@ class PackTests(TestCase):
 class PackIndexTests(PackTests):
 class PackIndexTests(PackTests):
     """Class that tests the index of packfiles"""
     """Class that tests the index of packfiles"""
 
 
-    def test_object_index(self):
+    def test_object_offset(self):
         """Tests that the correct object offset is returned from the index."""
         """Tests that the correct object offset is returned from the index."""
         p = self.get_pack_index(pack1_sha)
         p = self.get_pack_index(pack1_sha)
-        self.assertRaises(KeyError, p.object_index, pack1_sha)
-        self.assertEqual(p.object_index(a_sha), 178)
-        self.assertEqual(p.object_index(tree_sha), 138)
-        self.assertEqual(p.object_index(commit_sha), 12)
+        self.assertRaises(KeyError, p.object_offset, pack1_sha)
+        self.assertEqual(p.object_offset(a_sha), 178)
+        self.assertEqual(p.object_offset(tree_sha), 138)
+        self.assertEqual(p.object_offset(commit_sha), 12)
 
 
     def test_object_sha1(self):
     def test_object_sha1(self):
         """Tests that the correct object offset is returned from the index."""
         """Tests that the correct object offset is returned from the index."""
@@ -1006,6 +1006,16 @@ class DeltaChainIteratorTests(TestCase):
         data = PackData("test.pack", file=f)
         data = PackData("test.pack", file=f)
         return TestPackIterator.for_pack_data(data, resolve_ext_ref=resolve_ext_ref)
         return TestPackIterator.for_pack_data(data, resolve_ext_ref=resolve_ext_ref)
 
 
+    def make_pack_iter_subset(self, f, subset, thin=None):
+        if thin is None:
+            thin = bool(list(self.store))
+        resolve_ext_ref = thin and self.get_raw_no_repeat or None
+        data = PackData("test.pack", file=f)
+        assert data
+        index = MemoryPackIndex.for_pack(data)
+        pack = Pack.from_objects(data, index)
+        return TestPackIterator.for_pack_subset(pack, subset, resolve_ext_ref=resolve_ext_ref)
+
     def assertEntriesMatch(self, expected_indexes, entries, pack_iter):
     def assertEntriesMatch(self, expected_indexes, entries, pack_iter):
         expected = [entries[i] for i in expected_indexes]
         expected = [entries[i] for i in expected_indexes]
         self.assertEqual(expected, list(pack_iter._walk_all_chains()))
         self.assertEqual(expected, list(pack_iter._walk_all_chains()))
@@ -1021,6 +1031,10 @@ class DeltaChainIteratorTests(TestCase):
             ],
             ],
         )
         )
         self.assertEntriesMatch([0, 1, 2], entries, self.make_pack_iter(f))
         self.assertEntriesMatch([0, 1, 2], entries, self.make_pack_iter(f))
+        f.seek(0)
+        self.assertEntriesMatch([], entries, self.make_pack_iter_subset(f, []))
+        f.seek(0)
+        self.assertEntriesMatch([1, 0], entries, self.make_pack_iter_subset(f, [entries[0][3], entries[1][3]]))
 
 
     def test_ofs_deltas(self):
     def test_ofs_deltas(self):
         f = BytesIO()
         f = BytesIO()
@@ -1034,6 +1048,10 @@ class DeltaChainIteratorTests(TestCase):
         )
         )
         # Delta resolution changed to DFS
         # Delta resolution changed to DFS
         self.assertEntriesMatch([0, 2, 1], entries, self.make_pack_iter(f))
         self.assertEntriesMatch([0, 2, 1], entries, self.make_pack_iter(f))
+        f.seek(0)
+        self.assertEntriesMatch(
+            [0, 2, 1], entries,
+            self.make_pack_iter_subset(f, [entries[1][3], entries[2][3]]))
 
 
     def test_ofs_deltas_chain(self):
     def test_ofs_deltas_chain(self):
         f = BytesIO()
         f = BytesIO()