Browse Source

Merge pull request #1117 from jelmer/pack-fixes

Refactor pack handling
Jelmer Vernooij 2 years ago
parent
commit
c18863d65e

+ 5 - 6
.github/workflows/pythontest.yml

@@ -31,19 +31,17 @@ jobs:
       - name: Install dependencies
       - name: Install dependencies
         run: |
         run: |
           python -m pip install --upgrade pip
           python -m pip install --upgrade pip
-          pip install -U pip coverage flake8 fastimport paramiko urllib3
+          pip install -U ".[fastimport,paramiko,https]"
       - name: Install gpg on supported platforms
       - name: Install gpg on supported platforms
-        run: pip install -U gpg
+        run: pip install -U ".[pgp]"
         if: "matrix.os != 'windows-latest' && matrix.python-version != 'pypy3'"
         if: "matrix.os != 'windows-latest' && matrix.python-version != 'pypy3'"
-      - name: Install mypy
-        run: |
-          pip install -U mypy types-paramiko types-requests
-        if: "matrix.python-version != 'pypy3'"
       - name: Style checks
       - name: Style checks
         run: |
         run: |
+          pip install -U flake8
           python -m flake8
           python -m flake8
       - name: Typing checks
       - name: Typing checks
         run: |
         run: |
+          pip install -U mypy types-paramiko types-requests
           python -m mypy dulwich
           python -m mypy dulwich
         if: "matrix.python-version != 'pypy3'"
         if: "matrix.python-version != 'pypy3'"
       - name: Build
       - name: Build
@@ -51,4 +49,5 @@ jobs:
           python setup.py build_ext -i
           python setup.py build_ext -i
       - name: Coverage test suite run
       - name: Coverage test suite run
         run: |
         run: |
+          pip install -U coverage
           python -m coverage run -p -m unittest dulwich.tests.test_suite
           python -m coverage run -p -m unittest dulwich.tests.test_suite

+ 9 - 1
NEWS

@@ -1,4 +1,12 @@
-0.20.51	UNRELEASED
+0.21.0	UNRELEASED
+
+ * Pack internals have been significantly refactored, including
+   significant low-level API changes.
+
+   As a consequence of this, Dulwich now reuses pack deltas
+   when communicating with remote servers, which brings a
+   big boost to network performance.
+   (Jelmer Vernooij)
 
 
 0.20.50	2022-10-30
 0.20.50	2022-10-30
 
 

+ 1 - 0
docs/tutorial/remote.txt

@@ -55,6 +55,7 @@ which claims that the client doesn't have any objects::
 
 
    >>> class DummyGraphWalker(object):
    >>> class DummyGraphWalker(object):
    ...     def ack(self, sha): pass
    ...     def ack(self, sha): pass
+   ...     def nak(self): pass
    ...     def next(self): pass
    ...     def next(self): pass
    ...     def __next__(self): pass
    ...     def __next__(self): pass
 
 

+ 1 - 1
dulwich/__init__.py

@@ -22,4 +22,4 @@
 
 
 """Python implementation of the Git file formats and protocols."""
 """Python implementation of the Git file formats and protocols."""
 
 
-__version__ = (0, 20, 50)
+__version__ = (0, 21, 0)

+ 7 - 1
dulwich/bundle.py

@@ -34,6 +34,12 @@ class Bundle:
     references: Dict[str, bytes] = {}
     references: Dict[str, bytes] = {}
     pack_data: Union[PackData, Sequence[bytes]] = []
     pack_data: Union[PackData, Sequence[bytes]] = []
 
 
+    def __repr__(self):
+        return (f"<{type(self).__name__}(version={self.version}, "
+                f"capabilities={self.capabilities}, "
+                f"prerequisites={self.prerequisites}, "
+                f"references={self.references})>")
+
     def __eq__(self, other):
     def __eq__(self, other):
         if not isinstance(other, type(self)):
         if not isinstance(other, type(self)):
             return False
             return False
@@ -119,4 +125,4 @@ def write_bundle(f, bundle):
     for ref, obj_id in bundle.references.items():
     for ref, obj_id in bundle.references.items():
         f.write(b"%s %s\n" % (obj_id, ref))
         f.write(b"%s %s\n" % (obj_id, ref))
     f.write(b"\n")
     f.write(b"\n")
-    write_pack_data(f.write, records=bundle.pack_data)
+    write_pack_data(f.write, num_records=len(bundle.pack_data), records=bundle.pack_data.iter_unpacked())

+ 6 - 3
dulwich/cli.py

@@ -37,7 +37,7 @@ import signal
 from typing import Dict, Type, Optional
 from typing import Dict, Type, Optional
 
 
 from dulwich import porcelain
 from dulwich import porcelain
-from dulwich.client import get_transport_and_path
+from dulwich.client import get_transport_and_path, GitProtocolError
 from dulwich.errors import ApplyDeltaError
 from dulwich.errors import ApplyDeltaError
 from dulwich.index import Index
 from dulwich.index import Index
 from dulwich.objectspec import parse_commit
 from dulwich.objectspec import parse_commit
@@ -263,8 +263,11 @@ class cmd_clone(Command):
         else:
         else:
             target = None
             target = None
 
 
-        porcelain.clone(source, target, bare=options.bare, depth=options.depth,
-                        branch=options.branch)
+        try:
+            porcelain.clone(source, target, bare=options.bare, depth=options.depth,
+                            branch=options.branch)
+        except GitProtocolError as e:
+            print("%s" % e)
 
 
 
 
 class cmd_commit(Command):
 class cmd_commit(Command):

+ 18 - 9
dulwich/client.py

@@ -51,6 +51,8 @@ from typing import (
     Callable,
     Callable,
     Dict,
     Dict,
     List,
     List,
+    Iterable,
+    Iterator,
     Optional,
     Optional,
     Set,
     Set,
     Tuple,
     Tuple,
@@ -117,7 +119,8 @@ from dulwich.protocol import (
     pkt_line,
     pkt_line,
 )
 )
 from dulwich.pack import (
 from dulwich.pack import (
-    write_pack_objects,
+    write_pack_from_container,
+    UnpackedObject,
     PackChunkGenerator,
     PackChunkGenerator,
 )
 )
 from dulwich.refs import (
 from dulwich.refs import (
@@ -494,7 +497,7 @@ class _v1ReceivePackHeader:
         yield None
         yield None
 
 
 
 
-def _read_side_band64k_data(pkt_seq, channel_callbacks):
+def _read_side_band64k_data(pkt_seq: Iterable[bytes], channel_callbacks: Dict[int, Callable[[bytes], None]]) -> None:
     """Read per-channel data.
     """Read per-channel data.
 
 
     This requires the side-band-64k capability.
     This requires the side-band-64k capability.
@@ -587,9 +590,9 @@ def _handle_upload_pack_head(
 
 
 def _handle_upload_pack_tail(
 def _handle_upload_pack_tail(
     proto,
     proto,
-    capabilities,
+    capabilities: Set[bytes],
     graph_walker,
     graph_walker,
-    pack_data,
+    pack_data: Callable[[bytes], None],
     progress=None,
     progress=None,
     rbufsize=_RBUFSIZE,
     rbufsize=_RBUFSIZE,
 ):
 ):
@@ -611,6 +614,8 @@ def _handle_upload_pack_tail(
         parts = pkt.rstrip(b"\n").split(b" ")
         parts = pkt.rstrip(b"\n").split(b" ")
         if parts[0] == b"ACK":
         if parts[0] == b"ACK":
             graph_walker.ack(parts[1])
             graph_walker.ack(parts[1])
+        if parts[0] == b"NAK":
+            graph_walker.nak()
         if len(parts) < 3 or parts[2] not in (
         if len(parts) < 3 or parts[2] not in (
             b"ready",
             b"ready",
             b"continue",
             b"continue",
@@ -699,7 +704,7 @@ class GitClient:
         """
         """
         raise NotImplementedError(cls.from_parsedurl)
         raise NotImplementedError(cls.from_parsedurl)
 
 
-    def send_pack(self, path, update_refs, generate_pack_data, progress=None):
+    def send_pack(self, path, update_refs, generate_pack_data: Callable[[Set[bytes], Set[bytes], bool], Tuple[int, Iterator[UnpackedObject]]], progress=None):
         """Upload a pack to a remote repository.
         """Upload a pack to a remote repository.
 
 
         Args:
         Args:
@@ -855,6 +860,7 @@ class GitClient:
         determine_wants,
         determine_wants,
         graph_walker,
         graph_walker,
         pack_data,
         pack_data,
+        *,
         progress=None,
         progress=None,
         depth=None,
         depth=None,
     ):
     ):
@@ -1106,10 +1112,11 @@ class TraditionalGitClient(GitClient):
                 header_handler.have,
                 header_handler.have,
                 header_handler.want,
                 header_handler.want,
                 ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities),
                 ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities),
+                progress=progress,
             )
             )
 
 
             if self._should_send_pack(new_refs):
             if self._should_send_pack(new_refs):
