Bläddra i källkod

Merge pull request #1113 from jelmer/pack-refactor

Various bits of pack refactoring; add more typing
Jelmer Vernooij 2 år sedan
förälder
incheckning
f40a609689

+ 2 - 2
dulwich/client.py

@@ -1515,7 +1515,7 @@ class LocalGitClient(GitClient):
         pack_data,
         pack_data,
         progress=None,
         progress=None,
         depth=None,
         depth=None,
-    ):
+    ) -> FetchPackResult:
         """Retrieve a pack from a git smart server.
         """Retrieve a pack from a git smart server.
 
 
         Args:
         Args:
@@ -1543,7 +1543,7 @@ class LocalGitClient(GitClient):
             # Note that the client still expects a 0-object pack in most cases.
             # Note that the client still expects a 0-object pack in most cases.
             if objects_iter is None:
             if objects_iter is None:
                 return FetchPackResult(None, symrefs, agent)
                 return FetchPackResult(None, symrefs, agent)
-            write_pack_objects(pack_data, objects_iter)
+            write_pack_objects(pack_data, objects_iter, reuse_pack=r.object_store)
             return FetchPackResult(r.get_refs(), symrefs, agent)
             return FetchPackResult(r.get_refs(), symrefs, agent)
 
 
     def get_refs(self, path):
     def get_refs(self, path):

+ 4 - 1
dulwich/fastexport.py

@@ -30,6 +30,9 @@ from dulwich.objects import (
     Tag,
     Tag,
     ZERO_SHA,
     ZERO_SHA,
 )
 )
+from dulwich.object_store import (
+    iter_tree_contents,
+)
 from fastimport import (
 from fastimport import (
     commands,
     commands,
     errors as fastimport_errors,
     errors as fastimport_errors,
@@ -232,7 +235,7 @@ class GitImportProcessor(processor.ImportProcessor):
                 path,
                 path,
                 mode,
                 mode,
                 hexsha,
                 hexsha,
-            ) in self.repo.object_store.iter_tree_contents(tree_id):
+            ) in iter_tree_contents(self.repo.object_store, tree_id):
                 self._contents[path] = (mode, hexsha)
                 self._contents[path] = (mode, hexsha)
 
 
     def reset_handler(self, cmd):
     def reset_handler(self, cmd):

+ 7 - 6
dulwich/greenthreads.py

@@ -31,12 +31,13 @@ from dulwich.objects import (
 )
 )
 from dulwich.object_store import (
 from dulwich.object_store import (
     MissingObjectFinder,
     MissingObjectFinder,
+    _collect_ancestors,
     _collect_filetree_revs,
     _collect_filetree_revs,
     ObjectStoreIterator,
     ObjectStoreIterator,
 )
 )
 
 
 
 
-def _split_commits_and_tags(obj_store, lst, ignore_unknown=False, pool=None):
+def _split_commits_and_tags(obj_store, lst, *, ignore_unknown=False, pool=None):
     """Split object id list into two list with commit SHA1s and tag SHA1s.
     """Split object id list into two list with commit SHA1s and tag SHA1s.
 
 
     Same implementation as object_store._split_commits_and_tags
     Same implementation as object_store._split_commits_and_tags
@@ -90,11 +91,11 @@ class GreenThreadsMissingObjectFinder(MissingObjectFinder):
         self.object_store = object_store
         self.object_store = object_store
         p = pool.Pool(size=concurrency)
         p = pool.Pool(size=concurrency)
 
 
-        have_commits, have_tags = _split_commits_and_tags(object_store, haves, True, p)
-        want_commits, want_tags = _split_commits_and_tags(object_store, wants, False, p)
-        all_ancestors = object_store._collect_ancestors(have_commits)[0]
-        missing_commits, common_commits = object_store._collect_ancestors(
-            want_commits, all_ancestors
+        have_commits, have_tags = _split_commits_and_tags(object_store, haves, ignore_unknown=True, pool=p)
+        want_commits, want_tags = _split_commits_and_tags(object_store, wants, ignore_unknown=False, pool=p)
+        all_ancestors = _collect_ancestors(object_store, have_commits)[0]
+        missing_commits, common_commits = _collect_ancestors(
+            object_store, want_commits, all_ancestors
         )
         )
 
 
         self.sha_done = set()
         self.sha_done = set()

+ 11 - 12
dulwich/index.py

@@ -32,16 +32,12 @@ from typing import (
     Dict,
     Dict,
     List,
     List,
     Optional,
     Optional,
-    TYPE_CHECKING,
     Iterable,
     Iterable,
     Iterator,
     Iterator,
     Tuple,
     Tuple,
     Union,
     Union,
 )
 )
 
 
-if TYPE_CHECKING:
-    from dulwich.object_store import BaseObjectStore
-
 from dulwich.file import GitFile
 from dulwich.file import GitFile
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
@@ -52,9 +48,11 @@ from dulwich.objects import (
     sha_to_hex,
     sha_to_hex,
     ObjectID,
     ObjectID,
 )
 )
+from dulwich.object_store import iter_tree_contents
 from dulwich.pack import (
 from dulwich.pack import (
     SHA1Reader,
     SHA1Reader,
     SHA1Writer,
     SHA1Writer,
+    ObjectContainer,
 )
 )
 
 
 
 
@@ -451,7 +449,7 @@ class Index:
 
 
 
 
 def commit_tree(
 def commit_tree(
-    object_store: "BaseObjectStore", blobs: Iterable[Tuple[bytes, bytes, int]]
+    object_store: ObjectContainer, blobs: Iterable[Tuple[bytes, bytes, int]]
 ) -> bytes:
 ) -> bytes:
     """Commit a new tree.
     """Commit a new tree.
 
 
