2
0
Jelmer Vernooij 5 сар өмнө
parent
commit
27ebd52b00

+ 14 - 7
dulwich/bundle.py

@@ -21,10 +21,15 @@
 
 """Bundle format support."""
 
-from typing import BinaryIO, Callable, Optional
+from collections.abc import Iterator
+from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Optional
 
 from .pack import PackData, write_pack_data
 
+if TYPE_CHECKING:
+    from .object_store import BaseObjectStore
+    from .repo import BaseRepo
+
 
 class Bundle:
     version: Optional[int]
@@ -58,8 +63,10 @@ class Bundle:
         return True
 
     def store_objects(
-        self, object_store, progress: Optional[Callable[[str], None]] = None
-    ):
+        self,
+        object_store: "BaseObjectStore",
+        progress: Optional[Callable[[str], None]] = None,
+    ) -> None:
         """Store all objects from this bundle into an object store.
 
         Args:
@@ -178,7 +185,7 @@ def write_bundle(f: BinaryIO, bundle: Bundle) -> None:
 
 
 def create_bundle_from_repo(
-    repo,
+    repo: "BaseRepo",
     refs: Optional[list[bytes]] = None,
     prerequisites: Optional[list[bytes]] = None,
     version: Optional[int] = None,
@@ -266,14 +273,14 @@ def create_bundle_from_repo(
     # Store the pack objects directly, we'll write them when saving the bundle
     # For now, create a simple wrapper to hold the data
     class _BundlePackData:
-        def __init__(self, count, objects):
+        def __init__(self, count: int, objects: Iterator[Any]) -> None:
             self._count = count
             self._objects = list(objects)  # Materialize the iterator
 
-        def __len__(self):
+        def __len__(self) -> int:
             return self._count
 
-        def iter_unpacked(self):
+        def iter_unpacked(self) -> Iterator[Any]:
             return iter(self._objects)
 
     pack_data = _BundlePackData(pack_count, pack_objects)

+ 23 - 14
dulwich/config.py

@@ -48,6 +48,7 @@ from typing import (
     Optional,
     TypeVar,
     Union,
+    cast,
     overload,
 )
 
@@ -176,7 +177,11 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
 
     @classmethod
     def make(
-        cls, dict_in=None, default_factory=None
+        cls,
+        dict_in: Optional[
+            Union[MutableMapping[K, V], "CaseInsensitiveOrderedMultiDict[K, V]"]
+        ] = None,
+        default_factory: Optional[Callable[[], V]] = None,
     ) -> "CaseInsensitiveOrderedMultiDict[K, V]":
         if isinstance(dict_in, cls):
             return dict_in
@@ -226,22 +231,22 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
     def values(self) -> ValuesView[V]:
         return self._keyed.values()
 
-    def __setitem__(self, key, value) -> None:
+    def __setitem__(self, key: K, value: V) -> None:
         self._real.append((key, value))
         self._keyed[lower_key(key)] = value
 
-    def set(self, key, value) -> None:
+    def set(self, key: K, value: V) -> None:
         # This method replaces all existing values for the key
         lower = lower_key(key)
         self._real = [(k, v) for k, v in self._real if lower_key(k) != lower]
         self._real.append((key, value))
         self._keyed[lower] = value
 
-    def __delitem__(self, key) -> None:
-        key = lower_key(key)
-        del self._keyed[key]
+    def __delitem__(self, key: K) -> None:
+        lower_k = lower_key(key)
+        del self._keyed[lower_k]
         for i, (actual, unused_value) in reversed(list(enumerate(self._real))):
-            if lower_key(actual) == key:
+            if lower_key(actual) == lower_k:
                 del self._real[i]
 
     def __getitem__(self, item: K) -> V:
@@ -394,7 +399,7 @@ class ConfigDict(Config):
     def __init__(
         self,
         values: Union[
-            MutableMapping[Section, MutableMapping[Name, Value]], None
+            MutableMapping[Section, CaseInsensitiveOrderedMultiDict[Name, Value]], None
         ] = None,
         encoding: Union[str, None] = None,
     ) -> None:
@@ -417,7 +422,9 @@ class ConfigDict(Config):
     def __getitem__(self, key: Section) -> CaseInsensitiveOrderedMultiDict[Name, Value]:
         return self._values.__getitem__(key)
 
-    def __setitem__(self, key: Section, value: MutableMapping[Name, Value]) -> None:
+    def __setitem__(
+        self, key: Section, value: CaseInsensitiveOrderedMultiDict[Name, Value]
+    ) -> None:
         return self._values.__setitem__(key, value)
 
     def __delitem__(self, key: Section) -> None:
@@ -739,7 +746,7 @@ class ConfigFile(ConfigDict):
     def __init__(
         self,
         values: Union[
-            MutableMapping[Section, MutableMapping[Name, Value]], None
+            MutableMapping[Section, CaseInsensitiveOrderedMultiDict[Name, Value]], None
         ] = None,
         encoding: Union[str, None] = None,
     ) -> None:
@@ -924,10 +931,11 @@ class ConfigFile(ConfigDict):
         # Load and merge the included file
         try:
             # Use provided file opener or default to GitFile
+            opener: FileOpener
             if file_opener is None:
 
-                def opener(path):
-                    return GitFile(path, "rb")
+                def opener(path: Union[str, os.PathLike]) -> BinaryIO:
+                    return cast(BinaryIO, GitFile(path, "rb"))
             else:
                 opener = file_opener
 
@@ -1084,10 +1092,11 @@ class ConfigFile(ConfigDict):
         config_dir = os.path.dirname(abs_path)
 
         # Use provided file opener or default to GitFile
+        opener: FileOpener
         if file_opener is None:
 
-            def opener(p):
-                return GitFile(p, "rb")
+            def opener(p: Union[str, os.PathLike]) -> BinaryIO:
+                return cast(BinaryIO, GitFile(p, "rb"))
         else:
             opener = file_opener
 

+ 12 - 0
dulwich/file.py

@@ -266,3 +266,15 @@ class _GitFile:
         if name in self.PROXY_PROPERTIES:
             return getattr(self._file, name)
         raise AttributeError(name)
+
+    def readable(self) -> bool:
+        """Return whether the file is readable."""
+        return self._file.readable()
+
+    def writable(self) -> bool:
+        """Return whether the file is writable."""
+        return self._file.writable()
+
+    def seekable(self) -> bool:
+        """Return whether the file is seekable."""
+        return self._file.seekable()

+ 9 - 10
dulwich/gc.py

@@ -4,19 +4,18 @@ import collections
 import os
 import time
 from dataclasses import dataclass, field
-from typing import TYPE_CHECKING, Optional
+from typing import TYPE_CHECKING, Callable, Optional
 
 from dulwich.object_store import (
     BaseObjectStore,
     DiskObjectStore,
-    PackBasedObjectStore,
 )
 from dulwich.objects import Commit, ObjectID, Tag, Tree
 from dulwich.refs import RefsContainer
 
 if TYPE_CHECKING:
     from .config import Config
-    from .repo import BaseRepo
+    from .repo import BaseRepo, Repo
 
 
 DEFAULT_GC_AUTO = 6700
@@ -39,7 +38,7 @@ def find_reachable_objects(
     object_store: BaseObjectStore,
     refs_container: RefsContainer,
     include_reflogs: bool = True,
-    progress=None,
+    progress: Optional[Callable[[str], None]] = None,
 ) -> set[bytes]:
     """Find all reachable objects in the repository.
 