-                for chunk in PackChunkGenerator(pack_data_count, pack_data):
+                for chunk in PackChunkGenerator(pack_data_count, pack_data, progress=progress):
                     proto.write(chunk)
                     proto.write(chunk)
 
 
             ref_status = self._handle_receive_pack_tail(
             ref_status = self._handle_receive_pack_tail(
@@ -1533,17 +1540,19 @@ class LocalGitClient(GitClient):
 
 
         """
         """
         with self._open_repo(path) as r:
         with self._open_repo(path) as r:
-            objects_iter = r.fetch_objects(
+            missing_objects = r.find_missing_objects(
                 determine_wants, graph_walker, progress=progress, depth=depth
                 determine_wants, graph_walker, progress=progress, depth=depth
             )
             )
+            other_haves = missing_objects.get_remote_has()
+            object_ids = list(missing_objects)
             symrefs = r.refs.get_symrefs()
             symrefs = r.refs.get_symrefs()
             agent = agent_string()
             agent = agent_string()
 
 
             # Did the process short-circuit (e.g. in a stateless RPC call)?
             # Did the process short-circuit (e.g. in a stateless RPC call)?
             # Note that the client still expects a 0-object pack in most cases.
             # Note that the client still expects a 0-object pack in most cases.
-            if objects_iter is None:
+            if object_ids is None:
                 return FetchPackResult(None, symrefs, agent)
                 return FetchPackResult(None, symrefs, agent)
-            write_pack_objects(pack_data, objects_iter, reuse_pack=r.object_store)
+            write_pack_from_container(pack_data, r.object_store, object_ids, other_haves=other_haves)
             return FetchPackResult(r.get_refs(), symrefs, agent)
             return FetchPackResult(r.get_refs(), symrefs, agent)
 
 
     def get_refs(self, path):
     def get_refs(self, path):

+ 0 - 23
dulwich/contrib/swift.py

@@ -40,7 +40,6 @@ from geventhttpclient import HTTPClient
 
 
 from dulwich.greenthreads import (
 from dulwich.greenthreads import (
     GreenThreadsMissingObjectFinder,
     GreenThreadsMissingObjectFinder,
-    GreenThreadsObjectStoreIterator,
 )
 )
 
 
 from dulwich.lru_cache import LRUSizeCache
 from dulwich.lru_cache import LRUSizeCache
@@ -119,15 +118,6 @@ cache_length = 20
 """
 """
 
 
 
 
-class PackInfoObjectStoreIterator(GreenThreadsObjectStoreIterator):
-    def __len__(self):
-        while self.finder.objects_to_send:
-            for _ in range(0, len(self.finder.objects_to_send)):
-                sha = self.finder.next()
-                self._shas.append(sha)
-        return len(self._shas)
-
-
 class PackInfoMissingObjectFinder(GreenThreadsMissingObjectFinder):
 class PackInfoMissingObjectFinder(GreenThreadsMissingObjectFinder):
     def next(self):
     def next(self):
         while True:
         while True:
@@ -681,19 +671,6 @@ class SwiftObjectStore(PackBasedObjectStore):
         """Loose objects are not supported by this repository"""
         """Loose objects are not supported by this repository"""
         return []
         return []
 
 
-    def iter_shas(self, finder):
-        """An iterator over pack's ObjectStore.
-
-        Returns: a `ObjectStoreIterator` or `GreenThreadsObjectStoreIterator`
-                 instance if gevent is enabled
-        """
-        shas = iter(finder.next, None)
-        return PackInfoObjectStoreIterator(self, shas, finder, self.scon.concurrency)
-
-    def find_missing_objects(self, *args, **kwargs):
-        kwargs["concurrency"] = self.scon.concurrency
-        return PackInfoMissingObjectFinder(self, *args, **kwargs)
-
     def pack_info_get(self, sha):
     def pack_info_get(self, sha):
         for pack in self.packs:
         for pack in self.packs:
             if sha in pack:
             if sha in pack:

+ 0 - 34
dulwich/greenthreads.py

@@ -33,7 +33,6 @@ from dulwich.object_store import (
     MissingObjectFinder,
     MissingObjectFinder,
     _collect_ancestors,
     _collect_ancestors,
     _collect_filetree_revs,
     _collect_filetree_revs,
-    ObjectStoreIterator,
 )
 )
 
 
 
 
@@ -111,36 +110,3 @@ class GreenThreadsMissingObjectFinder(MissingObjectFinder):
         else:
         else:
             self.progress = progress
             self.progress = progress
         self._tagged = get_tagged and get_tagged() or {}
         self._tagged = get_tagged and get_tagged() or {}
-
-
-class GreenThreadsObjectStoreIterator(ObjectStoreIterator):
-    """ObjectIterator that works on top of an ObjectStore.
-
-    Same implementation as object_store.ObjectStoreIterator
-    except we use gevent to parallelize object retrieval.
-    """
-
-    def __init__(self, store, shas, finder, concurrency=1):
-        self.finder = finder
-        self.p = pool.Pool(size=concurrency)
-        super().__init__(store, shas)
-
-    def retrieve(self, args):
-        sha, path = args
-        return self.store[sha], path
-
-    def __iter__(self):
-        yield from self.p.imap_unordered(self.retrieve, self.itershas())
-
-    def __len__(self):
-        if len(self._shas) > 0:
-            return len(self._shas)
-        while self.finder.objects_to_send:
-            jobs = []
-            for _ in range(0, len(self.finder.objects_to_send)):
-                jobs.append(self.p.spawn(self.finder.next))
-            gevent.joinall(jobs)
-            for j in jobs:
-                if j.value is not None:
-                    self._shas.append(j.value)
-        return len(self._shas)

+ 210 - 218
dulwich/object_store.py

@@ -28,7 +28,12 @@ import stat
 import sys
 import sys
 import warnings
 import warnings
 
 
-from typing import Callable, Dict, List, Optional, Tuple, Protocol, Union, Iterator, Set
+from typing import Callable, Dict, List, Optional, Tuple, Iterator, Set, Iterable, Sequence, cast
+
+try:
+    from typing import Protocol
+except ImportError:  # python << 3.8
+    from typing_extensions import Protocol  # type: ignore
 
 
 from dulwich.errors import (
 from dulwich.errors import (
     NotTreeError,
     NotTreeError,
@@ -40,6 +45,7 @@ from dulwich.objects import (
     ShaFile,
     ShaFile,
     Tag,
     Tag,
     Tree,
     Tree,
+    Blob,
     ZERO_SHA,
     ZERO_SHA,
     hex_to_sha,
     hex_to_sha,
     sha_to_hex,
     sha_to_hex,
@@ -53,10 +59,14 @@ from dulwich.pack import (
     ObjectContainer,
     ObjectContainer,
     Pack,
     Pack,
     PackData,
     PackData,
+    PackHint,
     PackInflater,
     PackInflater,
     PackFileDisappeared,
     PackFileDisappeared,
+    UnpackedObject,
     load_pack_index_file,
     load_pack_index_file,
     iter_sha1,
     iter_sha1,
+    full_unpacked_object,
+    generate_unpacked_objects,
     pack_objects_to_data,
     pack_objects_to_data,
     write_pack_header,
     write_pack_header,
     write_pack_index_v2,
     write_pack_index_v2,
@@ -65,6 +75,7 @@ from dulwich.pack import (
     compute_file_sha,
     compute_file_sha,
     PackIndexer,
     PackIndexer,
     PackStreamCopier,
     PackStreamCopier,
+    PackedObjectContainer,
 )
 )
 from dulwich.protocol import DEPTH_INFINITE
 from dulwich.protocol import DEPTH_INFINITE
 from dulwich.refs import ANNOTATED_TAG_SUFFIX, Ref
 from dulwich.refs import ANNOTATED_TAG_SUFFIX, Ref
@@ -109,29 +120,16 @@ class BaseObjectStore:
             and not sha == ZERO_SHA
             and not sha == ZERO_SHA
         ]
         ]
 
 
-    def iter_shas(self, shas):
-        """Iterate over the objects for the specified shas.
-
-        Args:
-          shas: Iterable object with SHAs
-        Returns: Object iterator
-        """
-        return ObjectStoreIterator(self, shas)
-
     def contains_loose(self, sha):
     def contains_loose(self, sha):
         """Check if a particular object is present by SHA1 and is loose."""
         """Check if a particular object is present by SHA1 and is loose."""
         raise NotImplementedError(self.contains_loose)
         raise NotImplementedError(self.contains_loose)
 
 
-    def contains_packed(self, sha):
-        """Check if a particular object is present by SHA1 and is packed."""
-        raise NotImplementedError(self.contains_packed)
-
-    def __contains__(self, sha):
+    def __contains__(self, sha1: bytes) -> bool:
         """Check if a particular object is present by SHA1.
         """Check if a particular object is present by SHA1.
 
 
         This method makes no distinction between loose and packed objects.
         This method makes no distinction between loose and packed objects.
         """
         """
-        return self.contains_packed(sha) or self.contains_loose(sha)
+        return self.contains_loose(sha1)
 
 
     @property
     @property
     def packs(self):
     def packs(self):
@@ -147,21 +145,15 @@ class BaseObjectStore:
         """
         """
         raise NotImplementedError(self.get_raw)
         raise NotImplementedError(self.get_raw)
 
 
-    def __getitem__(self, sha: ObjectID):
+    def __getitem__(self, sha1: ObjectID) -> ShaFile:
         """Obtain an object by SHA1."""
         """Obtain an object by SHA1."""
-        type_num, uncomp = self.get_raw(sha)
-        return ShaFile.from_raw_string(type_num, uncomp, sha=sha)
+        type_num, uncomp = self.get_raw(sha1)
+        return ShaFile.from_raw_string(type_num, uncomp, sha=sha1)
 
 
     def __iter__(self):
     def __iter__(self):
         """Iterate over the SHAs that are present in this store."""
         """Iterate over the SHAs that are present in this store."""
         raise NotImplementedError(self.__iter__)
         raise NotImplementedError(self.__iter__)
 
 
-    def add_pack(
-        self
-    ) -> Tuple[BytesIO, Callable[[], None], Callable[[], None]]:
-        """Add a new pack to this object store."""
-        raise NotImplementedError(self.add_pack)
-
     def add_object(self, obj):
     def add_object(self, obj):
         """Add a single object to this object store."""
         """Add a single object to this object store."""
         raise NotImplementedError(self.add_object)
         raise NotImplementedError(self.add_object)
@@ -174,31 +166,6 @@ class BaseObjectStore:
         """
         """
         raise NotImplementedError(self.add_objects)
         raise NotImplementedError(self.add_objects)
 
 
-    def add_pack_data(self, count, pack_data, progress=None):
-        """Add pack data to this object store.
-
-        Args:
-          count: Number of items to add
-          pack_data: Iterator over pack data tuples
-        """
-        if count == 0:
-            # Don't bother writing an empty pack file
-            return
-        f, commit, abort = self.add_pack()
-        try:
-            write_pack_data(
-                f.write,
-                count,
-                pack_data,
-                progress,
-                compression_level=self.pack_compression_level,
-            )
-        except BaseException:
-            abort()
-            raise
-        else:
-            return commit()
-
     def tree_changes(
     def tree_changes(
         self,
         self,
         source,
         source,
@@ -253,41 +220,13 @@ class BaseObjectStore:
             DeprecationWarning, stacklevel=2)
             DeprecationWarning, stacklevel=2)
         return iter_tree_contents(self, tree_id, include_trees=include_trees)
         return iter_tree_contents(self, tree_id, include_trees=include_trees)
 
 
-    def find_missing_objects(
-        self,
-        haves,
-        wants,
-        shallow=None,
-        progress=None,
-        get_tagged=None,
-        get_parents=lambda commit: commit.parents,
-    ):
-        """Find the missing objects required for a set of revisions.
-
-        Args:
-          haves: Iterable over SHAs already in common.
-          wants: Iterable over SHAs of objects to fetch.
-          shallow: Set of shallow commit SHA1s to skip
-          progress: Simple progress function that will be called with
-            updated progress strings.
-          get_tagged: Function that returns a dict of pointed-to sha ->
-            tag sha for including tags.
-          get_parents: Optional function for getting the parents of a
-            commit.
-        Returns: Iterator over (sha, path) pairs.
-        """
-        warnings.warn(
-            'Please use MissingObjectFinder(store)', DeprecationWarning)
-        finder = MissingObjectFinder(
-            self,
-            haves=haves,
-            wants=wants,
-            shallow=shallow,
-            progress=progress,
-            get_tagged=get_tagged,
-            get_parents=get_parents,
-        )
-        return iter(finder)
+    def iterobjects_subset(self, shas: Iterable[bytes], *, allow_missing: bool = False) -> Iterator[ShaFile]:
+        for sha in shas:
+            try:
+                yield self[sha]
+            except KeyError:
+                if not allow_missing:
+                    raise
 
 
     def find_common_revisions(self, graphwalker):
     def find_common_revisions(self, graphwalker):
         """Find which revisions this store has in common using graphwalker.
         """Find which revisions this store has in common using graphwalker.
@@ -305,22 +244,10 @@ class BaseObjectStore:
             sha = next(graphwalker)
             sha = next(graphwalker)
         return haves
         return haves
 
 
-    def generate_pack_contents(self, have, want, shallow=None, progress=None):
-        """Iterate over the contents of a pack file.
-
-        Args:
-          have: List of SHA1s of objects that should not be sent
-          want: List of SHA1s of objects that should be sent
-          shallow: Set of shallow commit SHA1s to skip
-          progress: Optional progress reporting method
-        """
-        missing = MissingObjectFinder(
-            self, haves=have, wants=want, shallow=shallow, progress=progress)
-        return self.iter_shas(missing)
-
     def generate_pack_data(
     def generate_pack_data(
-        self, have, want, shallow=None, progress=None, ofs_delta=True
-    ):
+        self, have, want, shallow=None, progress=None,
+        ofs_delta=True
+    ) -> Tuple[int, Iterator[UnpackedObject]]:
         """Generate pack data objects for a set of wants/haves.
         """Generate pack data objects for a set of wants/haves.
 
 
         Args:
         Args:
@@ -330,10 +257,14 @@ class BaseObjectStore:
           ofs_delta: Whether OFS deltas can be included
           ofs_delta: Whether OFS deltas can be included
           progress: Optional progress reporting method
           progress: Optional progress reporting method
         """
         """
-        # TODO(jelmer): More efficient implementation
+        # Note that the pack-specific implementation below is more efficient,
+        # as it reuses deltas
+        missing_objects = MissingObjectFinder(
+            self, haves=have, wants=want, shallow=shallow, progress=progress)
+        object_ids = list(missing_objects)
         return pack_objects_to_data(
         return pack_objects_to_data(
-            self.generate_pack_contents(have, want, shallow, progress)
-        )
+            [(self[oid], path) for oid, path in object_ids], ofs_delta=ofs_delta,
+            progress=progress)
 
 
     def peel_sha(self, sha):
     def peel_sha(self, sha):
         """Peel all tags from a SHA.
         """Peel all tags from a SHA.
@@ -389,6 +320,37 @@ class PackBasedObjectStore(BaseObjectStore):
         self._pack_cache = {}
         self._pack_cache = {}
         self.pack_compression_level = pack_compression_level
         self.pack_compression_level = pack_compression_level
 
 
+    def add_pack(
+        self
+    ) -> Tuple[BytesIO, Callable[[], None], Callable[[], None]]:
+        """Add a new pack to this object store."""
+        raise NotImplementedError(self.add_pack)
+
+    def add_pack_data(self, count: int, unpacked_objects: Iterator[UnpackedObject], progress=None) -> None:
+        """Add pack data to this object store.
+
+        Args:
+          count: Number of items to add
+          pack_data: Iterator over pack data tuples
+        """
+        if count == 0:
+            # Don't bother writing an empty pack file
+            return
+        f, commit, abort = self.add_pack()
+        try:
+            write_pack_data(
+                f.write,
+                unpacked_objects,
+                num_records=count,
+                progress=progress,
+                compression_level=self.pack_compression_level,
+            )
+        except BaseException:
+            abort()
+            raise
+        else:
+            return commit()
+
     @property
     @property
     def alternates(self):
     def alternates(self):
         return []
         return []
@@ -426,6 +388,30 @@ class PackBasedObjectStore(BaseObjectStore):
             if prev_pack:
             if prev_pack:
                 prev_pack.close()
                 prev_pack.close()
 
 
+    def generate_pack_data(
+        self, have, want, shallow=None, progress=None,
+        ofs_delta=True
+    ) -> Tuple[int, Iterator[UnpackedObject]]:
+        """Generate pack data objects for a set of wants/haves.
+
+        Args:
+          have: List of SHA1s of objects that should not be sent
+          want: List of SHA1s of objects that should be sent
+          shallow: Set of shallow commit SHA1s to skip
+          ofs_delta: Whether OFS deltas can be included
+          progress: Optional progress reporting method
+        """
+        missing_objects = MissingObjectFinder(
+            self, haves=have, wants=want, shallow=shallow, progress=progress)
+        remote_has = missing_objects.get_remote_has()
+        object_ids = list(missing_objects)
+        return len(object_ids), generate_unpacked_objects(
+            cast(PackedObjectContainer, self),
+            object_ids,
+            progress=progress,
+            ofs_delta=ofs_delta,
+            other_haves=remote_has)
+
     def _clear_cached_packs(self):
     def _clear_cached_packs(self):
         pack_cache = self._pack_cache
         pack_cache = self._pack_cache
         self._pack_cache = {}
         self._pack_cache = {}
@@ -565,47 +551,89 @@ class PackBasedObjectStore(BaseObjectStore):
                 pass
                 pass
         raise KeyError(hexsha)
         raise KeyError(hexsha)
 
 
-    def get_raw_unresolved(self, name: bytes) -> Tuple[int, Union[bytes, None], List[bytes]]:
-        """Obtain the unresolved data for an object.
+    def iter_unpacked_subset(self, shas, *, include_comp=False, allow_missing: bool = False, convert_ofs_delta: bool = True) -> Iterator[ShaFile]:
+        todo: Set[bytes] = set(shas)
+        for p in self._iter_cached_packs():
+            for unpacked in p.iter_unpacked_subset(todo, include_comp=include_comp, allow_missing=True, convert_ofs_delta=convert_ofs_delta):
+                yield unpacked
+                hexsha = sha_to_hex(unpacked.sha())
+                todo.remove(hexsha)
+        # Maybe something else has added a pack with the object
+        # in the mean time?
+        for p in self._update_pack_cache():
+            for unpacked in p.iter_unpacked_subset(todo, include_comp=include_comp, allow_missing=True, convert_ofs_delta=convert_ofs_delta):
+                yield unpacked
+                hexsha = sha_to_hex(unpacked.sha())
+                todo.remove(hexsha)
+        for alternate in self.alternates:
+            for unpacked in alternate.iter_unpacked_subset(todo, include_comp=include_comp, allow_missing=True, convert_ofs_delta=convert_ofs_delta):
+                yield unpacked
+                hexsha = sha_to_hex(unpacked.sha())
+                todo.remove(hexsha)
+
+    def iterobjects_subset(self, shas: Iterable[bytes], *, allow_missing: bool = False) -> Iterator[ShaFile]:
+        todo: Set[bytes] = set(shas)
+        for p in self._iter_cached_packs():
+            for o in p.iterobjects_subset(todo, allow_missing=True):
+                yield o
+                todo.remove(o.id)
+        # Maybe something else has added a pack with the object
+        # in the mean time?
+        for p in self._update_pack_cache():
+            for o in p.iterobjects_subset(todo, allow_missing=True):
+                yield o
+                todo.remove(o.id)
+        for alternate in self.alternates:
+            for o in alternate.iterobjects_subset(todo, allow_missing=True):
+                yield o
+                todo.remove(o.id)
+        for oid in todo:
+            o = self._get_loose_object(oid)
+            if o is not None:
+                yield o
+            elif not allow_missing:
+                raise KeyError(oid)
+
+    def get_unpacked_object(self, sha1: bytes, *, include_comp: bool = False) -> UnpackedObject:
+        """Obtain the unpacked object.
 
 
         Args:
         Args:
-          name: sha for the object.
+          sha1: sha for the object.
         """
         """
-        if name == ZERO_SHA:
-            raise KeyError(name)
-        if len(name) == 40:
-            sha = hex_to_sha(name)
-            hexsha = name
-        elif len(name) == 20:
-            sha = name
+        if sha1 == ZERO_SHA:
+            raise KeyError(sha1)
+        if len(sha1) == 40:
+            sha = hex_to_sha(sha1)
+            hexsha = sha1
+        elif len(sha1) == 20:
+            sha = sha1
             hexsha = None
             hexsha = None
         else:
         else:
-            raise AssertionError("Invalid object name {!r}".format(name))
+            raise AssertionError("Invalid object sha1 {!r}".format(sha1))
         for pack in self._iter_cached_packs():
         for pack in self._iter_cached_packs():
             try:
             try:
-                return pack.get_raw_unresolved(sha)
+                return pack.get_unpacked_object(sha, include_comp=include_comp)
             except (KeyError, PackFileDisappeared):
             except (KeyError, PackFileDisappeared):
                 pass
                 pass
         if hexsha is None:
         if hexsha is None:
-            hexsha = sha_to_hex(name)
-        ret = self._get_loose_object(hexsha)
-        if ret is not None:
-            return ret.type_num, None, ret.as_raw_chunks()
+            hexsha = sha_to_hex(sha1)
         # Maybe something else has added a pack with the object
         # Maybe something else has added a pack with the object
         # in the mean time?
         # in the mean time?
         for pack in self._update_pack_cache():
         for pack in self._update_pack_cache():
             try:
             try:
-                return pack.get_raw_unresolved(sha)
+                return pack.get_unpacked_object(sha, include_comp=include_comp)
             except KeyError:
             except KeyError:
                 pass
                 pass
         for alternate in self.alternates:
         for alternate in self.alternates:
             try:
             try:
-                return alternate.get_raw_unresolved(hexsha)
+                return alternate.get_unpacked_object(hexsha, include_comp=include_comp)
             except KeyError:
             except KeyError:
                 pass
                 pass
         raise KeyError(hexsha)
         raise KeyError(hexsha)
 
 
-    def add_objects(self, objects, progress=None):
+    def add_objects(
+            self, objects: Sequence[Tuple[ShaFile, Optional[str]]],
+            progress: Optional[Callable[[str], None]] = None) -> None:
         """Add a set of objects to this object store.
         """Add a set of objects to this object store.
 
 
         Args:
         Args:
@@ -613,7 +641,9 @@ class PackBasedObjectStore(BaseObjectStore):
             __len__.
             __len__.
         Returns: Pack object of the objects written.
         Returns: Pack object of the objects written.
         """
         """
-        return self.add_pack_data(*pack_objects_to_data(objects), progress=progress)
+        count = len(objects)
+        record_iter = (full_unpacked_object(o) for (o, p) in objects)
+        return self.add_pack_data(count, record_iter, progress=progress)
 
 
 
 
 class DiskObjectStore(PackBasedObjectStore):
 class DiskObjectStore(PackBasedObjectStore):
@@ -1103,79 +1133,6 @@ class ObjectIterator(Protocol):
         raise NotImplementedError(self.iterobjects)
         raise NotImplementedError(self.iterobjects)
 
 
 
 
-class ObjectStoreIterator(ObjectIterator):
-    """ObjectIterator that works on top of an ObjectStore."""
-
-    def __init__(self, store, sha_iter):
-        """Create a new ObjectIterator.
-
-        Args:
-          store: Object store to retrieve from
-          sha_iter: Iterator over (sha, path) tuples
-        """
-        self.store = store
-        self.sha_iter = sha_iter
-        self._shas = []
-
-    def __iter__(self):
-        """Yield tuple with next object and path."""
-        for sha, path in self.itershas():
-            yield self.store[sha], path
-
-    def iterobjects(self):
-        """Iterate over just the objects."""
-        for o, path in self:
-            yield o
-
-    def itershas(self):
-        """Iterate over the SHAs."""
-        for sha in self._shas:
-            yield sha
-        for sha in self.sha_iter:
-            self._shas.append(sha)
-            yield sha
-
-    def __contains__(self, needle):
-        """Check if an object is present.
-
-        Note: This checks if the object is present in
-            the underlying object store, not if it would
-            be yielded by the iterator.
-
-        Args:
-          needle: SHA1 of the object to check for
-        """
-        if needle == ZERO_SHA:
-            return False
-        return needle in self.store
-
-    def __getitem__(self, key):
-        """Find an object by SHA1.
-
-        Note: This retrieves the object from the underlying
-            object store. It will also succeed if the object would
-            not be returned by the iterator.
-        """
-        return self.store[key]
-
-    def __len__(self):
-        """Return the number of objects."""
-        return len(list(self.itershas()))
-
-    def _empty(self):
-        it = self.itershas()
-        try:
-            next(it)
-        except StopIteration:
-            return True
-        else:
-            return False
-
-    def __bool__(self):
-        """Indicate whether this object has contents."""
-        return not self._empty()
-
-
 def tree_lookup_path(lookup_obj, root_sha, path):
 def tree_lookup_path(lookup_obj, root_sha, path):
     """Look up an object in a Git tree.
     """Look up an object in a Git tree.
 
 
@@ -1306,27 +1263,33 @@ class MissingObjectFinder:
             shallow=shallow,
             shallow=shallow,
             get_parents=self._get_parents,
             get_parents=self._get_parents,
         )
         )
