浏览代码

Add protocols for object containers.

Jelmer Vernooij 2 年之前
父节点
当前提交
03cc1f5fd0

+ 4 - 1
dulwich/fastexport.py

@@ -30,6 +30,9 @@ from dulwich.objects import (
     Tag,
     Tag,
     ZERO_SHA,
     ZERO_SHA,
 )
 )
+from dulwich.object_store import (
+    iter_tree_contents,
+)
 from fastimport import (
 from fastimport import (
     commands,
     commands,
     errors as fastimport_errors,
     errors as fastimport_errors,
@@ -232,7 +235,7 @@ class GitImportProcessor(processor.ImportProcessor):
                 path,
                 path,
                 mode,
                 mode,
                 hexsha,
                 hexsha,
-            ) in self.repo.object_store.iter_tree_contents(tree_id):
+            ) in iter_tree_contents(self.repo.object_store, tree_id):
                 self._contents[path] = (mode, hexsha)
                 self._contents[path] = (mode, hexsha)
 
 
     def reset_handler(self, cmd):
     def reset_handler(self, cmd):

+ 7 - 6
dulwich/greenthreads.py

@@ -31,12 +31,13 @@ from dulwich.objects import (
 )
 )
 from dulwich.object_store import (
 from dulwich.object_store import (
     MissingObjectFinder,
     MissingObjectFinder,
+    _collect_ancestors,
     _collect_filetree_revs,
     _collect_filetree_revs,
     ObjectStoreIterator,
     ObjectStoreIterator,
 )
 )
 
 
 
 
-def _split_commits_and_tags(obj_store, lst, ignore_unknown=False, pool=None):
+def _split_commits_and_tags(obj_store, lst, *, ignore_unknown=False, pool=None):
     """Split object id list into two list with commit SHA1s and tag SHA1s.
     """Split object id list into two list with commit SHA1s and tag SHA1s.
 
 
     Same implementation as object_store._split_commits_and_tags
     Same implementation as object_store._split_commits_and_tags
@@ -90,11 +91,11 @@ class GreenThreadsMissingObjectFinder(MissingObjectFinder):
         self.object_store = object_store
         self.object_store = object_store
         p = pool.Pool(size=concurrency)
         p = pool.Pool(size=concurrency)
 
 
-        have_commits, have_tags = _split_commits_and_tags(object_store, haves, True, p)
-        want_commits, want_tags = _split_commits_and_tags(object_store, wants, False, p)
-        all_ancestors = object_store._collect_ancestors(have_commits)[0]
-        missing_commits, common_commits = object_store._collect_ancestors(
-            want_commits, all_ancestors
+        have_commits, have_tags = _split_commits_and_tags(object_store, haves, ignore_unknown=True, pool=p)
+        want_commits, want_tags = _split_commits_and_tags(object_store, wants, ignore_unknown=False, pool=p)
+        all_ancestors = _collect_ancestors(object_store, have_commits)[0]
+        missing_commits, common_commits = _collect_ancestors(
+            object_store, want_commits, all_ancestors
         )
         )
 
 
         self.sha_done = set()
         self.sha_done = set()

+ 11 - 11
dulwich/index.py

@@ -39,9 +39,6 @@ from typing import (
     Union,
     Union,
 )
 )
 
 
-if TYPE_CHECKING:
-    from dulwich.object_store import BaseObjectStore
-
 from dulwich.file import GitFile
 from dulwich.file import GitFile
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
@@ -52,9 +49,11 @@ from dulwich.objects import (
     sha_to_hex,
     sha_to_hex,
     ObjectID,
     ObjectID,
 )
 )
+from dulwich.object_store import iter_tree_contents
 from dulwich.pack import (
 from dulwich.pack import (
     SHA1Reader,
     SHA1Reader,
     SHA1Writer,
     SHA1Writer,
+    ObjectContainer,
 )
 )
 
 
 
 
@@ -451,7 +450,7 @@ class Index:
 
 
 
 
 def commit_tree(
 def commit_tree(
-    object_store: "BaseObjectStore", blobs: Iterable[Tuple[bytes, bytes, int]]
+    object_store: ObjectContainer, blobs: Iterable[Tuple[bytes, bytes, int]]
 ) -> bytes:
 ) -> bytes:
     """Commit a new tree.
     """Commit a new tree.
 
 