@@ -112,7 +111,7 @@ def find_unreachable_objects(
     object_store: BaseObjectStore,
     refs_container: RefsContainer,
     include_reflogs: bool = True,
-    progress=None,
+    progress: Optional[Callable[[str], None]] = None,
 ) -> set[bytes]:
     """Find all unreachable objects in the repository.
 
@@ -138,11 +137,11 @@ def find_unreachable_objects(
 
 
 def prune_unreachable_objects(
-    object_store: PackBasedObjectStore,
+    object_store: DiskObjectStore,
     refs_container: RefsContainer,
     grace_period: Optional[int] = None,
     dry_run: bool = False,
-    progress=None,
+    progress: Optional[Callable[[str], None]] = None,
 ) -> tuple[set[bytes], int]:
     """Remove unreachable objects from the repository.
 
@@ -206,13 +205,13 @@ def prune_unreachable_objects(
 
 
 def garbage_collect(
-    repo,
+    repo: "Repo",
     auto: bool = False,
     aggressive: bool = False,
     prune: bool = True,
     grace_period: Optional[int] = 1209600,  # 2 weeks default
     dry_run: bool = False,
-    progress=None,
+    progress: Optional[Callable[[str], None]] = None,
 ) -> GCStats:
     """Run garbage collection on a repository.
 
@@ -368,7 +367,7 @@ def should_run_gc(repo: "BaseRepo", config: Optional["Config"] = None) -> bool:
     return False
 
 
-def maybe_auto_gc(repo: "BaseRepo", config: Optional["Config"] = None) -> bool:
+def maybe_auto_gc(repo: "Repo", config: Optional["Config"] = None) -> bool:
     """Run automatic garbage collection if needed.
 
     Args:

+ 22 - 14
dulwich/greenthreads.py

@@ -23,12 +23,13 @@
 
 """Utility module for querying an ObjectStore with gevent."""
 
-from typing import Optional
+from typing import Callable, Optional
 
 import gevent
 from gevent import pool
 
 from .object_store import (
+    BaseObjectStore,
     MissingObjectFinder,
     _collect_ancestors,
     _collect_filetree_revs,
@@ -36,7 +37,13 @@ from .object_store import (
 from .objects import Commit, ObjectID, Tag
 
 
-def _split_commits_and_tags(obj_store, lst, *, ignore_unknown=False, pool=None):
+def _split_commits_and_tags(
+    obj_store: BaseObjectStore,
+    lst: list[ObjectID],
+    *,
+    ignore_unknown: bool = False,
+    pool: pool.Pool,
+) -> tuple[set[ObjectID], set[ObjectID]]:
     """Split object id list into two list with commit SHA1s and tag SHA1s.
 
     Same implementation as object_store._split_commits_and_tags
@@ -45,7 +52,7 @@ def _split_commits_and_tags(obj_store, lst, *, ignore_unknown=False, pool=None):
     commits = set()
     tags = set()
 
-    def find_commit_type(sha) -> None:
+    def find_commit_type(sha: ObjectID) -> None:
         try:
             o = obj_store[sha]
         except KeyError:
@@ -58,7 +65,7 @@ def _split_commits_and_tags(obj_store, lst, *, ignore_unknown=False, pool=None):
                 tags.add(sha)
                 commits.add(o.object[1])
             else:
-                raise KeyError(f"Not a commit or a tag: {sha}")
+                raise KeyError(f"Not a commit or a tag: {sha!r}")
 
     jobs = [pool.spawn(find_commit_type, s) for s in lst]
     gevent.joinall(jobs)
@@ -74,18 +81,19 @@ class GreenThreadsMissingObjectFinder(MissingObjectFinder):
 
     def __init__(
         self,
-        object_store,
-        haves,
-        wants,
-        progress=None,
-        get_tagged=None,
-        concurrency=1,
-        get_parents=None,
+        object_store: BaseObjectStore,
+        haves: list[ObjectID],
+        wants: list[ObjectID],
+        progress: Optional[Callable[[str], None]] = None,
+        get_tagged: Optional[Callable[[], dict[ObjectID, ObjectID]]] = None,
+        concurrency: int = 1,
+        get_parents: Optional[Callable[[ObjectID], list[ObjectID]]] = None,
     ) -> None:
-        def collect_tree_sha(sha) -> None:
+        def collect_tree_sha(sha: ObjectID) -> None:
             self.sha_done.add(sha)
-            cmt = object_store[sha]
-            _collect_filetree_revs(object_store, cmt.tree, self.sha_done)
+            obj = object_store[sha]
+            if isinstance(obj, Commit):
+                _collect_filetree_revs(object_store, obj.tree, self.sha_done)
 
         self.object_store = object_store
         p = pool.Pool(size=concurrency)

+ 7 - 2
dulwich/lfs_server.py

@@ -231,7 +231,7 @@ class LFSRequestHandler(BaseHTTPRequestHandler):
         except KeyError:
             return False
 
-    def log_message(self, format, *args):
+    def log_message(self, format: str, *args: object) -> None:
         """Override to suppress request logging during tests."""
         if self.server.log_requests:
             super().log_message(format, *args)
@@ -240,7 +240,12 @@ class LFSRequestHandler(BaseHTTPRequestHandler):
 class LFSServer(HTTPServer):
     """Simple LFS server for testing."""
 
-    def __init__(self, server_address, lfs_store: LFSStore, log_requests: bool = False):
+    def __init__(
+        self,
+        server_address: tuple[str, int],
+        lfs_store: LFSStore,
+        log_requests: bool = False,
+    ) -> None:
         super().__init__(server_address, LFSRequestHandler)
         self.lfs_store = lfs_store
         self.log_requests = log_requests

+ 4 - 4
dulwich/merge_drivers.py

@@ -23,7 +23,7 @@
 import os
 import subprocess
 import tempfile
-from typing import Any, Optional, Protocol
+from typing import Any, Callable, Optional, Protocol
 
 from .config import Config
 
@@ -149,12 +149,12 @@ class MergeDriverRegistry:
         # Register built-in drivers
         self._register_builtin_drivers()
 
-    def _register_builtin_drivers(self):
+    def _register_builtin_drivers(self) -> None:
         """Register built-in merge drivers."""
         # The "text" driver is the default three-way merge
         # We don't register it here as it's handled by the default merge code
 
-    def register_driver(self, name: str, driver: MergeDriver):
+    def register_driver(self, name: str, driver: MergeDriver) -> None:
         """Register a merge driver instance.
 
         Args:
@@ -163,7 +163,7 @@ class MergeDriverRegistry:
         """
         self._drivers[name] = driver
 
-    def register_factory(self, name: str, factory):
+    def register_factory(self, name: str, factory: Callable[[], MergeDriver]) -> None:
         """Register a factory function for creating merge drivers.
 
         Args:

+ 17 - 5
dulwich/notes.py

@@ -28,12 +28,14 @@ from .objects import Blob, Tree
 
 if TYPE_CHECKING:
     from .config import StackedConfig
+    from .object_store import BaseObjectStore
+    from .refs import RefsContainer
 
 NOTES_REF_PREFIX = b"refs/notes/"
 DEFAULT_NOTES_REF = NOTES_REF_PREFIX + b"commits"
 
 
-def get_note_fanout_level(tree: Tree, object_store) -> int:
+def get_note_fanout_level(tree: Tree, object_store: "BaseObjectStore") -> int:
     """Determine the fanout level for a note tree.
 
     Git uses a fanout directory structure for performance with large numbers
@@ -57,6 +59,7 @@ def get_note_fanout_level(tree: Tree, object_store) -> int:
             elif stat.S_ISDIR(mode) and level < 2:  # Only recurse 2 levels deep
                 try:
                     subtree = object_store[sha]
+                    assert isinstance(subtree, Tree)
                     count += count_notes(subtree, level + 1)
                 except KeyError:
                     pass
@@ -111,7 +114,7 @@ def get_note_path(object_sha: bytes, fanout_level: int = 0) -> bytes:
 class NotesTree:
     """Represents a Git notes tree."""
 
-    def __init__(self, tree: Tree, object_store):
+    def __init__(self, tree: Tree, object_store: "BaseObjectStore") -> None:
         """Initialize a notes tree.
 
         Args:
@@ -167,6 +170,7 @@ class NotesTree:
                 try:
                     sample_mode, sample_sha = self._tree[sample_dir_name]
                     sample_tree = self._object_store[sample_sha]
+                    assert isinstance(sample_tree, Tree)
 
                     # Check if this subtree also has 2-char hex directories
                     sub_has_dirs = False
@@ -236,6 +240,7 @@ class NotesTree:
                             # Update this subtree
                             if stat.S_ISDIR(mode):
                                 subtree = self._object_store[sha]
+                                assert isinstance(subtree, Tree)
                             else:
                                 # If not a directory, we need to replace it
                                 subtree = Tree()
@@ -303,7 +308,9 @@ class NotesTree:
                 mode, sha = current_tree[component]
                 if not stat.S_ISDIR(mode):  # Not a directory
                     return None
-                current_tree = self._object_store[sha]
+                obj = self._object_store[sha]
+                assert isinstance(obj, Tree)
+                current_tree = obj
             except KeyError:
                 return None
 
@@ -378,6 +385,7 @@ class NotesTree:
                         # Update this subtree
                         if stat.S_ISDIR(mode):
                             subtree = self._object_store[sha]
+                            assert isinstance(subtree, Tree)
                         else:
                             # If not a directory, we need to replace it
                             subtree = Tree()
@@ -444,6 +452,7 @@ class NotesTree:
                     if name == components[0] and stat.S_ISDIR(mode):
                         # Update this subtree
                         subtree = self._object_store[sha]
+                        assert isinstance(subtree, Tree)
                         new_subtree = remove_from_tree(subtree, components[1:])
                         if new_subtree is not None:
                             self._object_store.add_object(new_subtree)
@@ -478,6 +487,7 @@ class NotesTree:
             for name, mode, sha in tree.items():
                 if stat.S_ISDIR(mode):  # Directory
                     subtree = self._object_store[sha]
+                    assert isinstance(subtree, Tree)
                     yield from walk_tree(subtree, prefix + name)
                 elif stat.S_ISREG(mode):  # File
                     # Reconstruct the full hex SHA from the path
@@ -487,7 +497,7 @@ class NotesTree:
         yield from walk_tree(self._tree)
 
 
-def create_notes_tree(object_store) -> Tree:
+def create_notes_tree(object_store: "BaseObjectStore") -> Tree:
     """Create an empty notes tree.
 
     Args:
@@ -504,7 +514,9 @@ def create_notes_tree(object_store) -> Tree:
 class Notes:
     """High-level interface for Git notes operations."""
 
-    def __init__(self, object_store, refs_container):
+    def __init__(
+        self, object_store: "BaseObjectStore", refs_container: "RefsContainer"
+    ) -> None:
         """Initialize Notes.
 
         Args:

+ 14 - 6
dulwich/objectspec.py

@@ -26,6 +26,7 @@ from typing import TYPE_CHECKING, Optional, Union
 from .objects import Commit, ShaFile, Tag, Tree
 
 if TYPE_CHECKING:
+    from .object_store import BaseObjectStore
     from .refs import Ref, RefsContainer
     from .repo import Repo
 
@@ -267,7 +268,7 @@ def parse_reftuples(
     rh_container: Union["Repo", "RefsContainer"],
     refspecs: Union[bytes, list[bytes]],
     force: bool = False,
-):
+) -> list[tuple[Optional["Ref"], Optional["Ref"], bool]]:
     """Parse a list of reftuple specs to a list of reftuples.
 
     Args:
@@ -288,7 +289,10 @@ def parse_reftuples(
     return ret
 
 
-def parse_refs(container, refspecs):
+def parse_refs(
+    container: Union["Repo", "RefsContainer"],
+    refspecs: Union[bytes, str, list[Union[bytes, str]]],
+) -> list["Ref"]:
     """Parse a list of refspecs to a list of refs.
 
     Args:
@@ -343,12 +347,14 @@ def parse_commit_range(
 class AmbiguousShortId(Exception):
     """The short id is ambiguous."""
 
-    def __init__(self, prefix, options) -> None:
+    def __init__(self, prefix: bytes, options: list[ShaFile]) -> None:
         self.prefix = prefix
         self.options = options
 
 
-def scan_for_short_id(object_store, prefix, tp):
+def scan_for_short_id(
+    object_store: "BaseObjectStore", prefix: bytes, tp: type[ShaFile]
+) -> ShaFile:
     """Scan an object store for a short id."""
     ret = []
     for object_id in object_store.iter_prefix(prefix):
@@ -374,7 +380,7 @@ def parse_commit(repo: "Repo", committish: Union[str, bytes, Commit, Tag]) -> "C
       ValueError: If the range can not be parsed
     """
 
-    def dereference_tag(obj):
+    def dereference_tag(obj: ShaFile) -> "Commit":
         """Follow tag references until we reach a non-tag object."""
         while isinstance(obj, Tag):
             obj_type, obj_sha = obj.object
@@ -384,7 +390,9 @@ def parse_commit(repo: "Repo", committish: Union[str, bytes, Commit, Tag]) -> "C
                 # Tag points to a missing object
                 raise KeyError(obj_sha)
         if not isinstance(obj, Commit):
-            raise ValueError(f"Expected commit, got {obj.type_name}")
+            raise ValueError(
+                f"Expected commit, got {obj.type_name.decode('ascii', 'replace')}"
+            )
         return obj
 
     # If already a Commit object, return it directly

+ 1 - 0
dulwich/repo.py

@@ -1115,6 +1115,7 @@ class Repo(BaseRepo):
 
     path: str
     bare: bool
+    object_store: DiskObjectStore
 
     def __init__(
         self,