-        self.sha_done = set()
+        self.remote_has: Set[bytes] = set()
         # Now, fill sha_done with commits and revisions of
         # Now, fill sha_done with commits and revisions of
         # files and directories known to be both locally
         # files and directories known to be both locally
         # and on target. Thus these commits and files
         # and on target. Thus these commits and files
         # won't get selected for fetch
         # won't get selected for fetch
         for h in common_commits:
         for h in common_commits:
-            self.sha_done.add(h)
+            self.remote_has.add(h)
             cmt = object_store[h]
             cmt = object_store[h]
-            _collect_filetree_revs(object_store, cmt.tree, self.sha_done)
+            _collect_filetree_revs(object_store, cmt.tree, self.remote_has)
         # record tags we have as visited, too
         # record tags we have as visited, too
         for t in have_tags:
         for t in have_tags:
-            self.sha_done.add(t)
+            self.remote_has.add(t)
+        self.sha_done = set(self.remote_has)
 
 
-        missing_tags = want_tags.difference(have_tags)
-        missing_others = want_others.difference(have_others)
         # in fact, what we 'want' is commits, tags, and others
         # in fact, what we 'want' is commits, tags, and others
         # we've found missing
         # we've found missing
-        wants = missing_commits.union(missing_tags)
-        wants = wants.union(missing_others)
-
-        self.objects_to_send = {(w, None, False) for w in wants}
+        self.objects_to_send = {
+            (w, None, Commit.type_num, False)
+            for w in missing_commits}
+        missing_tags = want_tags.difference(have_tags)
+        self.objects_to_send.update(
+            {(w, None, Tag.type_num, False)
+             for w in missing_tags})
+        missing_others = want_others.difference(have_others)
+        self.objects_to_send.update(
+            {(w, None, None, False)
+             for w in missing_others})
 
 
         if progress is None:
         if progress is None:
             self.progress = lambda x: None
             self.progress = lambda x: None