@@ -494,7 +493,7 @@ def commit_tree(
     return build_tree(b"")
     return build_tree(b"")
 
 
 
 
-def commit_index(object_store: "BaseObjectStore", index: Index) -> bytes:
+def commit_index(object_store: ObjectContainer, index: Index) -> bytes:
     """Create a new tree from an index.
     """Create a new tree from an index.
 
 
     Args:
     Args:
@@ -509,7 +508,7 @@ def commit_index(object_store: "BaseObjectStore", index: Index) -> bytes:
 def changes_from_tree(
 def changes_from_tree(
     names: Iterable[bytes],
     names: Iterable[bytes],
     lookup_entry: Callable[[bytes], Tuple[bytes, int]],
     lookup_entry: Callable[[bytes], Tuple[bytes, int]],
-    object_store: "BaseObjectStore",
+    object_store: ObjectContainer,
     tree: Optional[bytes],
     tree: Optional[bytes],
     want_unchanged=False,
     want_unchanged=False,
 ) -> Iterable[
 ) -> Iterable[
@@ -535,7 +534,7 @@ def changes_from_tree(
     other_names = set(names)
     other_names = set(names)
 
 
     if tree is not None:
     if tree is not None:
-        for (name, mode, sha) in object_store.iter_tree_contents(tree):
+        for (name, mode, sha) in iter_tree_contents(object_store, tree):
             try:
             try:
                 (other_sha, other_mode) = lookup_entry(name)
                 (other_sha, other_mode) = lookup_entry(name)
             except KeyError:
             except KeyError:
@@ -686,7 +685,7 @@ def validate_path(path: bytes,
 def build_index_from_tree(
 def build_index_from_tree(
     root_path: Union[str, bytes],
     root_path: Union[str, bytes],
     index_path: Union[str, bytes],
     index_path: Union[str, bytes],
-    object_store: "BaseObjectStore",
+    object_store: ObjectContainer,
     tree_id: bytes,
     tree_id: bytes,
     honor_filemode: bool = True,
     honor_filemode: bool = True,
     validate_path_element=validate_path_element_default,
     validate_path_element=validate_path_element_default,
@@ -711,7 +710,7 @@ def build_index_from_tree(
     if not isinstance(root_path, bytes):
     if not isinstance(root_path, bytes):
         root_path = os.fsencode(root_path)
         root_path = os.fsencode(root_path)
 
 
-    for entry in object_store.iter_tree_contents(tree_id):
+    for entry in iter_tree_contents(object_store, tree_id):
         if not validate_path(entry.path, validate_path_element):
         if not validate_path(entry.path, validate_path_element):
             continue
             continue
         full_path = _tree_to_fs_path(root_path, entry.path)
         full_path = _tree_to_fs_path(root_path, entry.path)
@@ -727,6 +726,7 @@ def build_index_from_tree(
             # TODO(jelmer): record and return submodule paths
             # TODO(jelmer): record and return submodule paths
         else:
         else:
             obj = object_store[entry.sha]
             obj = object_store[entry.sha]
+            assert isinstance(obj, Blob)
             st = build_file_from_blob(
             st = build_file_from_blob(
                 obj, entry.mode, full_path,
                 obj, entry.mode, full_path,
                 honor_filemode=honor_filemode,
                 honor_filemode=honor_filemode,
@@ -927,7 +927,7 @@ def index_entry_from_directory(st, path: bytes) -> Optional[IndexEntry]:
 
 
 
 
 def index_entry_from_path(
 def index_entry_from_path(
-        path: bytes, object_store: Optional["BaseObjectStore"] = None
+        path: bytes, object_store: Optional[ObjectContainer] = None
 ) -> Optional[IndexEntry]:
 ) -> Optional[IndexEntry]:
     """Create an index from a filesystem path.
     """Create an index from a filesystem path.
 
 
@@ -957,7 +957,7 @@ def index_entry_from_path(
 
 
 def iter_fresh_entries(
 def iter_fresh_entries(
     paths: Iterable[bytes], root_path: bytes,
     paths: Iterable[bytes], root_path: bytes,
-    object_store: Optional["BaseObjectStore"] = None
+    object_store: Optional[ObjectContainer] = None
 ) -> Iterator[Tuple[bytes, Optional[IndexEntry]]]:
 ) -> Iterator[Tuple[bytes, Optional[IndexEntry]]]:
     """Iterate over current versions of index entries on disk.
     """Iterate over current versions of index entries on disk.
 
 

+ 2 - 1
dulwich/line_ending.py

@@ -136,6 +136,7 @@ Sources:
 - https://adaptivepatchwork.com/2012/03/01/mind-the-end-of-your-line/
 - https://adaptivepatchwork.com/2012/03/01/mind-the-end-of-your-line/
 """
 """
 
 
+from dulwich.object_store import iter_tree_contents
 from dulwich.objects import Blob
 from dulwich.objects import Blob
 from dulwich.patch import is_binary
 from dulwich.patch import is_binary
 
 
@@ -290,7 +291,7 @@ class TreeBlobNormalizer(BlobNormalizer):
         if tree:
         if tree:
             self.existing_paths = {
             self.existing_paths = {
                 name
                 name
-                for name, _, _ in object_store.iter_tree_contents(tree)
+                for name, _, _ in iter_tree_contents(object_store, tree)
             }
             }
         else:
         else:
             self.existing_paths = set()
             self.existing_paths = set()

+ 118 - 63
dulwich/object_store.py

@@ -26,13 +26,10 @@ from io import BytesIO
 import os
 import os
 import stat
 import stat
 import sys
 import sys
+import warnings
 
 
-from typing import Callable, Dict, List, Optional, Tuple
+from typing import Callable, Dict, List, Optional, Tuple, Protocol, Union, Iterator, Set
 
 
-from dulwich.diff_tree import (
-    tree_changes,
-    walk_trees,
-)
 from dulwich.errors import (
 from dulwich.errors import (
     NotTreeError,
     NotTreeError,
 )
 )
@@ -48,10 +45,12 @@ from dulwich.objects import (
     sha_to_hex,
     sha_to_hex,
     hex_to_filename,
     hex_to_filename,
     S_ISGITLINK,
     S_ISGITLINK,
+    TreeEntry,
     object_class,
     object_class,
     valid_hexsha,
     valid_hexsha,
 )
 )
 from dulwich.pack import (
 from dulwich.pack import (
+    ObjectContainer,
     Pack,
     Pack,
     PackData,
     PackData,
     PackInflater,
     PackInflater,
@@ -79,6 +78,14 @@ PACKDIR = "pack"
 PACK_MODE = 0o444 if sys.platform != "win32" else 0o644
 PACK_MODE = 0o444 if sys.platform != "win32" else 0o644
 
 
 
 
+class PackContainer(Protocol):
+
+    def add_pack(
+        self
+    ) -> Tuple[BytesIO, Callable[[], None], Callable[[], None]]:
+        """Add a new pack."""
+
+
 class BaseObjectStore:
 class BaseObjectStore:
     """Object store interface."""
     """Object store interface."""
 
 
@@ -213,6 +220,8 @@ class BaseObjectStore:
         Returns: Iterator over tuples with
         Returns: Iterator over tuples with
             (oldpath, newpath), (oldmode, newmode), (oldsha, newsha)
             (oldpath, newpath), (oldmode, newmode), (oldsha, newsha)
         """
         """
+
+        from dulwich.diff_tree import tree_changes
         for change in tree_changes(
         for change in tree_changes(
             self,
             self,
             source,
             source,
@@ -239,11 +248,10 @@ class BaseObjectStore:
         Returns: Iterator over TreeEntry namedtuples for all the objects in a
         Returns: Iterator over TreeEntry namedtuples for all the objects in a
             tree.
             tree.
         """
         """
-        for entry, _ in walk_trees(self, tree_id, None):
-            if (
-                entry.mode is not None and not stat.S_ISDIR(entry.mode)
-            ) or include_trees:
-                yield entry
+        warnings.warn(
+            "Please use dulwich.object_store.iter_tree_contents",
+            DeprecationWarning, stacklevel=2)
+        return iter_tree_contents(self, tree_id, include_trees=include_trees)
 
 
     def find_missing_objects(
     def find_missing_objects(
         self,
         self,
@@ -334,47 +342,10 @@ class BaseObjectStore:
             intermediate tags; if the original ref does not point to a tag,
             intermediate tags; if the original ref does not point to a tag,
             this will equal the original SHA1.
             this will equal the original SHA1.
         """
         """
-        obj = self[sha]
-        obj_class = object_class(obj.type_name)
-        while obj_class is Tag:
-            obj_class, sha = obj.object
-            obj = self[sha]
-        return obj
-
-    def _collect_ancestors(
-        self,
-        heads,
-        common=frozenset(),
-        shallow=frozenset(),
-        get_parents=lambda commit: commit.parents,
-    ):
-        """Collect all ancestors of heads up to (excluding) those in common.
-
-        Args:
-          heads: commits to start from
-          common: commits to end at, or empty set to walk repository
-            completely
-          get_parents: Optional function for getting the parents of a
-            commit.
-        Returns: a tuple (A, B) where A - all commits reachable
-            from heads but not present in common, B - common (shared) elements
-            that are directly reachable from heads
-        """
-        bases = set()
-        commits = set()
-        queue = []
-        queue.extend(heads)
-        while queue:
-            e = queue.pop(0)
-            if e in common:
-                bases.add(e)
-            elif e not in commits:
-                commits.add(e)
-                if e in shallow:
-                    continue
-                cmt = self[e]
-                queue.extend(get_parents(cmt))
-        return (commits, bases)
+        warnings.warn(
+            "Please use dulwich.object_store.peel_sha()",
+            DeprecationWarning, stacklevel=2)
+        return peel_sha(self, sha)
 
 
     def _get_depth(
     def _get_depth(
         self, head, get_parents=lambda commit: commit.parents, max_depth=None,
         self, head, get_parents=lambda commit: commit.parents, max_depth=None,
@@ -1083,10 +1054,10 @@ class MemoryObjectStore(BaseObjectStore):
             commit()
             commit()
 
 
 
 
-class ObjectIterator:
+class ObjectIterator(Protocol):
     """Interface for iterating over objects."""
     """Interface for iterating over objects."""
 
 
-    def iterobjects(self):
+    def iterobjects(self) -> Iterator[ShaFile]:
         raise NotImplementedError(self.iterobjects)
         raise NotImplementedError(self.iterobjects)
 
 
 
 
@@ -1194,7 +1165,7 @@ def _collect_filetree_revs(obj_store, tree_sha, kset):
                 _collect_filetree_revs(obj_store, sha, kset)
                 _collect_filetree_revs(obj_store, sha, kset)
 
 
 
 
-def _split_commits_and_tags(obj_store, lst, ignore_unknown=False):
+def _split_commits_and_tags(obj_store: ObjectContainer, lst, *, ignore_unknown=False) -> Tuple[Set[bytes], Set[bytes], Set[bytes]]:
     """Split object id list into three lists with commit, tag, and other SHAs.
     """Split object id list into three lists with commit, tag, and other SHAs.
 
 
     Commits referenced by tags are included into commits
     Commits referenced by tags are included into commits
@@ -1209,9 +1180,9 @@ def _split_commits_and_tags(obj_store, lst, ignore_unknown=False):
         silently.
         silently.
     Returns: A tuple of (commits, tags, others) SHA1s
     Returns: A tuple of (commits, tags, others) SHA1s
     """
     """
-    commits = set()
-    tags = set()
-    others = set()
+    commits: Set[bytes] = set()
+    tags: Set[bytes] = set()
+    others: Set[bytes] = set()
     for e in lst:
     for e in lst:
         try:
         try:
             o = obj_store[e]
             o = obj_store[e]
@@ -1224,12 +1195,12 @@ def _split_commits_and_tags(obj_store, lst, ignore_unknown=False):
             elif isinstance(o, Tag):
             elif isinstance(o, Tag):
                 tags.add(e)
                 tags.add(e)
                 tagged = o.object[1]
                 tagged = o.object[1]
-                c, t, o = _split_commits_and_tags(
+                c, t, os = _split_commits_and_tags(
                     obj_store, [tagged], ignore_unknown=ignore_unknown
                     obj_store, [tagged], ignore_unknown=ignore_unknown
                 )
                 )
                 commits |= c
                 commits |= c
                 tags |= t
                 tags |= t
-                others |= o
+                others |= os
             else:
             else:
                 others.add(e)
                 others.add(e)
     return (commits, tags, others)
     return (commits, tags, others)
@@ -1270,20 +1241,22 @@ class MissingObjectFinder:
         # wants shall list only known SHAs, and otherwise
         # wants shall list only known SHAs, and otherwise
         # _split_commits_and_tags fails with KeyError
         # _split_commits_and_tags fails with KeyError
         have_commits, have_tags, have_others = _split_commits_and_tags(
         have_commits, have_tags, have_others = _split_commits_and_tags(
-            object_store, haves, True
+            object_store, haves, ignore_unknown=True
         )
         )
         want_commits, want_tags, want_others = _split_commits_and_tags(
         want_commits, want_tags, want_others = _split_commits_and_tags(
-            object_store, wants, False
+            object_store, wants, ignore_unknown=False
         )
         )
         # all_ancestors is a set of commits that shall not be sent
         # all_ancestors is a set of commits that shall not be sent
         # (complete repository up to 'haves')
         # (complete repository up to 'haves')
-        all_ancestors = object_store._collect_ancestors(
+        all_ancestors = _collect_ancestors(
+            object_store,
             have_commits, shallow=shallow, get_parents=self._get_parents
             have_commits, shallow=shallow, get_parents=self._get_parents
         )[0]
         )[0]
         # all_missing - complete set of commits between haves and wants
         # all_missing - complete set of commits between haves and wants
         # common - commits from all_ancestors we hit into while
         # common - commits from all_ancestors we hit into while
         # traversing parent hierarchy of wants
         # traversing parent hierarchy of wants
-        missing_commits, common_commits = object_store._collect_ancestors(
+        missing_commits, common_commits = _collect_ancestors(
+            object_store,
             want_commits,
             want_commits,
             all_ancestors,
             all_ancestors,
             shallow=shallow,
             shallow=shallow,
@@ -1606,3 +1579,85 @@ class BucketBasedObjectStore(PackBasedObjectStore):
             return final_pack
             return final_pack
 
 
         return pf, commit, pf.close
         return pf, commit, pf.close
+
+
+def _collect_ancestors(
+    store: ObjectContainer,
+    heads,
+    common=frozenset(),
+    shallow=frozenset(),
+    get_parents=lambda commit: commit.parents,
+):
+    """Collect all ancestors of heads up to (excluding) those in common.
+
+    Args:
+      heads: commits to start from
+      common: commits to end at, or empty set to walk repository
+        completely
+      get_parents: Optional function for getting the parents of a
+        commit.
+    Returns: a tuple (A, B) where A - all commits reachable
+        from heads but not present in common, B - common (shared) elements
+        that are directly reachable from heads
+    """
+    bases = set()
+    commits = set()
+    queue = []
+    queue.extend(heads)
+    while queue:
+        e = queue.pop(0)
+        if e in common:
+            bases.add(e)
+        elif e not in commits:
+            commits.add(e)
+            if e in shallow:
+                continue
+            cmt = store[e]
+            queue.extend(get_parents(cmt))
+    return (commits, bases)
+
+
+def iter_tree_contents(
+        store: ObjectContainer, tree_id: bytes, *, include_trees: bool = False):
+    """Iterate the contents of a tree and all subtrees.
+
+    Iteration is depth-first pre-order, as in e.g. os.walk.
+
+    Args:
+      tree_id: SHA1 of the tree.
+      include_trees: If True, include tree objects in the iteration.
+    Returns: Iterator over TreeEntry namedtuples for all the objects in a
+        tree.
+    """
+    # This could be fairly easily generalized to >2 trees if we find a use
+    # case.
+    todo = [TreeEntry(b"", stat.S_IFDIR, tree_id)]
+    while todo:
+        entry = todo.pop()
+        if stat.S_ISDIR(entry.mode):
+            extra = []
+            tree = store[entry.sha]
+            assert isinstance(tree, Tree)
+            for subentry in tree.iteritems(name_order=True):
+                extra.append(subentry.in_path(entry.path))
+            todo.extend(reversed(extra))
+        if not stat.S_ISDIR(entry.mode) or include_trees:
+            yield entry
+
+
+def peel_sha(store: ObjectContainer, sha: bytes) -> ShaFile:
+    """Peel all tags from a SHA.
+
+    Args:
+      sha: The object SHA to peel.
+    Returns: The fully-peeled SHA1 of a tag object, after peeling all
+        intermediate tags; if the original ref does not point to a tag,
+        this will equal the original SHA1.
+    """
+    obj = store[sha]
+    obj_class = object_class(obj.type_name)
+    while obj_class is Tag:
+        assert isinstance(obj, Tag)
+        obj_class, sha = obj.object
+        obj = store[sha]
+    return obj

+ 43 - 4
dulwich/pack.py

@@ -49,7 +49,7 @@ from itertools import chain
 
 
 import os
 import os
 import sys
 import sys
-from typing import Optional, Callable, Tuple, List, Deque, Union
+from typing import Optional, Callable, Tuple, List, Deque, Union, Protocol, Iterable
 import warnings
 import warnings
 
 
 from hashlib import sha1
 from hashlib import sha1
@@ -96,6 +96,38 @@ DELTA_TYPES = (OFS_DELTA, REF_DELTA)
 DEFAULT_PACK_DELTA_WINDOW_SIZE = 10
 DEFAULT_PACK_DELTA_WINDOW_SIZE = 10
 
 
 
 
+class ObjectContainer(Protocol):
+
+    def add_object(self, obj: ShaFile) -> None:
+        """Add a single object to this object store."""
+
+    def add_objects(
+            self, objects: Iterable[tuple[ShaFile, Optional[str]]],
+            progress: Optional[Callable[[str], None]] = None) -> None:
+        """Add a set of objects to this object store.
+
+        Args:
+          objects: Iterable over a list of (object, path) tuples
+        """
+
+    def __contains__(self, sha1: bytes) -> bool:
+        """Check if a hex sha is present."""
+
+    def __getitem__(self, sha1: bytes) -> ShaFile:
+        """Retrieve an object."""
+
+
+class PackedObjectContainer(ObjectContainer):
+
+    def get_raw_unresolved(self, sha1: bytes) -> Tuple[int, Union[bytes, None], List[bytes]]:
+        """Get a raw unresolved object."""
+        raise NotImplementedError(self.get_raw_unresolved)
+
+    def get_raw_delta(self, sha1: bytes) -> Tuple[int, Union[bytes, None], List[bytes]]:
+        """Get a raw delta text."""
+        raise NotImplementedError(self.get_raw_delta)
+
+
 def take_msb_bytes(read: Callable[[int], bytes], crc32: Optional[int] = None) -> Tuple[List[int], Optional[int]]:
 def take_msb_bytes(read: Callable[[int], bytes], crc32: Optional[int] = None) -> Tuple[List[int], Optional[int]]:
     """Read bytes marked with most significant bit.
     """Read bytes marked with most significant bit.
 
 
@@ -1643,7 +1675,10 @@ def write_pack_header(write, num_objects):
         write(chunk)
         write(chunk)
 
 
 
 
-def deltify_pack_objects(objects, window_size: Optional[int] = None, reuse_pack=None):
+def deltify_pack_objects(
+        objects: Iterable[Tuple[ShaFile, str]],
+        window_size: Optional[int] = None,
+        reuse_pack: Optional[PackedObjectContainer] = None):
     """Generate deltas for pack objects.
     """Generate deltas for pack objects.
 
 
     Args:
     Args:
@@ -1685,7 +1720,7 @@ def deltify_pack_objects(objects, window_size: Optional[int] = None, reuse_pack=
         magic.append((obj.type_num, path, -obj.raw_length(), obj))
         magic.append((obj.type_num, path, -obj.raw_length(), obj))
     magic.sort()
     magic.sort()
 
 
-    possible_bases: Deque[Tuple[bytes, int, bytes]] = deque()
+    possible_bases: Deque[Tuple[bytes, int, List[bytes]]] = deque()
 
 
     for type_num, path, neg_length, o in magic:
     for type_num, path, neg_length, o in magic:
         raw = o.as_raw_chunks()
         raw = o.as_raw_chunks()
@@ -1730,7 +1765,11 @@ def pack_objects_to_data(objects):
 
 
 
 
 def write_pack_objects(
 def write_pack_objects(
-    write, objects, delta_window_size=None, deltify=None, reuse_pack=None, compression_level=-1
+        write, objects,
+        delta_window_size: Optional[int] = None,
+        deltify: Optional[bool] = None,
+        reuse_pack: Optional[PackedObjectContainer] = None,
+        compression_level: int = -1
 ):
 ):
     """Write a new pack data file.
     """Write a new pack data file.
 
 

+ 2 - 1
dulwich/patch.py

@@ -34,6 +34,7 @@ from dulwich.objects import (
     Commit,
     Commit,
     S_ISGITLINK,
     S_ISGITLINK,
 )
 )
+from dulwich.pack import ObjectContainer
 
 
 FIRST_FEW_BYTES = 8000
 FIRST_FEW_BYTES = 8000
 
 
@@ -192,7 +193,7 @@ def patch_filename(p, root):
         return root + b"/" + p
         return root + b"/" + p
 
 
 
 
-def write_object_diff(f, store, old_file, new_file, diff_binary=False):
+def write_object_diff(f, store: ObjectContainer, old_file, new_file, diff_binary=False):
     """Write the diff for an object.
     """Write the diff for an object.
 
 
     Args:
     Args:

+ 5 - 2
dulwich/refs.py

@@ -37,6 +37,7 @@ from dulwich.objects import (
     Tag,
     Tag,
     ObjectID,
     ObjectID,
 )
 )
+from dulwich.pack import ObjectContainer
 from dulwich.file import (
 from dulwich.file import (
     GitFile,
     GitFile,
     ensure_dir_exists,
     ensure_dir_exists,
@@ -1150,8 +1151,10 @@ def read_info_refs(f):
     return ret
     return ret
 
 
 
 
-def write_info_refs(refs, store):
+def write_info_refs(refs, store: ObjectContainer):
     """Generate info refs."""
     """Generate info refs."""
+    # Avoid recursive import :(
+    from dulwich.object_store import peel_sha
     for name, sha in sorted(refs.items()):
     for name, sha in sorted(refs.items()):
         # get_refs() includes HEAD as a special case, but we don't want to
         # get_refs() includes HEAD as a special case, but we don't want to
         # advertise it
         # advertise it
@@ -1161,7 +1164,7 @@ def write_info_refs(refs, store):
             o = store[sha]
             o = store[sha]
         except KeyError:
         except KeyError:
             continue
             continue
-        peeled = store.peel_sha(sha)
+        peeled = peel_sha(store, sha)
         yield o.id + b"\t" + name + b"\n"
         yield o.id + b"\t" + name + b"\n"
         if o.id != peeled.id:
         if o.id != peeled.id:
             yield peeled.id + b"\t" + name + ANNOTATED_TAG_SUFFIX + b"\n"
             yield peeled.id + b"\t" + name + ANNOTATED_TAG_SUFFIX + b"\n"

+ 2 - 1
dulwich/repo.py

@@ -72,6 +72,7 @@ from dulwich.object_store import (
     MemoryObjectStore,
     MemoryObjectStore,
     BaseObjectStore,
     BaseObjectStore,
     ObjectStoreGraphWalker,
     ObjectStoreGraphWalker,
+    peel_sha,
 )
 )
 from dulwich.objects import (
 from dulwich.objects import (
     check_hexsha,
     check_hexsha,
@@ -757,7 +758,7 @@ class BaseRepo:
         cached = self.refs.get_peeled(ref)
         cached = self.refs.get_peeled(ref)
         if cached is not None:
         if cached is not None:
             return cached
             return cached
-        return self.object_store.peel_sha(self.refs[ref]).id
+        return peel_sha(self.object_store, self.refs[ref]).id
 
 
     def get_walker(self, include: Optional[List[bytes]] = None,
     def get_walker(self, include: Optional[List[bytes]] = None,
                    *args, **kwargs):
                    *args, **kwargs):

+ 22 - 16
dulwich/server.py

@@ -47,7 +47,7 @@ import os
 import socket
 import socket
 import sys
 import sys
 import time
 import time
-from typing import List, Tuple, Dict, Optional, Iterable
+from typing import List, Tuple, Dict, Optional, Iterable, Set
 import zlib
 import zlib
 
 
 import socketserver
 import socketserver
@@ -67,8 +67,12 @@ from dulwich.objects import (
     Commit,
     Commit,
     valid_hexsha,
     valid_hexsha,
 )
 )
+from dulwich.object_store import (
+    peel_sha,
+)
 from dulwich.pack import (
 from dulwich.pack import (
     write_pack_objects,
     write_pack_objects,
+    ObjectContainer,
 )
 )
 from dulwich.protocol import (
 from dulwich.protocol import (
     BufferedPktLineWriter,
     BufferedPktLineWriter,
@@ -456,7 +460,7 @@ def _split_proto_line(line, allowed):
     raise GitProtocolError("Received invalid line from client: %r" % line)
     raise GitProtocolError("Received invalid line from client: %r" % line)
 
 
 
 
-def _find_shallow(store, heads, depth):
+def _find_shallow(store: ObjectContainer, heads, depth):
     """Find shallow commits according to a given depth.
     """Find shallow commits according to a given depth.
 
 
     Args:
     Args:
@@ -468,7 +472,7 @@ def _find_shallow(store, heads, depth):
         considered shallow and unshallow according to the arguments. Note that
         considered shallow and unshallow according to the arguments. Note that
         these sets may overlap if a commit is reachable along multiple paths.
         these sets may overlap if a commit is reachable along multiple paths.
     """
     """
-    parents = {}
+    parents: Dict[bytes, List[bytes]]  = {}
 
 
     def get_parents(sha):
     def get_parents(sha):
         result = parents.get(sha, None)
         result = parents.get(sha, None)
@@ -479,7 +483,7 @@ def _find_shallow(store, heads, depth):
 
 
     todo = []  # stack of (sha, depth)
     todo = []  # stack of (sha, depth)
     for head_sha in heads:
     for head_sha in heads:
-        obj = store.peel_sha(head_sha)
+        obj = peel_sha(store, head_sha)
         if isinstance(obj, Commit):
         if isinstance(obj, Commit):
             todo.append((obj.id, 1))
             todo.append((obj.id, 1))
 
 
@@ -497,7 +501,7 @@ def _find_shallow(store, heads, depth):
     return shallow, not_shallow
     return shallow, not_shallow
 
 
 
 
-def _want_satisfied(store, haves, want, earliest):
+def _want_satisfied(store: ObjectContainer, haves, want, earliest):
     o = store[want]
     o = store[want]
     pending = collections.deque([o])
     pending = collections.deque([o])
     known = {want}
     known = {want}
@@ -505,7 +509,7 @@ def _want_satisfied(store, haves, want, earliest):
         commit = pending.popleft()
         commit = pending.popleft()
         if commit.id in haves:
         if commit.id in haves:
             return True
             return True
-        if commit.type_name != b"commit":
+        if not isinstance(commit, Commit):
             # non-commit wants are assumed to be satisfied
             # non-commit wants are assumed to be satisfied
             continue
             continue
         for parent in commit.parents:
         for parent in commit.parents:
@@ -513,13 +517,14 @@ def _want_satisfied(store, haves, want, earliest):
                 continue
                 continue
             known.add(parent)
             known.add(parent)
             parent_obj = store[parent]
             parent_obj = store[parent]
+            assert isinstance(parent_obj, Commit)
             # TODO: handle parents with later commit times than children
             # TODO: handle parents with later commit times than children
             if parent_obj.commit_time >= earliest:
             if parent_obj.commit_time >= earliest:
                 pending.append(parent_obj)
                 pending.append(parent_obj)
     return False
     return False
 
 
 
 
-def _all_wants_satisfied(store, haves, wants):
+def _all_wants_satisfied(store: ObjectContainer, haves, wants):
     """Check whether all the current wants are satisfied by a set of haves.
     """Check whether all the current wants are satisfied by a set of haves.
 
 
     Args:
     Args:
@@ -531,7 +536,8 @@ def _all_wants_satisfied(store, haves, wants):
     """
     """
     haves = set(haves)
     haves = set(haves)
     if haves:
     if haves:
-        earliest = min([store[h].commit_time for h in haves])
+        have_objs = [store[h] for h in haves]
+        earliest = min([h.commit_time for h in have_objs if isinstance(h, Commit)])
     else:
     else:
         earliest = 0
         earliest = 0
     for want in wants:
     for want in wants:
@@ -555,20 +561,20 @@ class _ProtocolGraphWalker:
     any calls to next() or ack() are made.
     any calls to next() or ack() are made.
     """
     """
 
 
-    def __init__(self, handler, object_store, get_peeled, get_symrefs):
+    def __init__(self, handler, object_store: ObjectContainer, get_peeled, get_symrefs):
         self.handler = handler
         self.handler = handler
-        self.store = object_store
+        self.store: ObjectContainer = object_store
         self.get_peeled = get_peeled
         self.get_peeled = get_peeled
         self.get_symrefs = get_symrefs
         self.get_symrefs = get_symrefs
         self.proto = handler.proto
         self.proto = handler.proto
         self.stateless_rpc = handler.stateless_rpc
         self.stateless_rpc = handler.stateless_rpc
         self.advertise_refs = handler.advertise_refs
         self.advertise_refs = handler.advertise_refs
-        self._wants = []
-        self.shallow = set()
-        self.client_shallow = set()
-        self.unshallow = set()
+        self._wants: List[bytes] = []
+        self.shallow: Set[bytes] = set()
+        self.client_shallow: Set[bytes] = set()
+        self.unshallow: Set[bytes] = set()
         self._cached = False
         self._cached = False
-        self._cache = []
+        self._cache: List[bytes] = []
         self._cache_index = 0
         self._cache_index = 0
         self._impl = None
         self._impl = None
 
 
@@ -1104,7 +1110,7 @@ class UploadArchiveHandler(Handler):
         prefix = b""
         prefix = b""
         format = "tar"
         format = "tar"
         i = 0
         i = 0
-        store = self.repo.object_store
+        store: ObjectContainer = self.repo.object_store
         while i < len(arguments):
         while i < len(arguments):
             argument = arguments[i]
             argument = arguments[i]
             if argument == b"--prefix":
             if argument == b"--prefix":

+ 2 - 1
dulwich/submodule.py

@@ -22,6 +22,7 @@
 """
 """
 
 
 from typing import Iterator, Tuple
 from typing import Iterator, Tuple
+from .object_store import iter_tree_contents
 from .objects import S_ISGITLINK
 from .objects import S_ISGITLINK
 
 
 
 
@@ -35,6 +36,6 @@ def iter_cached_submodules(store, root_tree_id: bytes) -> Iterator[Tuple[str, by
     Returns:
     Returns:
       Iterator over over (path, sha) tuples
       Iterator over over (path, sha) tuples
     """
     """
-    for entry in store.iter_tree_contents(root_tree_id):
+    for entry in iter_tree_contents(store, root_tree_id):
         if S_ISGITLINK(entry.mode):
         if S_ISGITLINK(entry.mode):
             yield entry.path, entry.sha
             yield entry.path, entry.sha

+ 5 - 3
dulwich/tests/test_object_store.py

@@ -51,6 +51,8 @@ from dulwich.object_store import (
     OverlayObjectStore,
     OverlayObjectStore,
     ObjectStoreGraphWalker,
     ObjectStoreGraphWalker,
     commit_tree_changes,
     commit_tree_changes,
+    iter_tree_contents,
+    peel_sha,
     read_packs_file,
     read_packs_file,
     tree_lookup_path,
     tree_lookup_path,
 )
 )
@@ -219,7 +221,7 @@ class ObjectStoreTests:
         tree_id = commit_tree(self.store, blobs)
         tree_id = commit_tree(self.store, blobs)
         self.assertEqual(
         self.assertEqual(
             [TreeEntry(p, m, h) for (p, h, m) in blobs],
             [TreeEntry(p, m, h) for (p, h, m) in blobs],
-            list(self.store.iter_tree_contents(tree_id)),
+            list(iter_tree_contents(self.store, tree_id)),
         )
         )
 
 
     def test_iter_tree_contents_include_trees(self):
     def test_iter_tree_contents_include_trees(self):
@@ -247,7 +249,7 @@ class ObjectStoreTests:
             TreeEntry(b"ad/bd", 0o040000, tree_bd.id),
             TreeEntry(b"ad/bd", 0o040000, tree_bd.id),
             TreeEntry(b"ad/bd/c", 0o100755, blob_c.id),
             TreeEntry(b"ad/bd/c", 0o100755, blob_c.id),
         ]
         ]
-        actual = self.store.iter_tree_contents(tree_id, include_trees=True)
+        actual = iter_tree_contents(self.store, tree_id, include_trees=True)
         self.assertEqual(expected, list(actual))
         self.assertEqual(expected, list(actual))
 
 
     def make_tag(self, name, obj):
     def make_tag(self, name, obj):
@@ -261,7 +263,7 @@ class ObjectStoreTests:
         tag2 = self.make_tag(b"2", testobject)
         tag2 = self.make_tag(b"2", testobject)
         tag3 = self.make_tag(b"3", testobject)
         tag3 = self.make_tag(b"3", testobject)
         for obj in [testobject, tag1, tag2, tag3]:
         for obj in [testobject, tag1, tag2, tag3]:
-            self.assertEqual(testobject, self.store.peel_sha(obj.id))
+            self.assertEqual(testobject, peel_sha(self.store, obj.id))
 
 
     def test_get_raw(self):
     def test_get_raw(self):
         self.store.add_object(testobject)
         self.store.add_object(testobject)