@@ -494,7 +492,7 @@ def commit_tree(
     return build_tree(b"")
     return build_tree(b"")
 
 
 
 
-def commit_index(object_store: "BaseObjectStore", index: Index) -> bytes:
+def commit_index(object_store: ObjectContainer, index: Index) -> bytes:
     """Create a new tree from an index.
     """Create a new tree from an index.
 
 
     Args:
     Args:
@@ -509,7 +507,7 @@ def commit_index(object_store: "BaseObjectStore", index: Index) -> bytes:
 def changes_from_tree(
 def changes_from_tree(
     names: Iterable[bytes],
     names: Iterable[bytes],
     lookup_entry: Callable[[bytes], Tuple[bytes, int]],
     lookup_entry: Callable[[bytes], Tuple[bytes, int]],
-    object_store: "BaseObjectStore",
+    object_store: ObjectContainer,
     tree: Optional[bytes],
     tree: Optional[bytes],
     want_unchanged=False,
     want_unchanged=False,
 ) -> Iterable[
 ) -> Iterable[
@@ -535,7 +533,7 @@ def changes_from_tree(
     other_names = set(names)
     other_names = set(names)
 
 
     if tree is not None:
     if tree is not None:
-        for (name, mode, sha) in object_store.iter_tree_contents(tree):
+        for (name, mode, sha) in iter_tree_contents(object_store, tree):
             try:
             try:
                 (other_sha, other_mode) = lookup_entry(name)
                 (other_sha, other_mode) = lookup_entry(name)
             except KeyError:
             except KeyError:
@@ -686,7 +684,7 @@ def validate_path(path: bytes,
 def build_index_from_tree(
 def build_index_from_tree(
     root_path: Union[str, bytes],
     root_path: Union[str, bytes],
     index_path: Union[str, bytes],
     index_path: Union[str, bytes],
-    object_store: "BaseObjectStore",
+    object_store: ObjectContainer,
     tree_id: bytes,
     tree_id: bytes,
     honor_filemode: bool = True,
     honor_filemode: bool = True,
     validate_path_element=validate_path_element_default,
     validate_path_element=validate_path_element_default,
@@ -711,7 +709,7 @@ def build_index_from_tree(
     if not isinstance(root_path, bytes):
     if not isinstance(root_path, bytes):
         root_path = os.fsencode(root_path)
         root_path = os.fsencode(root_path)
 
 
-    for entry in object_store.iter_tree_contents(tree_id):
+    for entry in iter_tree_contents(object_store, tree_id):
         if not validate_path(entry.path, validate_path_element):
         if not validate_path(entry.path, validate_path_element):
             continue
             continue
         full_path = _tree_to_fs_path(root_path, entry.path)
         full_path = _tree_to_fs_path(root_path, entry.path)
@@ -727,6 +725,7 @@ def build_index_from_tree(
             # TODO(jelmer): record and return submodule paths
             # TODO(jelmer): record and return submodule paths
         else:
         else:
             obj = object_store[entry.sha]
             obj = object_store[entry.sha]
+            assert isinstance(obj, Blob)
             st = build_file_from_blob(
             st = build_file_from_blob(
                 obj, entry.mode, full_path,
                 obj, entry.mode, full_path,
                 honor_filemode=honor_filemode,
                 honor_filemode=honor_filemode,
@@ -927,7 +926,7 @@ def index_entry_from_directory(st, path: bytes) -> Optional[IndexEntry]:
 
 
 
 
 def index_entry_from_path(
 def index_entry_from_path(
-        path: bytes, object_store: Optional["BaseObjectStore"] = None
+        path: bytes, object_store: Optional[ObjectContainer] = None
 ) -> Optional[IndexEntry]:
 ) -> Optional[IndexEntry]:
     """Create an index from a filesystem path.
     """Create an index from a filesystem path.
 
 
@@ -957,7 +956,7 @@ def index_entry_from_path(
 
 
 def iter_fresh_entries(
 def iter_fresh_entries(
     paths: Iterable[bytes], root_path: bytes,
     paths: Iterable[bytes], root_path: bytes,
-    object_store: Optional["BaseObjectStore"] = None
+    object_store: Optional[ObjectContainer] = None
 ) -> Iterator[Tuple[bytes, Optional[IndexEntry]]]:
 ) -> Iterator[Tuple[bytes, Optional[IndexEntry]]]:
     """Iterate over current versions of index entries on disk.
     """Iterate over current versions of index entries on disk.
 
 

+ 2 - 1
dulwich/line_ending.py

@@ -136,6 +136,7 @@ Sources:
 - https://adaptivepatchwork.com/2012/03/01/mind-the-end-of-your-line/
 - https://adaptivepatchwork.com/2012/03/01/mind-the-end-of-your-line/
 """
 """
 
 
+from dulwich.object_store import iter_tree_contents
 from dulwich.objects import Blob
 from dulwich.objects import Blob
 from dulwich.patch import is_binary
 from dulwich.patch import is_binary
 
 
@@ -290,7 +291,7 @@ class TreeBlobNormalizer(BlobNormalizer):
         if tree:
         if tree:
             self.existing_paths = {
             self.existing_paths = {
                 name
                 name
-                for name, _, _ in object_store.iter_tree_contents(tree)
+                for name, _, _ in iter_tree_contents(object_store, tree)
             }
             }
         else:
         else:
             self.existing_paths = set()
             self.existing_paths = set()

+ 160 - 64
dulwich/object_store.py

@@ -26,13 +26,10 @@ from io import BytesIO
 import os
 import os
 import stat
 import stat
 import sys
 import sys
+import warnings
 
 
-from typing import Callable, Dict, List, Optional, Tuple
+from typing import Callable, Dict, List, Optional, Tuple, Protocol, Union, Iterator, Set
 
 
-from dulwich.diff_tree import (
-    tree_changes,
-    walk_trees,
-)
 from dulwich.errors import (
 from dulwich.errors import (
     NotTreeError,
     NotTreeError,
 )
 )
@@ -48,10 +45,12 @@ from dulwich.objects import (
     sha_to_hex,
     sha_to_hex,
     hex_to_filename,
     hex_to_filename,
     S_ISGITLINK,
     S_ISGITLINK,
+    TreeEntry,
     object_class,
     object_class,
     valid_hexsha,
     valid_hexsha,
 )
 )
 from dulwich.pack import (
 from dulwich.pack import (
+    ObjectContainer,
     Pack,
     Pack,
     PackData,
     PackData,
     PackInflater,
     PackInflater,
@@ -79,6 +78,14 @@ PACKDIR = "pack"
 PACK_MODE = 0o444 if sys.platform != "win32" else 0o644
 PACK_MODE = 0o444 if sys.platform != "win32" else 0o644
 
 
 
 
+class PackContainer(Protocol):
+
+    def add_pack(
+        self
+    ) -> Tuple[BytesIO, Callable[[], None], Callable[[], None]]:
+        """Add a new pack."""
+
+
 class BaseObjectStore:
 class BaseObjectStore:
     """Object store interface."""
     """Object store interface."""
 
 
@@ -213,6 +220,8 @@ class BaseObjectStore:
         Returns: Iterator over tuples with
         Returns: Iterator over tuples with
             (oldpath, newpath), (oldmode, newmode), (oldsha, newsha)
             (oldpath, newpath), (oldmode, newmode), (oldsha, newsha)
         """
         """
+
+        from dulwich.diff_tree import tree_changes
         for change in tree_changes(
         for change in tree_changes(
             self,
             self,
             source,
             source,
@@ -239,11 +248,10 @@ class BaseObjectStore:
         Returns: Iterator over TreeEntry namedtuples for all the objects in a
         Returns: Iterator over TreeEntry namedtuples for all the objects in a
             tree.
             tree.
         """
         """
-        for entry, _ in walk_trees(self, tree_id, None):
-            if (
-                entry.mode is not None and not stat.S_ISDIR(entry.mode)
-            ) or include_trees:
-                yield entry
+        warnings.warn(
+            "Please use dulwich.object_store.iter_tree_contents",
+            DeprecationWarning, stacklevel=2)
+        return iter_tree_contents(self, tree_id, include_trees=include_trees)
 
 
     def find_missing_objects(
     def find_missing_objects(
         self,
         self,
@@ -334,47 +342,10 @@ class BaseObjectStore:
             intermediate tags; if the original ref does not point to a tag,
             intermediate tags; if the original ref does not point to a tag,
             this will equal the original SHA1.
             this will equal the original SHA1.
         """
         """
-        obj = self[sha]
-        obj_class = object_class(obj.type_name)
-        while obj_class is Tag:
-            obj_class, sha = obj.object
-            obj = self[sha]
-        return obj
-
-    def _collect_ancestors(
-        self,
-        heads,
-        common=frozenset(),
-        shallow=frozenset(),
-        get_parents=lambda commit: commit.parents,
-    ):
-        """Collect all ancestors of heads up to (excluding) those in common.
-
-        Args:
-          heads: commits to start from
-          common: commits to end at, or empty set to walk repository
-            completely
-          get_parents: Optional function for getting the parents of a
-            commit.
-        Returns: a tuple (A, B) where A - all commits reachable
-            from heads but not present in common, B - common (shared) elements
-            that are directly reachable from heads
-        """
-        bases = set()
-        commits = set()
-        queue = []
-        queue.extend(heads)
-        while queue:
-            e = queue.pop(0)
-            if e in common:
-                bases.add(e)
-            elif e not in commits:
-                commits.add(e)
-                if e in shallow:
-                    continue
-                cmt = self[e]
-                queue.extend(get_parents(cmt))
-        return (commits, bases)
+        warnings.warn(
+            "Please use dulwich.object_store.peel_sha()",
+            DeprecationWarning, stacklevel=2)
+        return peel_sha(self, sha)
 
 
     def _get_depth(
     def _get_depth(
         self, head, get_parents=lambda commit: commit.parents, max_depth=None,
         self, head, get_parents=lambda commit: commit.parents, max_depth=None,
@@ -592,6 +563,46 @@ class PackBasedObjectStore(BaseObjectStore):
                 pass
                 pass
         raise KeyError(hexsha)
         raise KeyError(hexsha)
 
 
+    def get_raw_unresolved(self, name: bytes) -> Tuple[int, Union[bytes, None], List[bytes]]:
+        """Obtain the unresolved data for an object.
+
+        Args:
+          name: sha for the object.
+        """
+        if name == ZERO_SHA:
+            raise KeyError(name)
+        if len(name) == 40:
+            sha = hex_to_sha(name)
+            hexsha = name
+        elif len(name) == 20:
+            sha = name
+            hexsha = None
+        else:
+            raise AssertionError("Invalid object name {!r}".format(name))
+        for pack in self._iter_cached_packs():
+            try:
+                return pack.get_raw_unresolved(sha)
+            except (KeyError, PackFileDisappeared):
+                pass
+        if hexsha is None:
+            hexsha = sha_to_hex(name)
+        ret = self._get_loose_object(hexsha)
+        if ret is not None:
+            return ret.type_num, None, ret.as_raw_chunks()
+        # Maybe something else has added a pack with the object
+        # in the mean time?
+        for pack in self._update_pack_cache():
+            try:
+                return pack.get_raw_unresolved(sha)
+            except KeyError:
+                pass
+        for alternate in self.alternates:
+            try:
+                return alternate.get_raw_unresolved(hexsha)
+            except KeyError:
+                pass
+        raise KeyError(hexsha)
+
     def add_objects(self, objects, progress=None):
     def add_objects(self, objects, progress=None):
         """Add a set of objects to this object store.
         """Add a set of objects to this object store.
 
 
@@ -1083,10 +1094,10 @@ class MemoryObjectStore(BaseObjectStore):
             commit()
             commit()
 
 
 
 
-class ObjectIterator:
+class ObjectIterator(Protocol):
     """Interface for iterating over objects."""
     """Interface for iterating over objects."""
 
 
-    def iterobjects(self):
+    def iterobjects(self) -> Iterator[ShaFile]:
         raise NotImplementedError(self.iterobjects)
         raise NotImplementedError(self.iterobjects)
 
 
 
 
@@ -1178,7 +1189,7 @@ def tree_lookup_path(lookup_obj, root_sha, path):
     return tree.lookup_path(lookup_obj, path)
     return tree.lookup_path(lookup_obj, path)
 
 
 
 
-def _collect_filetree_revs(obj_store, tree_sha, kset):
+def _collect_filetree_revs(obj_store: ObjectContainer, tree_sha: ObjectID, kset: Set[ObjectID]) -> None:
     """Collect SHA1s of files and directories for specified tree.
     """Collect SHA1s of files and directories for specified tree.
 
 
     Args:
     Args:
@@ -1187,6 +1198,7 @@ def _collect_filetree_revs(obj_store, tree_sha, kset):
       kset: set to fill with references to files and directories
       kset: set to fill with references to files and directories
     """
     """
     filetree = obj_store[tree_sha]
     filetree = obj_store[tree_sha]
+    assert isinstance(filetree, Tree)
     for name, mode, sha in filetree.iteritems():
     for name, mode, sha in filetree.iteritems():
         if not S_ISGITLINK(mode) and sha not in kset:
         if not S_ISGITLINK(mode) and sha not in kset:
             kset.add(sha)
             kset.add(sha)
@@ -1194,7 +1206,7 @@ def _collect_filetree_revs(obj_store, tree_sha, kset):
                 _collect_filetree_revs(obj_store, sha, kset)
                 _collect_filetree_revs(obj_store, sha, kset)
 
 
 
 
-def _split_commits_and_tags(obj_store, lst, ignore_unknown=False):
+def _split_commits_and_tags(obj_store: ObjectContainer, lst, *, ignore_unknown=False) -> Tuple[Set[bytes], Set[bytes], Set[bytes]]:
     """Split object id list into three lists with commit, tag, and other SHAs.
     """Split object id list into three lists with commit, tag, and other SHAs.
 
 
     Commits referenced by tags are included into commits
     Commits referenced by tags are included into commits
@@ -1209,9 +1221,9 @@ def _split_commits_and_tags(obj_store, lst, ignore_unknown=False):
         silently.
         silently.
     Returns: A tuple of (commits, tags, others) SHA1s
     Returns: A tuple of (commits, tags, others) SHA1s
     """
     """
-    commits = set()
-    tags = set()
-    others = set()
+    commits: Set[bytes] = set()
+    tags: Set[bytes] = set()
+    others: Set[bytes] = set()
     for e in lst:
     for e in lst:
         try:
         try:
             o = obj_store[e]
             o = obj_store[e]
@@ -1224,12 +1236,12 @@ def _split_commits_and_tags(obj_store, lst, ignore_unknown=False):
             elif isinstance(o, Tag):
             elif isinstance(o, Tag):
                 tags.add(e)
                 tags.add(e)
                 tagged = o.object[1]
                 tagged = o.object[1]
-                c, t, o = _split_commits_and_tags(
+                c, t, os = _split_commits_and_tags(
                     obj_store, [tagged], ignore_unknown=ignore_unknown
                     obj_store, [tagged], ignore_unknown=ignore_unknown
                 )
                 )
                 commits |= c
                 commits |= c
                 tags |= t
                 tags |= t
-                others |= o
+                others |= os
             else:
             else:
                 others.add(e)
                 others.add(e)
     return (commits, tags, others)
     return (commits, tags, others)
@@ -1270,20 +1282,22 @@ class MissingObjectFinder:
         # wants shall list only known SHAs, and otherwise
         # wants shall list only known SHAs, and otherwise
         # _split_commits_and_tags fails with KeyError
         # _split_commits_and_tags fails with KeyError
         have_commits, have_tags, have_others = _split_commits_and_tags(
         have_commits, have_tags, have_others = _split_commits_and_tags(
-            object_store, haves, True
+            object_store, haves, ignore_unknown=True
         )
         )
         want_commits, want_tags, want_others = _split_commits_and_tags(
         want_commits, want_tags, want_others = _split_commits_and_tags(
-            object_store, wants, False
+            object_store, wants, ignore_unknown=False
         )
         )
         # all_ancestors is a set of commits that shall not be sent
         # all_ancestors is a set of commits that shall not be sent
         # (complete repository up to 'haves')
         # (complete repository up to 'haves')
-        all_ancestors = object_store._collect_ancestors(
+        all_ancestors = _collect_ancestors(
+            object_store,
             have_commits, shallow=shallow, get_parents=self._get_parents
             have_commits, shallow=shallow, get_parents=self._get_parents
         )[0]
         )[0]
         # all_missing - complete set of commits between haves and wants
         # all_missing - complete set of commits between haves and wants
         # common - commits from all_ancestors we hit into while
         # common - commits from all_ancestors we hit into while
         # traversing parent hierarchy of wants
         # traversing parent hierarchy of wants
-        missing_commits, common_commits = object_store._collect_ancestors(
+        missing_commits, common_commits = _collect_ancestors(
+            object_store,
             want_commits,
             want_commits,
             all_ancestors,
             all_ancestors,
             shallow=shallow,
             shallow=shallow,
@@ -1606,3 +1620,85 @@ class BucketBasedObjectStore(PackBasedObjectStore):
             return final_pack
             return final_pack
 
 
         return pf, commit, pf.close
         return pf, commit, pf.close
+
+
+def _collect_ancestors(
+    store: ObjectContainer,
+    heads,
+    common=frozenset(),
+    shallow=frozenset(),
+    get_parents=lambda commit: commit.parents,
+):
+    """Collect all ancestors of heads up to (excluding) those in common.
+
+    Args:
+      heads: commits to start from
+      common: commits to end at, or empty set to walk repository
+        completely
+      get_parents: Optional function for getting the parents of a
+        commit.
+    Returns: a tuple (A, B) where A - all commits reachable
+        from heads but not present in common, B - common (shared) elements
+        that are directly reachable from heads
+    """
+    bases = set()
+    commits = set()
+    queue = []
+    queue.extend(heads)
+    while queue:
+        e = queue.pop(0)
+        if e in common:
+            bases.add(e)
+        elif e not in commits:
+            commits.add(e)
+            if e in shallow:
+                continue
+            cmt = store[e]
+            queue.extend(get_parents(cmt))
+    return (commits, bases)
+
+
+def iter_tree_contents(
+        store: ObjectContainer, tree_id: bytes, *, include_trees: bool = False):
+    """Iterate the contents of a tree and all subtrees.
+
+    Iteration is depth-first pre-order, as in e.g. os.walk.
+
+    Args:
+      tree_id: SHA1 of the tree.
+      include_trees: If True, include tree objects in the iteration.
+    Returns: Iterator over TreeEntry namedtuples for all the objects in a
+        tree.
+    """
+    # This could be fairly easily generalized to >2 trees if we find a use
+    # case.
+    todo = [TreeEntry(b"", stat.S_IFDIR, tree_id)]
+    while todo:
+        entry = todo.pop()
+        if stat.S_ISDIR(entry.mode):
+            extra = []
+            tree = store[entry.sha]
+            assert isinstance(tree, Tree)
+            for subentry in tree.iteritems(name_order=True):
+                extra.append(subentry.in_path(entry.path))
+            todo.extend(reversed(extra))
+        if not stat.S_ISDIR(entry.mode) or include_trees:
+            yield entry
+
+
+def peel_sha(store: ObjectContainer, sha: bytes) -> ShaFile:
+    """Peel all tags from a SHA.
+
+    Args:
+      sha: The object SHA to peel.
+    Returns: The fully-peeled SHA1 of a tag object, after peeling all
+        intermediate tags; if the original ref does not point to a tag,
+        this will equal the original SHA1.
+    """
+    obj = store[sha]
+    obj_class = object_class(obj.type_name)
+    while obj_class is Tag:
+        assert isinstance(obj, Tag)
+        obj_class, sha = obj.object
+        obj = store[sha]
+    return obj

+ 89 - 45
dulwich/pack.py

@@ -49,7 +49,7 @@ from itertools import chain
 
 
 import os
 import os
 import sys
 import sys
-from typing import Optional, Callable, Tuple, List, Deque, Union
+from typing import Optional, Callable, Tuple, List, Deque, Union, Protocol, Iterable, Iterator
 import warnings
 import warnings
 
 
 from hashlib import sha1
 from hashlib import sha1
@@ -96,6 +96,34 @@ DELTA_TYPES = (OFS_DELTA, REF_DELTA)
 DEFAULT_PACK_DELTA_WINDOW_SIZE = 10
 DEFAULT_PACK_DELTA_WINDOW_SIZE = 10
 
 
 
 
+class ObjectContainer(Protocol):
+
+    def add_object(self, obj: ShaFile) -> None:
+        """Add a single object to this object store."""
+
+    def add_objects(
+            self, objects: Iterable[Tuple[ShaFile, Optional[str]]],
+            progress: Optional[Callable[[str], None]] = None) -> None:
+        """Add a set of objects to this object store.
+
+        Args:
+          objects: Iterable over a list of (object, path) tuples
+        """
+
+    def __contains__(self, sha1: bytes) -> bool:
+        """Check if a hex sha is present."""
+
+    def __getitem__(self, sha1: bytes) -> ShaFile:
+        """Retrieve an object."""
+
+
+class PackedObjectContainer(ObjectContainer):
+
+    def get_raw_unresolved(self, sha1: bytes) -> Tuple[int, Union[bytes, None], List[bytes]]:
+        """Get a raw unresolved object."""
+        raise NotImplementedError(self.get_raw_unresolved)
+
+
 def take_msb_bytes(read: Callable[[int], bytes], crc32: Optional[int] = None) -> Tuple[List[int], Optional[int]]:
 def take_msb_bytes(read: Callable[[int], bytes], crc32: Optional[int] = None) -> Tuple[List[int], Optional[int]]:
     """Read bytes marked with most significant bit.
     """Read bytes marked with most significant bit.
 
 
@@ -513,7 +541,7 @@ class FilePackIndex(PackIndex):
             self._contents, self._size = (contents, size)
             self._contents, self._size = (contents, size)
 
 
     @property
     @property
-    def path(self):
+    def path(self) -> str:
         return self._filename
         return self._filename
 
 
     def __eq__(self, other):
     def __eq__(self, other):
@@ -526,16 +554,16 @@ class FilePackIndex(PackIndex):
 
 
         return super().__eq__(other)
         return super().__eq__(other)
 
 
-    def close(self):
+    def close(self) -> None:
         self._file.close()
         self._file.close()
         if getattr(self._contents, "close", None) is not None:
         if getattr(self._contents, "close", None) is not None:
             self._contents.close()
             self._contents.close()
 
 
-    def __len__(self):
+    def __len__(self) -> int:
         """Return the number of entries in this pack index."""
         """Return the number of entries in this pack index."""
         return self._fan_out_table[-1]
         return self._fan_out_table[-1]
 
 
-    def _unpack_entry(self, i):
+    def _unpack_entry(self, i: int) -> Tuple[bytes, int, Optional[int]]:
         """Unpack the i-th entry in the index file.
         """Unpack the i-th entry in the index file.
 
 
         Returns: Tuple with object name (SHA), offset in pack file and CRC32
         Returns: Tuple with object name (SHA), offset in pack file and CRC32
@@ -555,11 +583,11 @@ class FilePackIndex(PackIndex):
         """Unpack the crc32 checksum for the ith object from the index file."""
         """Unpack the crc32 checksum for the ith object from the index file."""
         raise NotImplementedError(self._unpack_crc32_checksum)
         raise NotImplementedError(self._unpack_crc32_checksum)
 
 
-    def _itersha(self):
+    def _itersha(self) -> Iterator[bytes]:
         for i in range(len(self)):
         for i in range(len(self)):
             yield self._unpack_name(i)
             yield self._unpack_name(i)
 
 
-    def iterentries(self):
+    def iterentries(self) -> Iterator[Tuple[bytes, int, Optional[int]]]:
         """Iterate over the entries in this pack index.
         """Iterate over the entries in this pack index.
 
 
         Returns: iterator over tuples with object name, offset in packfile and
         Returns: iterator over tuples with object name, offset in packfile and
@@ -568,7 +596,7 @@ class FilePackIndex(PackIndex):
         for i in range(len(self)):
         for i in range(len(self)):
             yield self._unpack_entry(i)
             yield self._unpack_entry(i)
 
 
-    def _read_fan_out_table(self, start_offset):
+    def _read_fan_out_table(self, start_offset: int):
         ret = []
         ret = []
         for i in range(0x100):
         for i in range(0x100):
             fanout_entry = self._contents[
             fanout_entry = self._contents[
@@ -577,35 +605,35 @@ class FilePackIndex(PackIndex):
             ret.append(struct.unpack(">L", fanout_entry)[0])
             ret.append(struct.unpack(">L", fanout_entry)[0])
         return ret
         return ret
 
 
-    def check(self):
+    def check(self) -> None:
         """Check that the stored checksum matches the actual checksum."""
         """Check that the stored checksum matches the actual checksum."""
         actual = self.calculate_checksum()
         actual = self.calculate_checksum()
         stored = self.get_stored_checksum()
         stored = self.get_stored_checksum()
         if actual != stored:
         if actual != stored:
             raise ChecksumMismatch(stored, actual)
             raise ChecksumMismatch(stored, actual)
 
 
-    def calculate_checksum(self):
+    def calculate_checksum(self) -> bytes:
         """Calculate the SHA1 checksum over this pack index.
         """Calculate the SHA1 checksum over this pack index.
 
 
         Returns: This is a 20-byte binary digest
         Returns: This is a 20-byte binary digest
         """
         """
         return sha1(self._contents[:-20]).digest()
         return sha1(self._contents[:-20]).digest()
 
 
-    def get_pack_checksum(self):
+    def get_pack_checksum(self) -> bytes:
         """Return the SHA1 checksum stored for the corresponding packfile.
         """Return the SHA1 checksum stored for the corresponding packfile.
 
 
         Returns: 20-byte binary digest
         Returns: 20-byte binary digest
         """
         """
         return bytes(self._contents[-40:-20])
         return bytes(self._contents[-40:-20])
 
 
-    def get_stored_checksum(self):
+    def get_stored_checksum(self) -> bytes:
         """Return the SHA1 checksum stored for this index.
         """Return the SHA1 checksum stored for this index.
 
 
         Returns: 20-byte binary digest
         Returns: 20-byte binary digest
         """
         """
         return bytes(self._contents[-20:])
         return bytes(self._contents[-20:])
 
 
-    def object_index(self, sha):
+    def object_index(self, sha: bytes) -> int:
         """Return the index in to the corresponding packfile for the object.
         """Return the index in to the corresponding packfile for the object.
 
 
         Given the name of an object it will return the offset that object
         Given the name of an object it will return the offset that object
@@ -644,7 +672,7 @@ class FilePackIndex(PackIndex):
 class PackIndex1(FilePackIndex):
 class PackIndex1(FilePackIndex):
     """Version 1 Pack Index file."""
     """Version 1 Pack Index file."""
 
 
-    def __init__(self, filename, file=None, contents=None, size=None):
+    def __init__(self, filename: str, file=None, contents=None, size=None):
         super().__init__(filename, file, contents, size)
         super().__init__(filename, file, contents, size)
         self.version = 1
         self.version = 1
         self._fan_out_table = self._read_fan_out_table(0)
         self._fan_out_table = self._read_fan_out_table(0)
@@ -669,7 +697,7 @@ class PackIndex1(FilePackIndex):
 class PackIndex2(FilePackIndex):
 class PackIndex2(FilePackIndex):
     """Version 2 Pack Index file."""
     """Version 2 Pack Index file."""
 
 
-    def __init__(self, filename, file=None, contents=None, size=None):
+    def __init__(self, filename: str, file=None, contents=None, size=None):
         super().__init__(filename, file, contents, size)
         super().__init__(filename, file, contents, size)
         if self._contents[:4] != b"\377tOc":
         if self._contents[:4] != b"\377tOc":
             raise AssertionError("Not a v2 pack index file")
             raise AssertionError("Not a v2 pack index file")
@@ -707,7 +735,7 @@ class PackIndex2(FilePackIndex):
         return unpack_from(">L", self._contents, self._crc32_table_offset + i * 4)[0]
         return unpack_from(">L", self._contents, self._crc32_table_offset + i * 4)[0]
 
 
 
 
-def read_pack_header(read):
+def read_pack_header(read) -> Tuple[Optional[int], Optional[int]]:
     """Read the header of a pack file.
     """Read the header of a pack file.
 
 
     Args:
     Args:
@@ -727,7 +755,7 @@ def read_pack_header(read):
     return (version, num_objects)
     return (version, num_objects)
 
 
 
 
-def chunks_length(chunks):
+def chunks_length(chunks: Union[bytes, Iterable[bytes]]) -> int:
     if isinstance(chunks, bytes):
     if isinstance(chunks, bytes):
         return len(chunks)
         return len(chunks)
     else:
     else:
@@ -740,7 +768,7 @@ def unpack_object(
     compute_crc32=False,
     compute_crc32=False,
     include_comp=False,
     include_comp=False,
     zlib_bufsize=_ZLIB_BUFSIZE,
     zlib_bufsize=_ZLIB_BUFSIZE,
-):
+) -> Tuple[UnpackedObject, bytes]:
     """Unpack a Git object.
     """Unpack a Git object.
 
 
     Args:
     Args:
@@ -1596,12 +1624,13 @@ def write_pack_object(write, type, object, sha=None, compression_level=-1):
 
 
 
 
 def write_pack(
 def write_pack(
-    filename,
-    objects,
-    deltify=None,
-    delta_window_size=None,
-    compression_level=-1,
-):
+        filename,
+        objects,
+        *,
+        deltify: Optional[bool] = None,
+        delta_window_size: Optional[int] = None,
+        compression_level: int = -1,
+        reuse_pack: Optional[PackedObjectContainer] = None):
     """Write a new pack data file.
     """Write a new pack data file.
 
 
     Args:
     Args:
@@ -1619,6 +1648,7 @@ def write_pack(
             delta_window_size=delta_window_size,
             delta_window_size=delta_window_size,
             deltify=deltify,
             deltify=deltify,
             compression_level=compression_level,
             compression_level=compression_level,
+            reuse_pack=reuse_pack,
         )
         )
     entries = sorted([(k, v[0], v[1]) for (k, v) in entries.items()])
     entries = sorted([(k, v[0], v[1]) for (k, v) in entries.items()])
     with GitFile(filename + ".idx", "wb") as f:
     with GitFile(filename + ".idx", "wb") as f:
@@ -1643,7 +1673,10 @@ def write_pack_header(write, num_objects):
         write(chunk)
         write(chunk)
 
 
 
 
-def deltify_pack_objects(objects, window_size: Optional[int] = None, reuse_pack=None):
+def deltify_pack_objects(
+        objects: Iterable[Tuple[ShaFile, str]],
+        window_size: Optional[int] = None,
+        reuse_pack: Optional[PackedObjectContainer] = None):
     """Generate deltas for pack objects.
     """Generate deltas for pack objects.
 
 
     Args:
     Args:
@@ -1685,7 +1718,7 @@ def deltify_pack_objects(objects, window_size: Optional[int] = None, reuse_pack=
         magic.append((obj.type_num, path, -obj.raw_length(), obj))
         magic.append((obj.type_num, path, -obj.raw_length(), obj))
     magic.sort()
     magic.sort()
 
 
-    possible_bases: Deque[Tuple[bytes, int, bytes]] = deque()
+    possible_bases: Deque[Tuple[bytes, int, List[bytes]]] = deque()
 
 
     for type_num, path, neg_length, o in magic:
     for type_num, path, neg_length, o in magic:
         raw = o.as_raw_chunks()
         raw = o.as_raw_chunks()
@@ -1712,7 +1745,11 @@ def deltify_pack_objects(objects, window_size: Optional[int] = None, reuse_pack=
             possible_bases.pop()
             possible_bases.pop()
 
 
 
 
-def pack_objects_to_data(objects):
+def pack_objects_to_data(
+        objects,
+        delta_window_size: Optional[int] = None,
+        deltify: Optional[bool] = None,
+        reuse_pack: Optional[PackedObjectContainer] = None):
     """Create pack data from objects
     """Create pack data from objects
 
 
     Args:
     Args:
@@ -1720,17 +1757,30 @@ def pack_objects_to_data(objects):
     Returns: Tuples with (type_num, hexdigest, delta base, object chunks)
     Returns: Tuples with (type_num, hexdigest, delta base, object chunks)
     """
     """
     count = len(objects)
     count = len(objects)
-    return (
-        count,
-        (
-            (o.type_num, o.sha().digest(), None, o.as_raw_chunks())
-            for (o, path) in objects
-        ),
-    )
+    if deltify is None:
+        # PERFORMANCE/TODO(jelmer): This should be enabled but is *much* too
+        # slow at the moment.
+        deltify = False
+    if deltify:
+        pack_contents = deltify_pack_objects(
+            objects, window_size=delta_window_size, reuse_pack=reuse_pack)
+        return (count, pack_contents)
+    else:
+        return (
+            count,
+            (
+                (o.type_num, o.sha().digest(), None, o.as_raw_chunks())
+                for (o, path) in objects
+            ),
+        )
 
 
 
 
 def write_pack_objects(
 def write_pack_objects(
-    write, objects, delta_window_size=None, deltify=None, reuse_pack=None, compression_level=-1
+        write, objects,
+        delta_window_size: Optional[int] = None,
+        deltify: Optional[bool] = None,
+        reuse_pack: Optional[PackedObjectContainer] = None,
+        compression_level: int = -1
 ):
 ):
     """Write a new pack data file.
     """Write a new pack data file.
 
 
@@ -1751,16 +1801,10 @@ def write_pack_objects(
             DeprecationWarning, stacklevel=2)
             DeprecationWarning, stacklevel=2)
         write = write.write
         write = write.write
 
 
-    if deltify is None:
-        # PERFORMANCE/TODO(jelmer): This should be enabled but is *much* too
-        # slow at the moment.
-        deltify = False
-    if deltify:
-        pack_contents = deltify_pack_objects(
-            objects, window_size=delta_window_size, reuse_pack=reuse_pack)
-        pack_contents_count = len(objects)
-    else:
-        pack_contents_count, pack_contents = pack_objects_to_data(objects)
+    pack_contents_count, pack_contents = pack_objects_to_data(
+        objects, delta_window_size=delta_window_size,
+        deltify=deltify,
+        reuse_pack=reuse_pack)
 
 
     return write_pack_data(
     return write_pack_data(
         write,
         write,

+ 2 - 1
dulwich/patch.py

@@ -34,6 +34,7 @@ from dulwich.objects import (
     Commit,
     Commit,
     S_ISGITLINK,
     S_ISGITLINK,
 )
 )
+from dulwich.pack import ObjectContainer
 
 
 FIRST_FEW_BYTES = 8000
 FIRST_FEW_BYTES = 8000
 
 
@@ -192,7 +193,7 @@ def patch_filename(p, root):
         return root + b"/" + p
         return root + b"/" + p
 
 
 
 
-def write_object_diff(f, store, old_file, new_file, diff_binary=False):
+def write_object_diff(f, store: ObjectContainer, old_file, new_file, diff_binary=False):
     """Write the diff for an object.
     """Write the diff for an object.
 
 
     Args:
     Args:

+ 5 - 2
dulwich/refs.py

@@ -37,6 +37,7 @@ from dulwich.objects import (
     Tag,
     Tag,
     ObjectID,
     ObjectID,
 )
 )
+from dulwich.pack import ObjectContainer
 from dulwich.file import (
 from dulwich.file import (
     GitFile,
     GitFile,
     ensure_dir_exists,
     ensure_dir_exists,
@@ -1150,8 +1151,10 @@ def read_info_refs(f):
     return ret
     return ret
 
 
 
 
-def write_info_refs(refs, store):
+def write_info_refs(refs, store: ObjectContainer):
     """Generate info refs."""
     """Generate info refs."""
+    # Avoid recursive import :(
+    from dulwich.object_store import peel_sha
     for name, sha in sorted(refs.items()):
     for name, sha in sorted(refs.items()):
         # get_refs() includes HEAD as a special case, but we don't want to
         # get_refs() includes HEAD as a special case, but we don't want to
         # advertise it
         # advertise it
@@ -1161,7 +1164,7 @@ def write_info_refs(refs, store):
             o = store[sha]
             o = store[sha]
         except KeyError:
         except KeyError:
             continue
             continue
-        peeled = store.peel_sha(sha)
+        peeled = peel_sha(store, sha)
         yield o.id + b"\t" + name + b"\n"
         yield o.id + b"\t" + name + b"\n"
         if o.id != peeled.id:
         if o.id != peeled.id:
             yield peeled.id + b"\t" + name + ANNOTATED_TAG_SUFFIX + b"\n"
             yield peeled.id + b"\t" + name + ANNOTATED_TAG_SUFFIX + b"\n"

+ 2 - 1
dulwich/repo.py

@@ -72,6 +72,7 @@ from dulwich.object_store import (
     MemoryObjectStore,
     MemoryObjectStore,
     BaseObjectStore,
     BaseObjectStore,
     ObjectStoreGraphWalker,
     ObjectStoreGraphWalker,
+    peel_sha,
 )
 )
 from dulwich.objects import (
 from dulwich.objects import (
     check_hexsha,
     check_hexsha,
@@ -757,7 +758,7 @@ class BaseRepo:
         cached = self.refs.get_peeled(ref)
         cached = self.refs.get_peeled(ref)
         if cached is not None:
         if cached is not None:
             return cached
             return cached
-        return self.object_store.peel_sha(self.refs[ref]).id
+        return peel_sha(self.object_store, self.refs[ref]).id
 
 
     def get_walker(self, include: Optional[List[bytes]] = None,
     def get_walker(self, include: Optional[List[bytes]] = None,
                    *args, **kwargs):
                    *args, **kwargs):

+ 22 - 16
dulwich/server.py

@@ -47,7 +47,7 @@ import os
 import socket
 import socket
 import sys
 import sys
 import time
 import time
-from typing import List, Tuple, Dict, Optional, Iterable
+from typing import List, Tuple, Dict, Optional, Iterable, Set
 import zlib
 import zlib
 
 
 import socketserver
 import socketserver
@@ -67,8 +67,12 @@ from dulwich.objects import (
     Commit,
     Commit,
     valid_hexsha,
     valid_hexsha,
 )
 )
+from dulwich.object_store import (
+    peel_sha,
+)
 from dulwich.pack import (
 from dulwich.pack import (
     write_pack_objects,
     write_pack_objects,
+    ObjectContainer,
 )
 )
 from dulwich.protocol import (
 from dulwich.protocol import (
     BufferedPktLineWriter,
     BufferedPktLineWriter,
@@ -456,7 +460,7 @@ def _split_proto_line(line, allowed):
     raise GitProtocolError("Received invalid line from client: %r" % line)
     raise GitProtocolError("Received invalid line from client: %r" % line)
 
 
 
 
-def _find_shallow(store, heads, depth):
+def _find_shallow(store: ObjectContainer, heads, depth):
     """Find shallow commits according to a given depth.
     """Find shallow commits according to a given depth.
 
 
     Args:
     Args:
@@ -468,7 +472,7 @@ def _find_shallow(store, heads, depth):
         considered shallow and unshallow according to the arguments. Note that
         considered shallow and unshallow according to the arguments. Note that
         these sets may overlap if a commit is reachable along multiple paths.
         these sets may overlap if a commit is reachable along multiple paths.
     """
     """
-    parents = {}
+    parents: Dict[bytes, List[bytes]] = {}
 
 
     def get_parents(sha):
     def get_parents(sha):
         result = parents.get(sha, None)
         result = parents.get(sha, None)
@@ -479,7 +483,7 @@ def _find_shallow(store, heads, depth):
 
 
     todo = []  # stack of (sha, depth)
     todo = []  # stack of (sha, depth)
     for head_sha in heads:
     for head_sha in heads:
-        obj = store.peel_sha(head_sha)
+        obj = peel_sha(store, head_sha)
         if isinstance(obj, Commit):
         if isinstance(obj, Commit):
             todo.append((obj.id, 1))
             todo.append((obj.id, 1))
 
 
@@ -497,7 +501,7 @@ def _find_shallow(store, heads, depth):
     return shallow, not_shallow
     return shallow, not_shallow
 
 
 
 
-def _want_satisfied(store, haves, want, earliest):
+def _want_satisfied(store: ObjectContainer, haves, want, earliest):
     o = store[want]
     o = store[want]
     pending = collections.deque([o])
     pending = collections.deque([o])
     known = {want}
     known = {want}
@@ -505,7 +509,7 @@ def _want_satisfied(store, haves, want, earliest):
         commit = pending.popleft()
         commit = pending.popleft()
         if commit.id in haves:
         if commit.id in haves:
             return True
             return True
-        if commit.type_name != b"commit":
+        if not isinstance(commit, Commit):
             # non-commit wants are assumed to be satisfied
             # non-commit wants are assumed to be satisfied
             continue
             continue
         for parent in commit.parents:
         for parent in commit.parents:
@@ -513,13 +517,14 @@ def _want_satisfied(store, haves, want, earliest):
                 continue
                 continue
             known.add(parent)
             known.add(parent)
             parent_obj = store[parent]
             parent_obj = store[parent]
+            assert isinstance(parent_obj, Commit)
             # TODO: handle parents with later commit times than children
             # TODO: handle parents with later commit times than children
             if parent_obj.commit_time >= earliest:
             if parent_obj.commit_time >= earliest:
                 pending.append(parent_obj)
                 pending.append(parent_obj)
     return False
     return False
 
 
 
 
-def _all_wants_satisfied(store, haves, wants):
+def _all_wants_satisfied(store: ObjectContainer, haves, wants):
     """Check whether all the current wants are satisfied by a set of haves.
     """Check whether all the current wants are satisfied by a set of haves.
 
 
     Args:
     Args:
@@ -531,7 +536,8 @@ def _all_wants_satisfied(store, haves, wants):
     """
     """
     haves = set(haves)
     haves = set(haves)
     if haves:
     if haves:
-        earliest = min([store[h].commit_time for h in haves])
+        have_objs = [store[h] for h in haves]
+        earliest = min([h.commit_time for h in have_objs if isinstance(h, Commit)])
     else:
     else:
         earliest = 0
         earliest = 0
     for want in wants:
     for want in wants:
@@ -555,20 +561,20 @@ class _ProtocolGraphWalker:
     any calls to next() or ack() are made.
     any calls to next() or ack() are made.
     """
     """
 
 
-    def __init__(self, handler, object_store, get_peeled, get_symrefs):
+    def __init__(self, handler, object_store: ObjectContainer, get_peeled, get_symrefs):
         self.handler = handler
         self.handler = handler
-        self.store = object_store
+        self.store: ObjectContainer = object_store
         self.get_peeled = get_peeled
         self.get_peeled = get_peeled
         self.get_symrefs = get_symrefs
         self.get_symrefs = get_symrefs
         self.proto = handler.proto
         self.proto = handler.proto
         self.stateless_rpc = handler.stateless_rpc
         self.stateless_rpc = handler.stateless_rpc
         self.advertise_refs = handler.advertise_refs
         self.advertise_refs = handler.advertise_refs
-        self._wants = []
-        self.shallow = set()
-        self.client_shallow = set()
-        self.unshallow = set()
+        self._wants: List[bytes] = []
+        self.shallow: Set[bytes] = set()
+        self.client_shallow: Set[bytes] = set()
+        self.unshallow: Set[bytes] = set()
         self._cached = False
         self._cached = False
-        self._cache = []
+        self._cache: List[bytes] = []
         self._cache_index = 0
         self._cache_index = 0
         self._impl = None
         self._impl = None
 
 
@@ -1104,7 +1110,7 @@ class UploadArchiveHandler(Handler):
         prefix = b""
         prefix = b""
         format = "tar"
         format = "tar"
         i = 0
         i = 0
-        store = self.repo.object_store
+        store: ObjectContainer = self.repo.object_store
         while i < len(arguments):
         while i < len(arguments):
             argument = arguments[i]
             argument = arguments[i]
             if argument == b"--prefix":
             if argument == b"--prefix":

+ 2 - 1
dulwich/submodule.py

@@ -22,6 +22,7 @@
 """
 """
 
 
 from typing import Iterator, Tuple
 from typing import Iterator, Tuple
+from .object_store import iter_tree_contents
 from .objects import S_ISGITLINK
 from .objects import S_ISGITLINK
 
 
 
 
@@ -35,6 +36,6 @@ def iter_cached_submodules(store, root_tree_id: bytes) -> Iterator[Tuple[str, by
     Returns:
     Returns:
       Iterator over over (path, sha) tuples
       Iterator over over (path, sha) tuples
     """
     """
-    for entry in store.iter_tree_contents(root_tree_id):
+    for entry in iter_tree_contents(store, root_tree_id):
         if S_ISGITLINK(entry.mode):
         if S_ISGITLINK(entry.mode):
             yield entry.path, entry.sha
             yield entry.path, entry.sha

+ 5 - 3
dulwich/tests/test_object_store.py

@@ -51,6 +51,8 @@ from dulwich.object_store import (
     OverlayObjectStore,
     OverlayObjectStore,
     ObjectStoreGraphWalker,
     ObjectStoreGraphWalker,
     commit_tree_changes,
     commit_tree_changes,
+    iter_tree_contents,
+    peel_sha,
     read_packs_file,
     read_packs_file,
     tree_lookup_path,
     tree_lookup_path,
 )
 )
@@ -219,7 +221,7 @@ class ObjectStoreTests:
         tree_id = commit_tree(self.store, blobs)
         tree_id = commit_tree(self.store, blobs)
         self.assertEqual(
         self.assertEqual(
             [TreeEntry(p, m, h) for (p, h, m) in blobs],
             [TreeEntry(p, m, h) for (p, h, m) in blobs],
-            list(self.store.iter_tree_contents(tree_id)),
+            list(iter_tree_contents(self.store, tree_id)),
         )
         )
 
 
     def test_iter_tree_contents_include_trees(self):
     def test_iter_tree_contents_include_trees(self):
@@ -247,7 +249,7 @@ class ObjectStoreTests:
             TreeEntry(b"ad/bd", 0o040000, tree_bd.id),
             TreeEntry(b"ad/bd", 0o040000, tree_bd.id),
             TreeEntry(b"ad/bd/c", 0o100755, blob_c.id),
             TreeEntry(b"ad/bd/c", 0o100755, blob_c.id),
         ]
         ]
-        actual = self.store.iter_tree_contents(tree_id, include_trees=True)
+        actual = iter_tree_contents(self.store, tree_id, include_trees=True)
         self.assertEqual(expected, list(actual))
         self.assertEqual(expected, list(actual))
 
 
     def make_tag(self, name, obj):
     def make_tag(self, name, obj):
@@ -261,7 +263,7 @@ class ObjectStoreTests:
         tag2 = self.make_tag(b"2", testobject)
         tag2 = self.make_tag(b"2", testobject)
         tag3 = self.make_tag(b"3", testobject)
         tag3 = self.make_tag(b"3", testobject)
         for obj in [testobject, tag1, tag2, tag3]:
         for obj in [testobject, tag1, tag2, tag3]:
-            self.assertEqual(testobject, self.store.peel_sha(obj.id))
+            self.assertEqual(testobject, peel_sha(self.store, obj.id))
 
 
     def test_get_raw(self):
     def test_get_raw(self):
         self.store.add_object(testobject)
         self.store.add_object(testobject)