@@ -1334,38 +1297,44 @@ class MissingObjectFinder:
             self.progress = progress
             self.progress = progress
         self._tagged = get_tagged and get_tagged() or {}
         self._tagged = get_tagged and get_tagged() or {}
 
 
-    def add_todo(self, entries):
+    def get_remote_has(self):
+        return self.remote_has
+
+    def add_todo(self, entries: Iterable[Tuple[ObjectID, Optional[bytes], Optional[int], bool]]):
         self.objects_to_send.update([e for e in entries if not e[0] in self.sha_done])
         self.objects_to_send.update([e for e in entries if not e[0] in self.sha_done])
 
 
-    def __next__(self):
+    def __next__(self) -> Tuple[bytes, PackHint]:
         while True:
         while True:
             if not self.objects_to_send:
             if not self.objects_to_send:
-                return None
-            (sha, name, leaf) = self.objects_to_send.pop()
+                self.progress(("counting objects: %d, done.\n" % len(self.sha_done)).encode("ascii"))
+                raise StopIteration
+            (sha, name, type_num, leaf) = self.objects_to_send.pop()
             if sha not in self.sha_done:
             if sha not in self.sha_done:
                 break
                 break
         if not leaf:
         if not leaf:
             o = self.object_store[sha]
             o = self.object_store[sha]
             if isinstance(o, Commit):
             if isinstance(o, Commit):
-                self.add_todo([(o.tree, b"", False)])
+                self.add_todo([(o.tree, b"", Tree.type_num, False)])
             elif isinstance(o, Tree):
             elif isinstance(o, Tree):
                 self.add_todo(
                 self.add_todo(
                     [
                     [
-                        (s, n, not stat.S_ISDIR(m))
+                        (s, n, (Blob.type_num if stat.S_ISREG(m) else Tree.type_num),
+                         not stat.S_ISDIR(m))
                         for n, m, s in o.iteritems()
                         for n, m, s in o.iteritems()
                         if not S_ISGITLINK(m)
                         if not S_ISGITLINK(m)
                     ]
                     ]
                 )
                 )
             elif isinstance(o, Tag):
             elif isinstance(o, Tag):
-                self.add_todo([(o.object[1], None, False)])
+                self.add_todo([(o.object[1], None, o.object[0].type_num, False)])
         if sha in self._tagged:
         if sha in self._tagged:
-            self.add_todo([(self._tagged[sha], None, True)])
+            self.add_todo([(self._tagged[sha], None, None, True)])
         self.sha_done.add(sha)
         self.sha_done.add(sha)
-        self.progress(("counting objects: %d\r" % len(self.sha_done)).encode("ascii"))
-        return (sha, name)
+        if len(self.sha_done) % 1000 == 0:
+            self.progress(("counting objects: %d\r" % len(self.sha_done)).encode("ascii"))
+        return (sha, (type_num, name))
 
 
     def __iter__(self):
     def __iter__(self):
-        return iter(self.__next__, None)
+        return self
 
 
 
 
 class ObjectStoreGraphWalker:
 class ObjectStoreGraphWalker:
@@ -1390,6 +1359,9 @@ class ObjectStoreGraphWalker:
             shallow = set()
             shallow = set()
         self.shallow = shallow
         self.shallow = shallow
 
 
+    def nak(self):
+        """Nothing in common was found."""
+
     def ack(self, sha):
     def ack(self, sha):
         """Ack that a revision and its ancestors are present in the source."""
         """Ack that a revision and its ancestors are present in the source."""
         if len(sha) != 40:
         if len(sha) != 40:
@@ -1512,6 +1484,24 @@ class OverlayObjectStore(BaseObjectStore):
                     yield o_id
                     yield o_id
                     done.add(o_id)
                     done.add(o_id)
 
 
+    def iterobjects_subset(self, shas: Iterable[bytes], *, allow_missing: bool = False) -> Iterator[ShaFile]:
+        todo = set(shas)
+        for b in self.bases:
+            for o in b.iterobjects_subset(todo, allow_missing=True):
+                yield o
+                todo.remove(o.id)
+        if todo and not allow_missing:
+            raise KeyError(o.id)
+
+    def iter_unpacked_subset(self, shas: Iterable[bytes], *, include_comp=False, allow_missing: bool = False, convert_ofs_delta=True) -> Iterator[ShaFile]:
+        todo = set(shas)
+        for b in self.bases:
+            for o in b.iter_unpacked_subset(todo, include_comp=include_comp, allow_missing=True, convert_ofs_delta=convert_ofs_delta):
+                yield o
+                todo.remove(o.id)
+        if todo and not allow_missing:
+            raise KeyError(o.id)
+
     def get_raw(self, sha_id):
     def get_raw(self, sha_id):
         for b in self.bases:
         for b in self.bases:
             try:
             try:
@@ -1663,7 +1653,7 @@ def _collect_ancestors(
 
 
 
 
 def iter_tree_contents(
 def iter_tree_contents(
-        store: ObjectContainer, tree_id: bytes, *, include_trees: bool = False):
+        store: ObjectContainer, tree_id: Optional[ObjectID], *, include_trees: bool = False):
     """Iterate the contents of a tree and all subtrees.
     """Iterate the contents of a tree and all subtrees.
 
 
     Iteration is depth-first pre-order, as in e.g. os.walk.
     Iteration is depth-first pre-order, as in e.g. os.walk.
@@ -1674,6 +1664,8 @@ def iter_tree_contents(
     Returns: Iterator over TreeEntry namedtuples for all the objects in a
     Returns: Iterator over TreeEntry namedtuples for all the objects in a
         tree.
         tree.
     """
     """
+    if tree_id is None:
+        return
     # This could be fairly easily generalized to >2 trees if we find a use
     # This could be fairly easily generalized to >2 trees if we find a use
     # case.
     # case.
     todo = [TreeEntry(b"", stat.S_IFDIR, tree_id)]
     todo = [TreeEntry(b"", stat.S_IFDIR, tree_id)]

+ 1 - 4
dulwich/objects.py

@@ -528,10 +528,7 @@ class ShaFile:
 
 
     def raw_length(self) -> int:
     def raw_length(self) -> int:
         """Returns the length of the raw string of this object."""
         """Returns the length of the raw string of this object."""
-        ret = 0
-        for chunk in self.as_raw_chunks():
-            ret += len(chunk)
-        return ret
+        return sum(map(len, self.as_raw_chunks()))
 
 
     def sha(self):
     def sha(self):
         """The SHA1 object that is the name of this object."""
         """The SHA1 object that is the name of this object."""

File diff suppressed because it is too large
+ 361 - 213
dulwich/pack.py


+ 5 - 19
dulwich/porcelain.py

@@ -134,7 +134,7 @@ from dulwich.objectspec import (
 )
 )
 from dulwich.pack import (
 from dulwich.pack import (
     write_pack_index,
     write_pack_index,
-    write_pack_objects,
+    write_pack_from_container,
 )
 )
 from dulwich.patch import write_tree_diff
 from dulwich.patch import write_tree_diff
 from dulwich.protocol import (
 from dulwich.protocol import (
@@ -1753,18 +1753,6 @@ def repack(repo):
         r.object_store.pack_loose_objects()
         r.object_store.pack_loose_objects()
 
 
 
 
-def find_pack_for_reuse(repo):
-    reuse_pack = None
-    max_pack_len = 0
-    # The pack file which contains the largest number of objects
-    # will be most suitable for object reuse.
-    for p in repo.object_store.packs:
-        if len(p) > max_pack_len:
-            reuse_pack = p
-            max_pack_len = len(reuse_pack)
-    return reuse_pack
-
-
 def pack_objects(repo, object_ids, packf, idxf, delta_window_size=None, deltify=None, reuse_deltas=True):
 def pack_objects(repo, object_ids, packf, idxf, delta_window_size=None, deltify=None, reuse_deltas=True):
     """Pack objects into a file.
     """Pack objects into a file.
 
 
@@ -1779,15 +1767,13 @@ def pack_objects(repo, object_ids, packf, idxf, delta_window_size=None, deltify=
       reuse_deltas: Allow reuse of existing deltas while deltifying
       reuse_deltas: Allow reuse of existing deltas while deltifying
     """
     """
     with open_repo_closing(repo) as r:
     with open_repo_closing(repo) as r:
-        reuse_pack = None
-        if deltify and reuse_deltas:
-            reuse_pack = find_pack_for_reuse(r)
-        entries, data_sum = write_pack_objects(
+        entries, data_sum = write_pack_from_container(
             packf.write,
             packf.write,
-            r.object_store.iter_shas((oid, None) for oid in object_ids),
+            r.object_store,
+            [(oid, None) for oid in object_ids],
             deltify=deltify,
             deltify=deltify,
             delta_window_size=delta_window_size,
             delta_window_size=delta_window_size,
-            reuse_pack=reuse_pack
+            reuse_deltas=reuse_deltas,
         )
         )
     if idxf is not None:
     if idxf is not None:
         entries = sorted([(k, v[0], v[1]) for (k, v) in entries.items()])
         entries = sorted([(k, v[0], v[1]) for (k, v) in entries.items()])

+ 39 - 25
dulwich/repo.py

@@ -39,6 +39,7 @@ from typing import (
     Callable,
     Callable,
     Tuple,
     Tuple,
     TYPE_CHECKING,
     TYPE_CHECKING,
+    FrozenSet,
     List,
     List,
     Dict,
     Dict,
     Union,
     Union,
@@ -70,10 +71,10 @@ from dulwich.file import (
 from dulwich.object_store import (
 from dulwich.object_store import (
     DiskObjectStore,
     DiskObjectStore,
     MemoryObjectStore,
     MemoryObjectStore,
-    BaseObjectStore,
+    MissingObjectFinder,
+    PackBasedObjectStore,
     ObjectStoreGraphWalker,
     ObjectStoreGraphWalker,
     peel_sha,
     peel_sha,
-    MissingObjectFinder,
 )
 )
 from dulwich.objects import (
 from dulwich.objects import (
     check_hexsha,
     check_hexsha,
@@ -86,7 +87,7 @@ from dulwich.objects import (
     ObjectID,
     ObjectID,
 )
 )
 from dulwich.pack import (
 from dulwich.pack import (
-    pack_objects_to_data,
+    generate_unpacked_objects
 )
 )
 
 
 from dulwich.hooks import (
 from dulwich.hooks import (
@@ -362,7 +363,7 @@ class BaseRepo:
         repository
         repository
     """
     """
 
 
-    def __init__(self, object_store: BaseObjectStore, refs: RefsContainer):
+    def __init__(self, object_store: PackBasedObjectStore, refs: RefsContainer):
         """Open a repository.
         """Open a repository.
 
 
         This shouldn't be called directly, but rather through one of the
         This shouldn't be called directly, but rather through one of the
@@ -484,20 +485,23 @@ class BaseRepo:
           depth: Shallow fetch depth
           depth: Shallow fetch depth
         Returns: count and iterator over pack data
         Returns: count and iterator over pack data
         """
         """
-        # TODO(jelmer): Fetch pack data directly, don't create objects first.
-        objects = self.fetch_objects(
+        missing_objects = self.find_missing_objects(
             determine_wants, graph_walker, progress, get_tagged, depth=depth
             determine_wants, graph_walker, progress, get_tagged, depth=depth
         )
         )
-        return pack_objects_to_data(objects)
+        remote_has = missing_objects.get_remote_has()
+        object_ids = list(missing_objects)
+        return len(object_ids), generate_unpacked_objects(
+            self.object_store, object_ids, progress=progress,
+            other_haves=remote_has)
 
 
-    def fetch_objects(
+    def find_missing_objects(
         self,
         self,
         determine_wants,
         determine_wants,
         graph_walker,
         graph_walker,
         progress,
         progress,
         get_tagged=None,
         get_tagged=None,
         depth=None,
         depth=None,
-    ):
+    ) -> Optional[MissingObjectFinder]:
         """Fetch the missing objects required for a set of revisions.
         """Fetch the missing objects required for a set of revisions.
 
 
         Args:
         Args:
@@ -536,8 +540,8 @@ class BaseRepo:
         if not isinstance(wants, list):
         if not isinstance(wants, list):
             raise TypeError("determine_wants() did not return a list")
             raise TypeError("determine_wants() did not return a list")
 
 
-        shallows = getattr(graph_walker, "shallow", frozenset())
-        unshallows = getattr(graph_walker, "unshallow", frozenset())
+        shallows: FrozenSet[ObjectID] = getattr(graph_walker, "shallow", frozenset())
+        unshallows: FrozenSet[ObjectID] = getattr(graph_walker, "unshallow", frozenset())
 
 
         if wants == []:
         if wants == []:
             # TODO(dborowitz): find a way to short-circuit that doesn't change
             # TODO(dborowitz): find a way to short-circuit that doesn't change
@@ -547,7 +551,18 @@ class BaseRepo:
                 # Do not send a pack in shallow short-circuit path
                 # Do not send a pack in shallow short-circuit path
                 return None
                 return None
 
 
-            return []
+            class DummyMissingObjectFinder:
+
+                def get_remote_has(self):
+                    return None
+
+                def __len__(self):
+                    return 0
+
+                def __iter__(self):
+                    yield from []
+
+            return DummyMissingObjectFinder()  # type: ignore
 
 
         # If the graph walker is set up with an implementation that can
         # If the graph walker is set up with an implementation that can
         # ACK/NAK to the wire, it will write data to the client through
         # ACK/NAK to the wire, it will write data to the client through
@@ -566,17 +581,14 @@ class BaseRepo:
         def get_parents(commit):
         def get_parents(commit):
             return parents_provider.get_parents(commit.id, commit)
             return parents_provider.get_parents(commit.id, commit)
 
 
-        return self.object_store.iter_shas(
-            MissingObjectFinder(
-                self.object_store,
-                haves=haves,
-                wants=wants,
-                shallow=self.get_shallow(),
-                progress=progress,
-                get_tagged=get_tagged,
-                get_parents=get_parents,
-            )
-        )
+        return MissingObjectFinder(
+            self.object_store,
+            haves=haves,
+            wants=wants,
+            shallow=self.get_shallow(),
+            progress=progress,
+            get_tagged=get_tagged,
+            get_parents=get_parents)
 
 
     def generate_pack_data(self, have: List[ObjectID], want: List[ObjectID],
     def generate_pack_data(self, have: List[ObjectID], want: List[ObjectID],
                            progress: Optional[Callable[[str], None]] = None,
                            progress: Optional[Callable[[str], None]] = None,
@@ -1116,7 +1128,7 @@ class Repo(BaseRepo):
     def __init__(
     def __init__(
         self,
         self,
         root: str,
         root: str,
-        object_store: Optional[BaseObjectStore] = None,
+        object_store: Optional[PackBasedObjectStore] = None,
         bare: Optional[bool] = None
         bare: Optional[bool] = None
     ) -> None:
     ) -> None:
         self.symlink_fn = None
         self.symlink_fn = None
@@ -1433,7 +1445,9 @@ class Repo(BaseRepo):
         for fs_path in fs_paths:
         for fs_path in fs_paths:
             tree_path = _fs_to_tree_path(fs_path)
             tree_path = _fs_to_tree_path(fs_path)
             try:
             try:
-                tree_entry = self.object_store[tree_id].lookup_path(
+                tree = self.object_store[tree_id]
+                assert isinstance(tree, Tree)
+                tree_entry = tree.lookup_path(
                     self.object_store.__getitem__, tree_path)
                     self.object_store.__getitem__, tree_path)
             except KeyError:
             except KeyError:
                 # if tree_entry didn't exist, this file was being added, so
                 # if tree_entry didn't exist, this file was being added, so

+ 45 - 26
dulwich/server.py

@@ -1,6 +1,6 @@
 # server.py -- Implementation of the server side git protocols
 # server.py -- Implementation of the server side git protocols
 # Copyright (C) 2008 John Carr <john.carr@unrouted.co.uk>
 # Copyright (C) 2008 John Carr <john.carr@unrouted.co.uk>
-# Coprygith (C) 2011-2012 Jelmer Vernooij <jelmer@jelmer.uk>
+# Copyright(C) 2011-2012 Jelmer Vernooij <jelmer@jelmer.uk>
 #
 #
 # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
 # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
 # General Public License as public by the Free Software Foundation; version 2.0
 # General Public License as public by the Free Software Foundation; version 2.0
@@ -43,11 +43,18 @@ Currently supported capabilities:
 """
 """
 
 
 import collections
 import collections
+from functools import partial
 import os
 import os
 import socket
 import socket
 import sys
 import sys
 import time
 import time
 from typing import List, Tuple, Dict, Optional, Iterable, Set
 from typing import List, Tuple, Dict, Optional, Iterable, Set
+
+try:
+    from typing import Protocol as TypingProtocol
+except ImportError:  # python < 3.8
+    from typing_extensions import Protocol as TypingProtocol  # type: ignore
+
 import zlib
 import zlib
 
 
 import socketserver
 import socketserver
@@ -65,14 +72,16 @@ from dulwich.errors import (
 from dulwich import log_utils
 from dulwich import log_utils
 from dulwich.objects import (
 from dulwich.objects import (
     Commit,
     Commit,
+    ObjectID,
     valid_hexsha,
     valid_hexsha,
 )
 )
 from dulwich.object_store import (
 from dulwich.object_store import (
     peel_sha,
     peel_sha,
 )
 )
 from dulwich.pack import (
 from dulwich.pack import (
-    write_pack_objects,
+    write_pack_from_container,
     ObjectContainer,
     ObjectContainer,
+    PackedObjectContainer,
 )
 )
 from dulwich.protocol import (
 from dulwich.protocol import (
     BufferedPktLineWriter,
     BufferedPktLineWriter,
@@ -118,6 +127,7 @@ from dulwich.protocol import (
     NAK_LINE,
     NAK_LINE,
 )
 )
 from dulwich.refs import (
 from dulwich.refs import (
+    RefsContainer,
     ANNOTATED_TAG_SUFFIX,
     ANNOTATED_TAG_SUFFIX,
     write_info_refs,
     write_info_refs,
 )
 )
@@ -145,15 +155,15 @@ class Backend:
         raise NotImplementedError(self.open_repository)
         raise NotImplementedError(self.open_repository)
 
 
 
 
-class BackendRepo:
+class BackendRepo(TypingProtocol):
     """Repository abstraction used by the Git server.
     """Repository abstraction used by the Git server.
 
 
     The methods required here are a subset of those provided by
     The methods required here are a subset of those provided by
     dulwich.repo.Repo.
     dulwich.repo.Repo.
     """
     """
 
 
-    object_store = None
-    refs = None
+    object_store: PackedObjectContainer
+    refs: RefsContainer
 
 
     def get_refs(self) -> Dict[bytes, bytes]:
     def get_refs(self) -> Dict[bytes, bytes]:
         """
         """
@@ -175,7 +185,7 @@ class BackendRepo:
         """
         """
         return None
         return None
 
 
-    def fetch_objects(self, determine_wants, graph_walker, progress, get_tagged=None):
+    def find_missing_objects(self, determine_wants, graph_walker, progress, get_tagged=None):
         """
         """
         Yield the objects required for a list of commits.
         Yield the objects required for a list of commits.
 
 
@@ -326,12 +336,21 @@ class UploadPackHandler(PackHandler):
             CAPABILITY_OFS_DELTA,
             CAPABILITY_OFS_DELTA,
         )
         )
 
 
-    def progress(self, message):
-        if self.has_capability(CAPABILITY_NO_PROGRESS) or self._processing_have_lines:
-            return
-        self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, message)
+    def progress(self, message: bytes):
+        pass
 
 
-    def get_tagged(self, refs=None, repo=None):
+    def _start_pack_send_phase(self):
+        if self.has_capability(CAPABILITY_SIDE_BAND_64K):
+            # The provided haves are processed, and it is safe to send side-
+            # band data now.
+            if not self.has_capability(CAPABILITY_NO_PROGRESS):
+                self.progress = partial(self.proto.write_sideband, SIDE_BAND_CHANNEL_PROGRESS)
+
+            self.write_pack_data = partial(self.proto.write_sideband, SIDE_BAND_CHANNEL_DATA)
+        else:
+            self.write_pack_data = self.proto.write
+
+    def get_tagged(self, refs=None, repo=None) -> Dict[ObjectID, ObjectID]:
         """Get a dict of peeled values of tags to their original tag shas.
         """Get a dict of peeled values of tags to their original tag shas.
 
 
         Args:
         Args:
@@ -355,7 +374,7 @@ class UploadPackHandler(PackHandler):
                 # TODO: fix behavior when missing
                 # TODO: fix behavior when missing
                 return {}
                 return {}
         # TODO(jelmer): Integrate this with the refs logic in
         # TODO(jelmer): Integrate this with the refs logic in
-        # Repo.fetch_objects
+        # Repo.find_missing_objects
         tagged = {}
         tagged = {}
         for name, sha in refs.items():
         for name, sha in refs.items():
             peeled_sha = repo.get_peeled(name)
             peeled_sha = repo.get_peeled(name)
@@ -364,8 +383,10 @@ class UploadPackHandler(PackHandler):
         return tagged
         return tagged
 
 
     def handle(self):
     def handle(self):
-        def write(x):
-            return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
+        # Note the fact that client is only processing responses related
+        # to the have lines it sent, and any other data (including side-
+        # band) will be be considered a fatal error.
+        self._processing_have_lines = True
 
 
         graph_walker = _ProtocolGraphWalker(
         graph_walker = _ProtocolGraphWalker(
             self,
             self,
@@ -379,17 +400,14 @@ class UploadPackHandler(PackHandler):
             wants.extend(graph_walker.determine_wants(refs, **kwargs))
             wants.extend(graph_walker.determine_wants(refs, **kwargs))
             return wants
             return wants
 
 
-        objects_iter = self.repo.fetch_objects(
+        missing_objects = self.repo.find_missing_objects(
             wants_wrapper,
             wants_wrapper,
             graph_walker,
             graph_walker,
             self.progress,
             self.progress,
             get_tagged=self.get_tagged,
             get_tagged=self.get_tagged,
         )
         )
 
 
-        # Note the fact that client is only processing responses related
-        # to the have lines it sent, and any other data (including side-
-        # band) will be be considered a fatal error.
-        self._processing_have_lines = True
+        object_ids = list(missing_objects)
 
 
         # Did the process short-circuit (e.g. in a stateless RPC call)? Note
         # Did the process short-circuit (e.g. in a stateless RPC call)? Note
         # that the client still expects a 0-object pack in most cases.
         # that the client still expects a 0-object pack in most cases.
@@ -400,19 +418,17 @@ class UploadPackHandler(PackHandler):
         if len(wants) == 0:
         if len(wants) == 0:
             return
             return
 
 
-        # The provided haves are processed, and it is safe to send side-
-        # band data now.
-        self._processing_have_lines = False
-
         if not graph_walker.handle_done(
         if not graph_walker.handle_done(
             not self.has_capability(CAPABILITY_NO_DONE), self._done_received
             not self.has_capability(CAPABILITY_NO_DONE), self._done_received
         ):
         ):
             return
             return
 
 
+        self._start_pack_send_phase()
         self.progress(
         self.progress(
-            ("counting objects: %d, done.\n" % len(objects_iter)).encode("ascii")
+            ("counting objects: %d, done.\n" % len(object_ids)).encode("ascii")
         )
         )
-        write_pack_objects(write, objects_iter)
+
+        write_pack_from_container(self.write_pack_data, self.repo.object_store, object_ids)
         # we are done
         # we are done
         self.proto.write_pkt_line(None)
         self.proto.write_pkt_line(None)
 
 
@@ -604,7 +620,7 @@ class _ProtocolGraphWalker:
                     peeled_sha = self.get_peeled(ref)
                     peeled_sha = self.get_peeled(ref)
                 except KeyError:
                 except KeyError:
                     # Skip refs that are inaccessible
                     # Skip refs that are inaccessible
-                    # TODO(jelmer): Integrate with Repo.fetch_objects refs
+                    # TODO(jelmer): Integrate with Repo.find_missing_objects refs
                     # logic.
                     # logic.
                     continue
                     continue
                 if i == 0:
                 if i == 0:
@@ -663,6 +679,9 @@ class _ProtocolGraphWalker:
             value = str(value).encode("ascii")
             value = str(value).encode("ascii")
         self.proto.unread_pkt_line(command + b" " + value)
         self.proto.unread_pkt_line(command + b" " + value)
 
 
+    def nak(self):
+        pass
+
     def ack(self, have_ref):
     def ack(self, have_ref):
         if len(have_ref) != 40:
         if len(have_ref) != 40:
             raise ValueError("invalid sha %r" % have_ref)
             raise ValueError("invalid sha %r" % have_ref)

+ 2 - 2
dulwich/tests/compat/test_client.py

@@ -329,7 +329,7 @@ class DulwichClientTestBase:
             sendrefs[b"refs/heads/abranch"] = b"00" * 20
             sendrefs[b"refs/heads/abranch"] = b"00" * 20
             del sendrefs[b"HEAD"]
             del sendrefs[b"HEAD"]
 
 
-            def gen_pack(have, want, ofs_delta=False):
+            def gen_pack(have, want, ofs_delta=False, progress=None):
                 return 0, []
                 return 0, []
 
 
             c = self._client()
             c = self._client()
@@ -344,7 +344,7 @@ class DulwichClientTestBase:
             dest.refs[b"refs/heads/abranch"] = dummy_commit
             dest.refs[b"refs/heads/abranch"] = dummy_commit
             sendrefs = {b"refs/heads/bbranch": dummy_commit}
             sendrefs = {b"refs/heads/bbranch": dummy_commit}
 
 
-            def gen_pack(have, want, ofs_delta=False):
+            def gen_pack(have, want, ofs_delta=False, progress=None):
                 return 0, []
                 return 0, []
 
 
             c = self._client()
             c = self._client()

+ 5 - 5
dulwich/tests/compat/test_pack.py

@@ -84,7 +84,7 @@ class TestPack(PackTests):
             orig_blob = orig_pack[a_sha]
             orig_blob = orig_pack[a_sha]
             new_blob = Blob()
             new_blob = Blob()
             new_blob.data = orig_blob.data + b"x"
             new_blob.data = orig_blob.data + b"x"
-            all_to_pack = list(orig_pack.pack_tuples()) + [(new_blob, None)]
+            all_to_pack = [(o, None) for o in orig_pack.iterobjects()] + [(new_blob, None)]
         pack_path = os.path.join(self._tempdir, "pack_with_deltas")
         pack_path = os.path.join(self._tempdir, "pack_with_deltas")
         write_pack(pack_path, all_to_pack, deltify=True)
         write_pack(pack_path, all_to_pack, deltify=True)
         output = run_git_or_fail(["verify-pack", "-v", pack_path])
         output = run_git_or_fail(["verify-pack", "-v", pack_path])
@@ -115,8 +115,8 @@ class TestPack(PackTests):
                 (new_blob, None),
                 (new_blob, None),
                 (new_blob_2, None),
                 (new_blob_2, None),
             ]
             ]
-        pack_path = os.path.join(self._tempdir, "pack_with_deltas")
-        write_pack(pack_path, all_to_pack, deltify=True)
+            pack_path = os.path.join(self._tempdir, "pack_with_deltas")
+            write_pack(pack_path, all_to_pack, deltify=True)
         output = run_git_or_fail(["verify-pack", "-v", pack_path])
         output = run_git_or_fail(["verify-pack", "-v", pack_path])
         self.assertEqual(
         self.assertEqual(
             {x[0].id for x in all_to_pack},
             {x[0].id for x in all_to_pack},
@@ -154,8 +154,8 @@ class TestPack(PackTests):
                 (new_blob, None),
                 (new_blob, None),
                 (new_blob_2, None),
                 (new_blob_2, None),
             ]
             ]
-        pack_path = os.path.join(self._tempdir, "pack_with_deltas")
-        write_pack(pack_path, all_to_pack, deltify=True)
+            pack_path = os.path.join(self._tempdir, "pack_with_deltas")
+            write_pack(pack_path, all_to_pack, deltify=True)
         output = run_git_or_fail(["verify-pack", "-v", pack_path])
         output = run_git_or_fail(["verify-pack", "-v", pack_path])
         self.assertEqual(
         self.assertEqual(
             {x[0].id for x in all_to_pack},
             {x[0].id for x in all_to_pack},

+ 9 - 0
dulwich/tests/test_bundle.py

@@ -20,6 +20,7 @@
 
 
 """Tests for bundle support."""
 """Tests for bundle support."""
 
 
+from io import BytesIO
 import os
 import os
 import tempfile
 import tempfile
 
 
@@ -32,6 +33,10 @@ from dulwich.bundle import (
     read_bundle,
     read_bundle,
     write_bundle,
     write_bundle,
 )
 )
+from dulwich.pack import (
+    PackData,
+    write_pack_objects,
+)
 
 
 
 
 class BundleTests(TestCase):
 class BundleTests(TestCase):
@@ -41,6 +46,10 @@ class BundleTests(TestCase):
         origbundle.capabilities = {"foo": None}
         origbundle.capabilities = {"foo": None}
         origbundle.references = {b"refs/heads/master": b"ab" * 20}
         origbundle.references = {b"refs/heads/master": b"ab" * 20}
         origbundle.prerequisites = [(b"cc" * 20, "comment")]
         origbundle.prerequisites = [(b"cc" * 20, "comment")]
+        b = BytesIO()
+        write_pack_objects(b.write, [])
+        b.seek(0)
+        origbundle.pack_data = PackData.from_file(b)
         with tempfile.TemporaryDirectory() as td:
         with tempfile.TemporaryDirectory() as td:
             with open(os.path.join(td, "foo"), "wb") as f:
             with open(os.path.join(td, "foo"), "wb") as f:
                 write_bundle(f, origbundle)
                 write_bundle(f, origbundle)

+ 15 - 12
dulwich/tests/test_client.py

@@ -200,7 +200,7 @@ class GitClientTests(TestCase):
         self.assertEqual({}, ret.symrefs)
         self.assertEqual({}, ret.symrefs)
         self.assertEqual(self.rout.getvalue(), b"0000")
         self.assertEqual(self.rout.getvalue(), b"0000")
 
 
-    def test_send_pack_no_sideband64k_with_update_ref_error(self):
+    def test_send_pack_no_sideband64k_with_update_ref_error(self) -> None:
         # No side-bank-64k reported by server shouldn't try to parse
         # No side-bank-64k reported by server shouldn't try to parse
         # side band data
         # side band data
         pkts = [
         pkts = [
@@ -233,11 +233,11 @@ class GitClientTests(TestCase):
                 b"refs/foo/bar": commit.id,
                 b"refs/foo/bar": commit.id,
             }
             }
 
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, ofs_delta=False, progress=None):
             return pack_objects_to_data(
             return pack_objects_to_data(
                 [
                 [
                     (commit, None),
                     (commit, None),
-                    (tree, ""),
+                    (tree, b""),
                 ]
                 ]
             )
             )
 
 
@@ -260,7 +260,7 @@ class GitClientTests(TestCase):
         def update_refs(refs):
         def update_refs(refs):
             return {b"refs/heads/master": b"310ca9477129b8586fa2afc779c1f57cf64bba6c"}
             return {b"refs/heads/master": b"310ca9477129b8586fa2afc779c1f57cf64bba6c"}
 
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, ofs_delta=False, progress=None):
             return 0, []
             return 0, []
 
 
         self.client.send_pack(b"/", update_refs, generate_pack_data)
         self.client.send_pack(b"/", update_refs, generate_pack_data)
@@ -280,7 +280,7 @@ class GitClientTests(TestCase):
         def update_refs(refs):
         def update_refs(refs):
             return {b"refs/heads/master": b"0" * 40}
             return {b"refs/heads/master": b"0" * 40}
 
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, ofs_delta=False, progress=None):
             return 0, []
             return 0, []
 
 
         self.client.send_pack(b"/", update_refs, generate_pack_data)
         self.client.send_pack(b"/", update_refs, generate_pack_data)
@@ -304,7 +304,7 @@ class GitClientTests(TestCase):
         def update_refs(refs):
         def update_refs(refs):
             return {b"refs/heads/master": b"0" * 40}
             return {b"refs/heads/master": b"0" * 40}
 
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, ofs_delta=False, progress=None):
             return 0, []
             return 0, []
 
 
         self.client.send_pack(b"/", update_refs, generate_pack_data)
         self.client.send_pack(b"/", update_refs, generate_pack_data)
@@ -331,11 +331,11 @@ class GitClientTests(TestCase):
                 b"refs/heads/master": b"310ca9477129b8586fa2afc779c1f57cf64bba6c",
                 b"refs/heads/master": b"310ca9477129b8586fa2afc779c1f57cf64bba6c",
             }
             }
 
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, ofs_delta=False, progress=None):
             return 0, []
             return 0, []
 
 
         f = BytesIO()
         f = BytesIO()
-        write_pack_objects(f.write, {})
+        write_pack_objects(f.write, [])
         self.client.send_pack("/", update_refs, generate_pack_data)
         self.client.send_pack("/", update_refs, generate_pack_data)
         self.assertEqual(
         self.assertEqual(
             self.rout.getvalue(),
             self.rout.getvalue(),
@@ -371,7 +371,7 @@ class GitClientTests(TestCase):
                 b"refs/heads/master": b"310ca9477129b8586fa2afc779c1f57cf64bba6c",
                 b"refs/heads/master": b"310ca9477129b8586fa2afc779c1f57cf64bba6c",
             }
             }
 
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, ofs_delta=False, progress=None):
             return pack_objects_to_data(
             return pack_objects_to_data(
                 [
                 [
                     (commit, None),
                     (commit, None),
@@ -380,7 +380,8 @@ class GitClientTests(TestCase):
             )
             )
 
 
         f = BytesIO()
         f = BytesIO()
-        write_pack_data(f.write, *generate_pack_data(None, None))
+        count, records = generate_pack_data(None, None)
+        write_pack_data(f.write, records, num_records=count)
         self.client.send_pack(b"/", update_refs, generate_pack_data)
         self.client.send_pack(b"/", update_refs, generate_pack_data)
         self.assertEqual(
         self.assertEqual(
             self.rout.getvalue(),
             self.rout.getvalue(),
@@ -407,7 +408,7 @@ class GitClientTests(TestCase):
         def update_refs(refs):
         def update_refs(refs):
             return {b"refs/heads/master": b"0" * 40}
             return {b"refs/heads/master": b"0" * 40}
 
 
-        def generate_pack_data(have, want, ofs_delta=False):
+        def generate_pack_data(have, want, ofs_delta=False, progress=None):
             return 0, []
             return 0, []
 
 
         result = self.client.send_pack(b"/", update_refs, generate_pack_data)
         result = self.client.send_pack(b"/", update_refs, generate_pack_data)
@@ -861,7 +862,9 @@ class LocalGitClientTests(TestCase):
 
 
     def test_fetch_into_empty(self):
     def test_fetch_into_empty(self):
         c = LocalGitClient()
         c = LocalGitClient()
-        t = MemoryRepo()
+        target = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, target)
+        t = Repo.init_bare(target)
         s = open_repo("a.git")
         s = open_repo("a.git")
         self.addCleanup(tear_down_repo, s)
         self.addCleanup(tear_down_repo, s)
         self.assertEqual(s.get_refs(), c.fetch(s.path, t).refs)
         self.assertEqual(s.get_refs(), c.fetch(s.path, t).refs)

+ 0 - 38
dulwich/tests/test_greenthreads.py

@@ -28,7 +28,6 @@ from dulwich.tests import (
 )
 )
 from dulwich.object_store import (
 from dulwich.object_store import (
     MemoryObjectStore,
     MemoryObjectStore,
-    MissingObjectFinder,
 )
 )
 from dulwich.objects import (
 from dulwich.objects import (
     Commit,
     Commit,
@@ -46,7 +45,6 @@ except ImportError:
 
 
 if gevent_support:
 if gevent_support:
     from dulwich.greenthreads import (
     from dulwich.greenthreads import (
-        GreenThreadsObjectStoreIterator,
         GreenThreadsMissingObjectFinder,
         GreenThreadsMissingObjectFinder,
     )
     )
 
 
@@ -77,42 +75,6 @@ def init_store(store, count=1):
     return ret
     return ret
 
 
 
 
-@skipIf(not gevent_support, skipmsg)
-class TestGreenThreadsObjectStoreIterator(TestCase):
-    def setUp(self):
-        super().setUp()
-        self.store = MemoryObjectStore()
-        self.cmt_amount = 10
-        self.objs = init_store(self.store, self.cmt_amount)
-
-    def test_len(self):
-        wants = [sha.id for sha in self.objs if isinstance(sha, Commit)]
-        finder = MissingObjectFinder(self.store, (), wants)
-        iterator = GreenThreadsObjectStoreIterator(
-            self.store, iter(finder.next, None), finder
-        )
-        # One commit refers one tree and one blob
-        self.assertEqual(len(iterator), self.cmt_amount * 3)
-        haves = wants[0 : self.cmt_amount - 1]
-        finder = MissingObjectFinder(self.store, haves, wants)
-        iterator = GreenThreadsObjectStoreIterator(
-            self.store, iter(finder.next, None), finder
-        )
-        self.assertEqual(len(iterator), 3)
-
-    def test_iter(self):
-        wants = [sha.id for sha in self.objs if isinstance(sha, Commit)]
-        finder = MissingObjectFinder(self.store, (), wants)
-        iterator = GreenThreadsObjectStoreIterator(
-            self.store, iter(finder.next, None), finder
-        )
-        objs = []
-        for sha, path in iterator:
-            self.assertIn(sha, self.objs)
-            objs.append(sha)
-        self.assertEqual(len(objs), len(self.objs))
-
-
 @skipIf(not gevent_support, skipmsg)
 @skipIf(not gevent_support, skipmsg)
 class TestGreenThreadsMissingObjectFinder(TestCase):
 class TestGreenThreadsMissingObjectFinder(TestCase):
     def setUp(self):
     def setUp(self):

+ 2 - 1
dulwich/tests/test_object_store.py

@@ -223,6 +223,7 @@ class ObjectStoreTests:
             [TreeEntry(p, m, h) for (p, h, m) in blobs],
             [TreeEntry(p, m, h) for (p, h, m) in blobs],
             list(iter_tree_contents(self.store, tree_id)),
             list(iter_tree_contents(self.store, tree_id)),
         )
         )
+        self.assertEqual([], list(iter_tree_contents(self.store, None)))
 
 
     def test_iter_tree_contents_include_trees(self):
     def test_iter_tree_contents_include_trees(self):
         blob_a = make_object(Blob, data=b"a")
         blob_a = make_object(Blob, data=b"a")
@@ -304,7 +305,7 @@ class MemoryObjectStoreTests(ObjectStoreTests, TestCase):
     def test_add_pack_emtpy(self):
     def test_add_pack_emtpy(self):
         o = MemoryObjectStore()
         o = MemoryObjectStore()
         f, commit, abort = o.add_pack()
         f, commit, abort = o.add_pack()
-        commit()
+        self.assertRaises(AssertionError, commit)
 
 
     def test_add_thin_pack(self):
     def test_add_thin_pack(self):
         o = MemoryObjectStore()
         o = MemoryObjectStore()

+ 52 - 32
dulwich/tests/test_pack.py

@@ -122,13 +122,13 @@ class PackTests(TestCase):
 class PackIndexTests(PackTests):
 class PackIndexTests(PackTests):
     """Class that tests the index of packfiles"""
     """Class that tests the index of packfiles"""
 
 
-    def test_object_index(self):
+    def test_object_offset(self):
         """Tests that the correct object offset is returned from the index."""
         """Tests that the correct object offset is returned from the index."""
         p = self.get_pack_index(pack1_sha)
         p = self.get_pack_index(pack1_sha)
-        self.assertRaises(KeyError, p.object_index, pack1_sha)
-        self.assertEqual(p.object_index(a_sha), 178)
-        self.assertEqual(p.object_index(tree_sha), 138)
-        self.assertEqual(p.object_index(commit_sha), 12)
+        self.assertRaises(KeyError, p.object_offset, pack1_sha)
+        self.assertEqual(p.object_offset(a_sha), 178)
+        self.assertEqual(p.object_offset(tree_sha), 138)
+        self.assertEqual(p.object_offset(commit_sha), 12)
 
 
     def test_object_sha1(self):
     def test_object_sha1(self):
         """Tests that the correct object offset is returned from the index."""
         """Tests that the correct object offset is returned from the index."""
@@ -284,7 +284,7 @@ class TestPackData(PackTests):
         with self.get_pack_data(pack1_sha) as p:
         with self.get_pack_data(pack1_sha) as p:
             self.assertSucceeds(p.check)
             self.assertSucceeds(p.check)
 
 
-    def test_iterobjects(self):
+    def test_iter_unpacked(self):
         with self.get_pack_data(pack1_sha) as p:
         with self.get_pack_data(pack1_sha) as p:
             commit_data = (
             commit_data = (
                 b"tree b2a2766a2879c209ab1176e7e778b81ae422eeaa\n"
                 b"tree b2a2766a2879c209ab1176e7e778b81ae422eeaa\n"
@@ -297,14 +297,12 @@ class TestPackData(PackTests):
             )
             )
             blob_sha = b"6f670c0fb53f9463760b7295fbb814e965fb20c8"
             blob_sha = b"6f670c0fb53f9463760b7295fbb814e965fb20c8"
             tree_data = b"100644 a\0" + hex_to_sha(blob_sha)
             tree_data = b"100644 a\0" + hex_to_sha(blob_sha)
-            actual = []
-            for offset, type_num, chunks, crc32 in p.iterobjects():
-                actual.append((offset, type_num, b"".join(chunks), crc32))
+            actual = list(p.iter_unpacked())
             self.assertEqual(
             self.assertEqual(
                 [
                 [
-                    (12, 1, commit_data, 3775879613),
-                    (138, 2, tree_data, 912998690),
-                    (178, 3, b"test 1\n", 1373561701),
+                    UnpackedObject(offset=12, pack_type_num=1, decomp_chunks=[commit_data], crc32=None),
+                    UnpackedObject(offset=138, pack_type_num=2, decomp_chunks=[tree_data], crc32=None),
+                    UnpackedObject(offset=178, pack_type_num=3, decomp_chunks=[b"test 1\n"], crc32=None),
                 ],
                 ],
                 actual,
                 actual,
             )
             )
@@ -578,24 +576,28 @@ class TestThinPack(PackTests):
         with self.make_pack(True) as p:
         with self.make_pack(True) as p:
             self.assertEqual((3, b"foo1234"), p.get_raw(self.blobs[b"foo1234"].id))
             self.assertEqual((3, b"foo1234"), p.get_raw(self.blobs[b"foo1234"].id))
 
 
-    def test_get_raw_unresolved(self):
+    def test_get_unpacked_object(self):
+        self.maxDiff = None
         with self.make_pack(False) as p:
         with self.make_pack(False) as p:
-            self.assertEqual(
-                (
-                    7,
-                    b"\x19\x10(\x15f=#\xf8\xb7ZG\xe7\xa0\x19e\xdc\xdc\x96F\x8c",
-                    [b"x\x9ccf\x9f\xc0\xccbhdl\x02\x00\x06f\x01l"],
-                ),
-                p.get_raw_unresolved(self.blobs[b"foo1234"].id),
+            expected = UnpackedObject(
+                7,
+                delta_base=b"\x19\x10(\x15f=#\xf8\xb7ZG\xe7\xa0\x19e\xdc\xdc\x96F\x8c",
+                decomp_chunks=[b'\x03\x07\x90\x03\x041234'],
             )
             )
+            expected.offset = 12
+            got = p.get_unpacked_object(self.blobs[b"foo1234"].id)
+            self.assertEqual(expected, got)
         with self.make_pack(True) as p:
         with self.make_pack(True) as p:
+            expected = UnpackedObject(
+                7,
+                delta_base=b"\x19\x10(\x15f=#\xf8\xb7ZG\xe7\xa0\x19e\xdc\xdc\x96F\x8c",
+                decomp_chunks=[b'\x03\x07\x90\x03\x041234'],
+            )
+            expected.offset = 12
+            got = p.get_unpacked_object(self.blobs[b"foo1234"].id)
             self.assertEqual(
             self.assertEqual(
-                (
-                    7,
-                    b"\x19\x10(\x15f=#\xf8\xb7ZG\xe7\xa0\x19e\xdc\xdc\x96F\x8c",
-                    [b"x\x9ccf\x9f\xc0\xccbhdl\x02\x00\x06f\x01l"],
-                ),
-                p.get_raw_unresolved(self.blobs[b"foo1234"].id),
+                expected,
+                got,
             )
             )
 
 
     def test_iterobjects(self):
     def test_iterobjects(self):
@@ -801,7 +803,7 @@ class ReadZlibTests(TestCase):
     def setUp(self):
     def setUp(self):
         super().setUp()
         super().setUp()
         self.read = BytesIO(self.comp + self.extra).read
         self.read = BytesIO(self.comp + self.extra).read
-        self.unpacked = UnpackedObject(Tree.type_num, None, len(self.decomp), 0)
+        self.unpacked = UnpackedObject(Tree.type_num, decomp_len=len(self.decomp), crc32=0)
 
 
     def test_decompress_size(self):
     def test_decompress_size(self):
         good_decomp_len = len(self.decomp)
         good_decomp_len = len(self.decomp)
@@ -820,7 +822,7 @@ class ReadZlibTests(TestCase):
         self.assertRaises(zlib.error, read_zlib_chunks, read, self.unpacked)
         self.assertRaises(zlib.error, read_zlib_chunks, read, self.unpacked)
 
 
     def test_decompress_empty(self):
     def test_decompress_empty(self):
-        unpacked = UnpackedObject(Tree.type_num, None, 0, None)
+        unpacked = UnpackedObject(Tree.type_num, decomp_len=0)
         comp = zlib.compress(b"")
         comp = zlib.compress(b"")
         read = BytesIO(comp + self.extra).read
         read = BytesIO(comp + self.extra).read
         unused = read_zlib_chunks(read, unpacked)
         unused = read_zlib_chunks(read, unpacked)
@@ -872,7 +874,7 @@ class DeltifyTests(TestCase):
     def test_single(self):
     def test_single(self):
         b = Blob.from_string(b"foo")
         b = Blob.from_string(b"foo")
         self.assertEqual(
         self.assertEqual(
-            [(b.type_num, b.sha().digest(), None, b.as_raw_chunks())],
+            [UnpackedObject(b.type_num, sha=b.sha().digest(), delta_base=None, decomp_chunks=b.as_raw_chunks())],
             list(deltify_pack_objects([(b, b"")])),
             list(deltify_pack_objects([(b, b"")])),
         )
         )
 
 
@@ -882,8 +884,8 @@ class DeltifyTests(TestCase):
         delta = list(create_delta(b1.as_raw_chunks(), b2.as_raw_chunks()))
         delta = list(create_delta(b1.as_raw_chunks(), b2.as_raw_chunks()))
         self.assertEqual(
         self.assertEqual(
             [
             [
-                (b1.type_num, b1.sha().digest(), None, b1.as_raw_chunks()),
-                (b2.type_num, b2.sha().digest(), b1.sha().digest(), delta),
+                UnpackedObject(b1.type_num, sha=b1.sha().digest(), delta_base=None, decomp_chunks=b1.as_raw_chunks()),
+                UnpackedObject(b2.type_num, sha=b2.sha().digest(), delta_base=b1.sha().digest(), decomp_chunks=delta),
             ],
             ],
             list(deltify_pack_objects([(b1, b""), (b2, b"")])),
             list(deltify_pack_objects([(b1, b""), (b2, b"")])),
         )
         )
@@ -943,7 +945,7 @@ class TestPackStreamReader(TestCase):
 
 
     def test_read_objects_empty(self):
     def test_read_objects_empty(self):
         reader = PackStreamReader(BytesIO().read)
         reader = PackStreamReader(BytesIO().read)
-        self.assertEqual([], list(reader.read_objects()))
+        self.assertRaises(AssertionError, list, reader.read_objects())
 
 
 
 
 class TestPackIterator(DeltaChainIterator):
 class TestPackIterator(DeltaChainIterator):
@@ -1006,6 +1008,16 @@ class DeltaChainIteratorTests(TestCase):
         data = PackData("test.pack", file=f)
         data = PackData("test.pack", file=f)
         return TestPackIterator.for_pack_data(data, resolve_ext_ref=resolve_ext_ref)
         return TestPackIterator.for_pack_data(data, resolve_ext_ref=resolve_ext_ref)
 
 
+    def make_pack_iter_subset(self, f, subset, thin=None):
+        if thin is None:
+            thin = bool(list(self.store))
+        resolve_ext_ref = thin and self.get_raw_no_repeat or None
+        data = PackData("test.pack", file=f)
+        assert data
+        index = MemoryPackIndex.for_pack(data)
+        pack = Pack.from_objects(data, index)
+        return TestPackIterator.for_pack_subset(pack, subset, resolve_ext_ref=resolve_ext_ref)
+
     def assertEntriesMatch(self, expected_indexes, entries, pack_iter):
     def assertEntriesMatch(self, expected_indexes, entries, pack_iter):
         expected = [entries[i] for i in expected_indexes]
         expected = [entries[i] for i in expected_indexes]
         self.assertEqual(expected, list(pack_iter._walk_all_chains()))
         self.assertEqual(expected, list(pack_iter._walk_all_chains()))
@@ -1021,6 +1033,10 @@ class DeltaChainIteratorTests(TestCase):
             ],
             ],
         )
         )
         self.assertEntriesMatch([0, 1, 2], entries, self.make_pack_iter(f))
         self.assertEntriesMatch([0, 1, 2], entries, self.make_pack_iter(f))
+        f.seek(0)
+        self.assertEntriesMatch([], entries, self.make_pack_iter_subset(f, []))
+        f.seek(0)
+        self.assertEntriesMatch([1, 0], entries, self.make_pack_iter_subset(f, [entries[0][3], entries[1][3]]))
 
 
     def test_ofs_deltas(self):
     def test_ofs_deltas(self):
         f = BytesIO()
         f = BytesIO()
@@ -1034,6 +1050,10 @@ class DeltaChainIteratorTests(TestCase):
         )
         )
         # Delta resolution changed to DFS
         # Delta resolution changed to DFS
         self.assertEntriesMatch([0, 2, 1], entries, self.make_pack_iter(f))
         self.assertEntriesMatch([0, 2, 1], entries, self.make_pack_iter(f))
+        f.seek(0)
+        self.assertEntriesMatch(
+            [0, 2, 1], entries,
+            self.make_pack_iter_subset(f, [entries[1][3], entries[2][3]]))
 
 
     def test_ofs_deltas_chain(self):
     def test_ofs_deltas_chain(self):
         f = BytesIO()
         f = BytesIO()

+ 1 - 1
dulwich/tests/test_repository.py

@@ -539,7 +539,7 @@ class RepositoryRootTests(TestCase):
         This test demonstrates that ``find_common_revisions()`` actually
         This test demonstrates that ``find_common_revisions()`` actually
         returns common heads, not revisions; dulwich already uses
         returns common heads, not revisions; dulwich already uses
         ``find_common_revisions()`` in such a manner (see
         ``find_common_revisions()`` in such a manner (see
-        ``Repo.fetch_objects()``).
+        ``Repo.find_objects()``).
         """
         """
 
 
         expected_shas = {b"60dacdc733de308bb77bb76ce0fb0f9b44c9769e"}
         expected_shas = {b"60dacdc733de308bb77bb76ce0fb0f9b44c9769e"}

+ 13 - 5
dulwich/tests/test_server.py

@@ -165,7 +165,10 @@ class HandlerTestCase(TestCase):
 class UploadPackHandlerTestCase(TestCase):
 class UploadPackHandlerTestCase(TestCase):
     def setUp(self):
     def setUp(self):
         super().setUp()
         super().setUp()
-        self._repo = MemoryRepo.init_bare([], {})
+        self.path = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, self.path)
+        self.repo = Repo.init(self.path)
+        self._repo = Repo.init_bare(self.path)
         backend = DictBackend({b"/": self._repo})
         backend = DictBackend({b"/": self._repo})
         self._handler = UploadPackHandler(
         self._handler = UploadPackHandler(
             backend, [b"/", b"host=lolcathost"], TestProto()
             backend, [b"/", b"host=lolcathost"], TestProto()
@@ -174,6 +177,7 @@ class UploadPackHandlerTestCase(TestCase):
     def test_progress(self):
     def test_progress(self):
         caps = self._handler.required_capabilities()
         caps = self._handler.required_capabilities()
         self._handler.set_client_capabilities(caps)
         self._handler.set_client_capabilities(caps)
+        self._handler._start_pack_send_phase()
         self._handler.progress(b"first message")
         self._handler.progress(b"first message")
         self._handler.progress(b"second message")
         self._handler.progress(b"second message")
         self.assertEqual(b"first message", self._handler.proto.get_received_line(2))
         self.assertEqual(b"first message", self._handler.proto.get_received_line(2))
@@ -195,12 +199,14 @@ class UploadPackHandlerTestCase(TestCase):
         }
         }
         # repo needs to peel this object
         # repo needs to peel this object
         self._repo.object_store.add_object(make_commit(id=FOUR))
         self._repo.object_store.add_object(make_commit(id=FOUR))
-        self._repo.refs._update(refs)
+        for name, sha in refs.items():
+            self._repo.refs[name] = sha
         peeled = {
         peeled = {
             b"refs/tags/tag1": b"1234" * 10,
             b"refs/tags/tag1": b"1234" * 10,
             b"refs/tags/tag2": b"5678" * 10,
             b"refs/tags/tag2": b"5678" * 10,
         }
         }
-        self._repo.refs._update_peeled(peeled)
+        self._repo.refs._peeled_refs = peeled
+        self._repo.refs.add_packed_refs(refs)
 
 
         caps = list(self._handler.required_capabilities()) + [b"include-tag"]
         caps = list(self._handler.required_capabilities()) + [b"include-tag"]
         self._handler.set_client_capabilities(caps)
         self._handler.set_client_capabilities(caps)
@@ -221,7 +227,8 @@ class UploadPackHandlerTestCase(TestCase):
         tree = Tree()
         tree = Tree()
         self._repo.object_store.add_object(tree)
         self._repo.object_store.add_object(tree)
         self._repo.object_store.add_object(make_commit(id=ONE, tree=tree))
         self._repo.object_store.add_object(make_commit(id=ONE, tree=tree))
-        self._repo.refs._update(refs)
+        for name, sha in refs.items():
+            self._repo.refs[name] = sha
         self._handler.proto.set_output(
         self._handler.proto.set_output(
             [
             [
                 b"want " + ONE + b" side-band-64k thin-pack ofs-delta",
                 b"want " + ONE + b" side-band-64k thin-pack ofs-delta",
@@ -241,7 +248,8 @@ class UploadPackHandlerTestCase(TestCase):
         tree = Tree()
         tree = Tree()
         self._repo.object_store.add_object(tree)
         self._repo.object_store.add_object(tree)
         self._repo.object_store.add_object(make_commit(id=ONE, tree=tree))
         self._repo.object_store.add_object(make_commit(id=ONE, tree=tree))
-        self._repo.refs._update(refs)
+        for ref, sha in refs.items():
+            self._repo.refs[ref] = sha
         self._handler.proto.set_output([None])
         self._handler.proto.set_output([None])
         self._handler.handle()
         self._handler.handle()
         # The server should not send a pack, since the client didn't ask for
         # The server should not send a pack, since the client didn't ask for

+ 2 - 2
dulwich/walk.py

@@ -24,7 +24,7 @@
 import collections
 import collections
 import heapq
 import heapq
 from itertools import chain
 from itertools import chain
-from typing import List, Tuple, Set, Deque, Literal, Optional
+from typing import List, Tuple, Set, Deque, Optional
 
 
 from dulwich.diff_tree import (
 from dulwich.diff_tree import (
     RENAME_CHANGE_TYPES,
     RENAME_CHANGE_TYPES,
@@ -244,7 +244,7 @@ class Walker:
         store,
         store,
         include: List[bytes],
         include: List[bytes],
         exclude: Optional[List[bytes]] = None,
         exclude: Optional[List[bytes]] = None,
-        order: Literal["date", "topo"] = 'date',
+        order: str = 'date',
         reverse: bool = False,
         reverse: bool = False,
         max_entries: Optional[int] = None,
         max_entries: Optional[int] = None,
         paths: Optional[List[bytes]] = None,
         paths: Optional[List[bytes]] = None,

+ 1 - 0
setup.cfg

@@ -52,6 +52,7 @@ packages =
 include_package_data = True
 include_package_data = True
 install_requires =
 install_requires =
     urllib3>=1.25
     urllib3>=1.25
+    typing_extensions;python_version<="3.7"
 zip_safe = False
 zip_safe = False
 scripts =
 scripts =
     bin/dul-receive-pack
     bin/dul-receive-pack

Some files were not shown because too many files changed in this diff