瀏覽代碼

Fix remaining mypy type errors and ruff linting issues

Jelmer Vernooij 1 月之前
父節點
當前提交
aee765a866

+ 6 - 5
dulwich/annotate.py

@@ -27,8 +27,9 @@ Python's difflib.
 """
 
 import difflib
-from typing import TYPE_CHECKING, Optional, cast
+from typing import TYPE_CHECKING, Optional
 
+from dulwich.objects import Blob
 from dulwich.walk import (
     ORDER_DATE,
     Walker,
@@ -37,7 +38,7 @@ from dulwich.walk import (
 if TYPE_CHECKING:
     from dulwich.diff_tree import TreeChange, TreeEntry
     from dulwich.object_store import BaseObjectStore
-    from dulwich.objects import Blob, Commit
+    from dulwich.objects import Commit
 
 # Walk over ancestry graph breadth-first
 # When checking each revision, find lines that according to difflib.Differ()
@@ -108,7 +109,7 @@ def annotate_lines(
 
     lines_annotated: list[tuple[tuple[Commit, TreeEntry], bytes]] = []
     for commit, entry in reversed(revs):
-        lines_annotated = update_lines(
-            lines_annotated, (commit, entry), cast("Blob", store[entry.sha])
-        )
+        blob_obj = store[entry.sha]
+        assert isinstance(blob_obj, Blob)
+        lines_annotated = update_lines(lines_annotated, (commit, entry), blob_obj)
     return lines_annotated

+ 4 - 2
dulwich/bisect.py

@@ -49,7 +49,7 @@ class BisectState:
         self,
         bad: Optional[bytes] = None,
         good: Optional[list[bytes]] = None,
-        paths: Optional[list[str]] = None,
+        paths: Optional[list[bytes]] = None,
         no_checkout: bool = False,
         term_bad: str = "bad",
         term_good: str = "good",
@@ -103,7 +103,9 @@ class BisectState:
         names_file = os.path.join(self.repo.controldir(), "BISECT_NAMES")
         with open(names_file, "w") as f:
             if paths:
-                f.write("\n".join(paths) + "\n")
+                f.write(
+                    "\n".join(path.decode("utf-8", "replace") for path in paths) + "\n"
+                )
             else:
                 f.write("\n")
 

+ 27 - 7
dulwich/bundle.py

@@ -22,9 +22,25 @@
 """Bundle format support."""
 
 from collections.abc import Iterator
-from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Optional
+from typing import (
+    TYPE_CHECKING,
+    BinaryIO,
+    Callable,
+    Optional,
+    Protocol,
+    runtime_checkable,
+)
+
+from .pack import PackData, UnpackedObject, write_pack_data
+
+
+@runtime_checkable
+class PackDataLike(Protocol):
+    """Protocol for objects that behave like PackData."""
+
+    def __len__(self) -> int: ...
+    def iter_unpacked(self) -> Iterator[UnpackedObject]: ...
 
-from .pack import PackData, write_pack_data
 
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
@@ -39,7 +55,7 @@ class Bundle:
     capabilities: dict[str, Optional[str]]
     prerequisites: list[tuple[bytes, bytes]]
     references: dict[bytes, bytes]
-    pack_data: PackData
+    pack_data: Optional[PackDataLike]
 
     def __repr__(self) -> str:
         """Return string representation of Bundle."""
@@ -79,10 +95,12 @@ class Bundle:
         """
         from .objects import ShaFile
 
+        if self.pack_data is None:
+            raise ValueError("pack_data is not loaded")
         count = 0
         for unpacked in self.pack_data.iter_unpacked():
             # Convert the unpacked object to a proper git object
-            if unpacked.decomp_chunks:
+            if unpacked.decomp_chunks and unpacked.obj_type_num is not None:
                 git_obj = ShaFile.from_raw_chunks(
                     unpacked.obj_type_num, unpacked.decomp_chunks
                 )
@@ -187,6 +205,8 @@ def write_bundle(f: BinaryIO, bundle: Bundle) -> None:
     for ref, obj_id in bundle.references.items():
         f.write(obj_id + b" " + ref + b"\n")
     f.write(b"\n")
+    if bundle.pack_data is None:
+        raise ValueError("bundle.pack_data is not loaded")
     write_pack_data(
         f.write,
         num_records=len(bundle.pack_data),
@@ -283,14 +303,14 @@ def create_bundle_from_repo(
     # Store the pack objects directly, we'll write them when saving the bundle
     # For now, create a simple wrapper to hold the data
     class _BundlePackData:
-        def __init__(self, count: int, objects: Iterator[Any]) -> None:
+        def __init__(self, count: int, objects: Iterator[UnpackedObject]) -> None:
             self._count = count
             self._objects = list(objects)  # Materialize the iterator
 
         def __len__(self) -> int:
             return self._count
 
-        def iter_unpacked(self) -> Iterator[Any]:
+        def iter_unpacked(self) -> Iterator[UnpackedObject]:
             return iter(self._objects)
 
     pack_data = _BundlePackData(pack_count, pack_objects)
@@ -301,6 +321,6 @@ def create_bundle_from_repo(
     bundle.capabilities = capabilities
     bundle.prerequisites = bundle_prerequisites
     bundle.references = bundle_refs
-    bundle.pack_data = pack_data  # type: ignore[assignment]
+    bundle.pack_data = pack_data
 
     return bundle

+ 81 - 32
dulwich/cli.py

@@ -37,7 +37,7 @@ import subprocess
 import sys
 import tempfile
 from pathlib import Path
-from typing import Callable, ClassVar, Optional, Union
+from typing import BinaryIO, Callable, ClassVar, Optional, Union
 
 from dulwich import porcelain
 
@@ -45,12 +45,26 @@ from .bundle import create_bundle_from_repo, read_bundle, write_bundle
 from .client import GitProtocolError, get_transport_and_path
 from .errors import ApplyDeltaError
 from .index import Index
-from .objects import valid_hexsha
+from .objects import Commit, valid_hexsha
 from .objectspec import parse_commit_range
 from .pack import Pack, sha_to_hex
 from .repo import Repo
 
 
+def to_display_str(value: Union[bytes, str]) -> str:
+    """Convert a bytes or string value to a display string.
+
+    Args:
+        value: The value to convert (bytes or str)
+
+    Returns:
+        A string suitable for display
+    """
+    if isinstance(value, bytes):
+        return value.decode("utf-8", "replace")
+    return value
+
+
 class CommitMessageError(Exception):
     """Raised when there's an issue with the commit message."""
 
@@ -126,7 +140,7 @@ def parse_relative_time(time_str: str) -> int:
         raise
 
 
-def format_bytes(bytes: int) -> str:
+def format_bytes(bytes: float) -> str:
     """Format bytes as human-readable string.
 
     Args:
@@ -231,7 +245,7 @@ class Pager:
         Args:
             pager_cmd: Command to use for paging (default: "cat")
         """
-        self.pager_process = None
+        self.pager_process: Optional[subprocess.Popen] = None
         self.buffer = PagerBuffer(self)
         self._closed = False
         self.pager_cmd = pager_cmd
@@ -449,12 +463,12 @@ def get_pager(config=None, cmd_name: Optional[str] = None):
 
 def disable_pager() -> None:
     """Disable pager for this session."""
-    get_pager._disabled = True
+    get_pager._disabled = True  # type: ignore[attr-defined]
 
 
 def enable_pager() -> None:
     """Enable pager for this session."""
-    get_pager._disabled = False
+    get_pager._disabled = False  # type: ignore[attr-defined]
 
 
 class Command:
@@ -491,10 +505,14 @@ class cmd_archive(Command):
                 write_error=sys.stderr.write,
             )
         else:
-            # Use buffer if available (for binary output), otherwise use stdout
-            outstream = getattr(sys.stdout, "buffer", sys.stdout)
+            # Use binary buffer for archive output
+            outstream: BinaryIO = sys.stdout.buffer
+            errstream: BinaryIO = sys.stderr.buffer
             porcelain.archive(
-                ".", args.committish, outstream=outstream, errstream=sys.stderr
+                ".",
+                args.committish,
+                outstream=outstream,
+                errstream=errstream,
             )
 
 
@@ -642,10 +660,11 @@ class cmd_fetch(Command):
         def progress(msg: bytes) -> None:
             sys.stdout.buffer.write(msg)
 
-        refs = client.fetch(path, r, progress=progress)
+        result = client.fetch(path, r, progress=progress)
         print("Remote refs:")
-        for item in refs.items():
-            print("{} -> {}".format(*item))
+        for ref, sha in result.refs.items():
+            if sha is not None:
+                print(f"{ref.decode()} -> {sha.decode()}")
 
 
 class cmd_for_each_ref(Command):
@@ -676,7 +695,7 @@ class cmd_fsck(Command):
         parser = argparse.ArgumentParser()
         parser.parse_args(args)
         for obj, msg in porcelain.fsck("."):
-            print(f"{obj}: {msg}")
+            print(f"{obj.decode() if isinstance(obj, bytes) else obj}: {msg}")
 
 
 class cmd_log(Command):
@@ -1302,9 +1321,17 @@ class cmd_reflog(Command):
 
                     for i, entry in enumerate(porcelain.reflog(repo, ref)):
                         # Format similar to git reflog
+                        from dulwich.reflog import Entry
+
+                        assert isinstance(entry, Entry)
                         short_new = entry.new_sha[:8].decode("ascii")
+                        message = (
+                            entry.message.decode("utf-8", "replace")
+                            if entry.message
+                            else ""
+                        )
                         outstream.write(
-                            f"{short_new} {ref.decode('utf-8', 'replace')}@{{{i}}}: {entry.message.decode('utf-8', 'replace')}\n"
+                            f"{short_new} {ref.decode('utf-8', 'replace')}@{{{i}}}: {message}\n"
                         )
 
 
@@ -1538,11 +1565,14 @@ class cmd_ls_remote(Command):
         if args.symref:
             # Show symrefs first, like git does
             for ref, target in sorted(result.symrefs.items()):
-                sys.stdout.write(f"ref: {target.decode()}\t{ref.decode()}\n")
+                if target:
+                    sys.stdout.write(f"ref: {target.decode()}\t{ref.decode()}\n")
 
         # Show regular refs
         for ref in sorted(result.refs):
-            sys.stdout.write(f"{result.refs[ref].decode()}\t{ref.decode()}\n")
+            sha = result.refs[ref]
+            if sha is not None:
+                sys.stdout.write(f"{sha.decode()}\t{ref.decode()}\n")
 
 
 class cmd_ls_tree(Command):
@@ -1601,12 +1631,13 @@ class cmd_pack_objects(Command):
         if not args.stdout and not args.basename:
             parser.error("basename required when not using --stdout")
 
-        object_ids = [line.strip() for line in sys.stdin.readlines()]
+        object_ids = [line.strip().encode() for line in sys.stdin.readlines()]
         deltify = args.deltify
         reuse_deltas = not args.no_reuse_deltas
 
         if args.stdout:
             packf = getattr(sys.stdout, "buffer", sys.stdout)
+            assert isinstance(packf, BinaryIO)
             idxf = None
             close = []
         else:
@@ -2022,8 +2053,17 @@ class cmd_stash_list(Command):
         """
         parser = argparse.ArgumentParser()
         parser.parse_args(args)
-        for i, entry in porcelain.stash_list("."):
-            print("stash@{{{}}}: {}".format(i, entry.message.rstrip("\n")))
+        from .repo import Repo
+        from .stash import Stash
+
+        with Repo(".") as r:
+            stash = Stash.from_repo(r)
+            for i, entry in enumerate(stash.stashes()):
+                print(
+                    "stash@{{{}}}: {}".format(
+                        i, entry.message.decode("utf-8", "replace").rstrip("\n")
+                    )
+                )
 
 
 class cmd_stash_push(Command):
@@ -2145,6 +2185,7 @@ class cmd_bisect(SuperCommand):
                         with open(bad_ref, "rb") as f:
                             bad_sha = f.read().strip()
                         commit = r.object_store[bad_sha]
+                        assert isinstance(commit, Commit)
                         message = commit.message.decode(
                             "utf-8", errors="replace"
                         ).split("\n")[0]
@@ -2173,7 +2214,7 @@ class cmd_bisect(SuperCommand):
                 print(log, end="")
 
             elif parsed_args.subcommand == "replay":
-                porcelain.bisect_replay(log_file=parsed_args.logfile)
+                porcelain.bisect_replay(".", log_file=parsed_args.logfile)
                 print(f"Replayed bisect log from {parsed_args.logfile}")
 
             elif parsed_args.subcommand == "help":
@@ -2270,6 +2311,7 @@ class cmd_merge(Command):
             elif args.no_commit:
                 print("Automatic merge successful; not committing as requested.")
             else:
+                assert merge_commit_id is not None
                 print(
                     f"Merge successful. Created merge commit {merge_commit_id.decode()}"
                 )
@@ -3129,12 +3171,14 @@ class cmd_lfs(Command):
             tracked = porcelain.lfs_untrack(patterns=args.patterns)
             print("Remaining tracked patterns:")
             for pattern in tracked:
-                print(f"  {pattern}")
+                print(f"  {to_display_str(pattern)}")
 
         elif args.subcommand == "ls-files":
             files = porcelain.lfs_ls_files(ref=args.ref)
             for path, oid, size in files:
-                print(f"{oid[:12]} * {path} ({format_bytes(size)})")
+                print(
+                    f"{to_display_str(oid[:12])} * {to_display_str(path)} ({format_bytes(size)})"
+                )
 
         elif args.subcommand == "migrate":
             count = porcelain.lfs_migrate(
@@ -3145,13 +3189,13 @@ class cmd_lfs(Command):
         elif args.subcommand == "pointer":
             if args.paths is not None:
                 results = porcelain.lfs_pointer_check(paths=args.paths or None)
-                for path, pointer in results.items():
+                for file_path, pointer in results.items():
                     if pointer:
                         print(
-                            f"{path}: LFS pointer (oid: {pointer.oid[:12]}, size: {format_bytes(pointer.size)})"
+                            f"{to_display_str(file_path)}: LFS pointer (oid: {to_display_str(pointer.oid[:12])}, size: {format_bytes(pointer.size)})"
                         )
                     else:
-                        print(f"{path}: Not an LFS pointer")
+                        print(f"{to_display_str(file_path)}: Not an LFS pointer")
 
         elif args.subcommand == "clean":
             pointer = porcelain.lfs_clean(path=args.path)
@@ -3188,13 +3232,13 @@ class cmd_lfs(Command):
 
             if status["missing"]:
                 print("\nMissing LFS objects:")
-                for path in status["missing"]:
-                    print(f"  {path}")
+                for file_path in status["missing"]:
+                    print(f"  {to_display_str(file_path)}")
 
             if status["not_staged"]:
                 print("\nModified LFS files not staged:")
-                for path in status["not_staged"]:
-                    print(f"  {path}")
+                for file_path in status["not_staged"]:
+                    print(f"  {to_display_str(file_path)}")
 
             if not any(status.values()):
                 print("No LFS files found.")
@@ -3273,14 +3317,19 @@ class cmd_format_patch(Command):
         args = parser.parse_args(args)
 
         # Parse committish using the new function
-        committish = None
+        committish: Optional[Union[bytes, tuple[bytes, bytes]]] = None
         if args.committish:
             with Repo(".") as r:
                 range_result = parse_commit_range(r, args.committish)
                 if range_result:
-                    committish = range_result
+                    # Convert Commit objects to their SHAs
+                    committish = (range_result[0].id, range_result[1].id)
                 else:
-                    committish = args.committish
+                    committish = (
+                        args.committish.encode()
+                        if isinstance(args.committish, str)
+                        else args.committish
+                    )
 
         filenames = porcelain.format_patch(
             ".",

+ 128 - 78
dulwich/client.py

@@ -53,7 +53,6 @@ from io import BufferedReader, BytesIO
 from typing import (
     IO,
     TYPE_CHECKING,
-    Any,
     Callable,
     ClassVar,
     Optional,
@@ -70,11 +69,11 @@ import dulwich
 
 from .config import Config, apply_instead_of, get_xdg_config_home_path
 from .errors import GitProtocolError, NotGitRepository, SendPackError
+from .object_store import GraphWalker
 from .pack import (
     PACK_SPOOL_FILE_MAX_SIZE,
     PackChunkGenerator,
     PackData,
-    UnpackedObject,
     write_pack_from_container,
 )
 from .protocol import (
@@ -117,7 +116,6 @@ from .protocol import (
     capability_agent,
     extract_capabilities,
     extract_capability_names,
-    filter_ref_prefix,
     parse_capability,
     pkt_line,
     pkt_seq,
@@ -130,6 +128,7 @@ from .refs import (
     _set_default_branch,
     _set_head,
     _set_origin_head,
+    filter_ref_prefix,
     read_info_refs,
     split_peeled_refs,
 )
@@ -150,7 +149,7 @@ logger = logging.getLogger(__name__)
 class InvalidWants(Exception):
     """Invalid wants."""
 
-    def __init__(self, wants: Any) -> None:
+    def __init__(self, wants: set[bytes]) -> None:
         """Initialize InvalidWants exception.
 
         Args:
@@ -164,7 +163,7 @@ class InvalidWants(Exception):
 class HTTPUnauthorized(Exception):
     """Raised when authentication fails."""
 
-    def __init__(self, www_authenticate: Any, url: str) -> None:
+    def __init__(self, www_authenticate: Optional[str], url: str) -> None:
         """Initialize HTTPUnauthorized exception.
 
         Args:
@@ -179,7 +178,7 @@ class HTTPUnauthorized(Exception):
 class HTTPProxyUnauthorized(Exception):
     """Raised when proxy authentication fails."""
 
-    def __init__(self, proxy_authenticate: Any, url: str) -> None:
+    def __init__(self, proxy_authenticate: Optional[str], url: str) -> None:
         """Initialize HTTPProxyUnauthorized exception.
 
         Args:
@@ -196,17 +195,23 @@ def _fileno_can_read(fileno: int) -> bool:
     return len(select.select([fileno], [], [], 0)[0]) > 0
 
 
-def _win32_peek_avail(handle: Any) -> int:
+def _win32_peek_avail(handle: int) -> int:
     """Wrapper around PeekNamedPipe to check how many bytes are available."""
-    from ctypes import byref, windll, wintypes
+    from ctypes import (  # type: ignore[attr-defined]
+        byref,
+        windll,  # type: ignore[attr-defined]
+        wintypes,
+    )
 
     c_avail = wintypes.DWORD()
     c_message = wintypes.DWORD()
-    success = windll.kernel32.PeekNamedPipe(
+    success = windll.kernel32.PeekNamedPipe(  # type: ignore[attr-defined]
         handle, None, 0, None, byref(c_avail), byref(c_message)
     )
     if not success:
-        raise OSError(wintypes.GetLastError())
+        from ctypes import GetLastError  # type: ignore[attr-defined]
+
+        raise OSError(GetLastError())
     return c_avail.value
 
 
@@ -231,10 +236,10 @@ class ReportStatusParser:
     def __init__(self) -> None:
         """Initialize ReportStatusParser."""
         self._done = False
-        self._pack_status = None
+        self._pack_status: Optional[bytes] = None
         self._ref_statuses: list[bytes] = []
 
-    def check(self) -> Any:
+    def check(self) -> Iterator[tuple[bytes, Optional[str]]]:
         """Check if there were any errors and, if so, raise exceptions.
 
         Raises:
@@ -277,7 +282,7 @@ class ReportStatusParser:
             self._ref_statuses.append(ref_status)
 
 
-def negotiate_protocol_version(proto: Any) -> int:
+def negotiate_protocol_version(proto: Protocol) -> int:
     pkt = proto.read_pkt_line()
     if pkt is not None and pkt.strip() == b"version 2":
         return 2
@@ -285,7 +290,7 @@ def negotiate_protocol_version(proto: Any) -> int:
     return 0
 
 
-def read_server_capabilities(pkt_seq: Any) -> set:
+def read_server_capabilities(pkt_seq: Iterable[bytes]) -> set[bytes]:
     server_capabilities = []
     for pkt in pkt_seq:
         server_capabilities.append(pkt)
@@ -293,21 +298,15 @@ def read_server_capabilities(pkt_seq: Any) -> set:
 
 
 def read_pkt_refs_v2(
-    pkt_seq: Any,
-) -> tuple[dict[bytes, bytes], dict[bytes, bytes], dict[bytes, bytes]]:
-    """Read packet references in protocol v2 format.
-
-    Args:
-      pkt_seq: Sequence of packets
-    Returns: Tuple of (refs dict, symrefs dict, peeled dict)
-    """
-    refs = {}
+    pkt_seq: Iterable[bytes],
+) -> tuple[dict[bytes, Optional[bytes]], dict[bytes, bytes], dict[bytes, bytes]]:
+    refs: dict[bytes, Optional[bytes]] = {}
     symrefs = {}
     peeled = {}
     # Receive refs from server
     for pkt in pkt_seq:
         parts = pkt.rstrip(b"\n").split(b" ")
-        sha = parts[0]
+        sha: Optional[bytes] = parts[0]
         if sha == b"unborn":
             sha = None
         ref = parts[1]
@@ -323,9 +322,11 @@ def read_pkt_refs_v2(
     return refs, symrefs, peeled
 
 
-def read_pkt_refs_v1(pkt_seq: Any) -> tuple[dict[bytes, bytes], set[bytes]]:
+def read_pkt_refs_v1(
+    pkt_seq: Iterable[bytes],
+) -> tuple[dict[bytes, Optional[bytes]], set[bytes]]:
     server_capabilities = None
-    refs = {}
+    refs: dict[bytes, Optional[bytes]] = {}
     # Receive refs from server
     for pkt in pkt_seq:
         (sha, ref) = pkt.rstrip(b"\n").split(None, 1)
@@ -346,6 +347,8 @@ def read_pkt_refs_v1(pkt_seq: Any) -> tuple[dict[bytes, bytes], set[bytes]]:
 class _DeprecatedDictProxy:
     """Base class for result objects that provide deprecated dict-like interface."""
 
+    refs: dict[bytes, Optional[bytes]]  # To be overridden by subclasses
+
     _FORWARDED_ATTRS: ClassVar[set[str]] = {
         "clear",
         "copy",
@@ -376,7 +379,7 @@ class _DeprecatedDictProxy:
         self._warn_deprecated()
         return name in self.refs
 
-    def __getitem__(self, name: bytes) -> bytes:
+    def __getitem__(self, name: bytes) -> Optional[bytes]:
         self._warn_deprecated()
         return self.refs[name]
 
@@ -384,11 +387,11 @@ class _DeprecatedDictProxy:
         self._warn_deprecated()
         return len(self.refs)
 
-    def __iter__(self) -> Any:
+    def __iter__(self) -> Iterator[bytes]:
         self._warn_deprecated()
         return iter(self.refs)
 
-    def __getattribute__(self, name: str) -> Any:
+    def __getattribute__(self, name: str) -> object:
         # Avoid infinite recursion by checking against class variable directly
         if name != "_FORWARDED_ATTRS" and name in type(self)._FORWARDED_ATTRS:
             self._warn_deprecated()
@@ -407,8 +410,16 @@ class FetchPackResult(_DeprecatedDictProxy):
       agent: User agent string
     """
 
+    symrefs: dict[bytes, bytes]
+    agent: Optional[bytes]
+
     def __init__(
-        self, refs: dict, symrefs: dict, agent: Optional[bytes], new_shallow: Optional[Any] = None, new_unshallow: Optional[Any] = None
+        self,
+        refs: dict[bytes, Optional[bytes]],
+        symrefs: dict[bytes, bytes],
+        agent: Optional[bytes],
+        new_shallow: Optional[set[bytes]] = None,
+        new_unshallow: Optional[set[bytes]] = None,
     ) -> None:
         """Initialize FetchPackResult.
 
@@ -425,10 +436,12 @@ class FetchPackResult(_DeprecatedDictProxy):
         self.new_shallow = new_shallow
         self.new_unshallow = new_unshallow
 
-    def __eq__(self, other: Any) -> bool:
+    def __eq__(self, other: object) -> bool:
         if isinstance(other, dict):
             self._warn_deprecated()
             return self.refs == other
+        if not isinstance(other, FetchPackResult):
+            return False
         return (
             self.refs == other.refs
             and self.symrefs == other.symrefs
@@ -448,7 +461,11 @@ class LsRemoteResult(_DeprecatedDictProxy):
       symrefs: Dictionary with remote symrefs
     """
 
-    def __init__(self, refs: dict, symrefs: dict) -> None:
+    symrefs: dict[bytes, bytes]
+
+    def __init__(
+        self, refs: dict[bytes, Optional[bytes]], symrefs: dict[bytes, bytes]
+    ) -> None:
         """Initialize LsRemoteResult.
 
         Args:
@@ -468,10 +485,12 @@ class LsRemoteResult(_DeprecatedDictProxy):
             stacklevel=3,
         )
 
-    def __eq__(self, other: Any) -> bool:
+    def __eq__(self, other: object) -> bool:
         if isinstance(other, dict):
             self._warn_deprecated()
             return self.refs == other
+        if not isinstance(other, LsRemoteResult):
+            return False
         return self.refs == other.refs and self.symrefs == other.symrefs
 
     def __repr__(self) -> str:
@@ -489,7 +508,12 @@ class SendPackResult(_DeprecatedDictProxy):
         failed to update), or None if it was updated successfully
     """
 
-    def __init__(self, refs: dict, agent: Optional[bytes] = None, ref_status: Optional[dict] = None) -> None:
+    def __init__(
+        self,
+        refs: dict[bytes, Optional[bytes]],
+        agent: Optional[bytes] = None,
+        ref_status: Optional[dict[bytes, Optional[str]]] = None,
+    ) -> None:
         """Initialize SendPackResult.
 
         Args:
@@ -501,10 +525,12 @@ class SendPackResult(_DeprecatedDictProxy):
         self.agent = agent
         self.ref_status = ref_status
 
-    def __eq__(self, other: Any) -> bool:
+    def __eq__(self, other: object) -> bool:
         if isinstance(other, dict):
             self._warn_deprecated()
             return self.refs == other
+        if not isinstance(other, SendPackResult):
+            return False
         return self.refs == other.refs and self.agent == other.agent
 
     def __repr__(self) -> str:
@@ -512,7 +538,7 @@ class SendPackResult(_DeprecatedDictProxy):
         return f"{self.__class__.__name__}({self.refs!r}, {self.agent!r})"
 
 
-def _read_shallow_updates(pkt_seq: Any) -> tuple[set, set]:
+def _read_shallow_updates(pkt_seq: Iterable[bytes]) -> tuple[set[bytes], set[bytes]]:
     new_shallow = set()
     new_unshallow = set()
     for pkt in pkt_seq:
@@ -521,27 +547,29 @@ def _read_shallow_updates(pkt_seq: Any) -> tuple[set, set]:
         try:
             cmd, sha = pkt.split(b" ", 1)
         except ValueError:
-            raise GitProtocolError(f"unknown command {pkt}")
+            raise GitProtocolError(f"unknown command {pkt!r}")
         if cmd == COMMAND_SHALLOW:
             new_shallow.add(sha.strip())
         elif cmd == COMMAND_UNSHALLOW:
             new_unshallow.add(sha.strip())
         else:
-            raise GitProtocolError(f"unknown command {pkt}")
+            raise GitProtocolError(f"unknown command {pkt!r}")
     return (new_shallow, new_unshallow)
 
 
 class _v1ReceivePackHeader:
     def __init__(self, capabilities: list, old_refs: dict, new_refs: dict) -> None:
-        self.want: list[bytes] = []
-        self.have: list[bytes] = []
+        self.want: set[bytes] = set()
+        self.have: set[bytes] = set()
         self._it = self._handle_receive_pack_head(capabilities, old_refs, new_refs)
         self.sent_capabilities = False
 
-    def __iter__(self) -> Any:
+    def __iter__(self) -> Iterator[Optional[bytes]]:
         return self._it
 
-    def _handle_receive_pack_head(self, capabilities: list, old_refs: dict, new_refs: dict) -> Any:
+    def _handle_receive_pack_head(
+        self, capabilities: list, old_refs: dict, new_refs: dict
+    ) -> Iterator[Optional[bytes]]:
         """Handle the head of a 'git-receive-pack' request.
 
         Args:
@@ -552,7 +580,7 @@ class _v1ReceivePackHeader:
         Returns:
           (have, want) tuple
         """
-        self.have = [x for x in old_refs.values() if not x == ZERO_SHA]
+        self.have = {x for x in old_refs.values() if not x == ZERO_SHA}
 
         for refname in new_refs:
             if not isinstance(refname, bytes):
@@ -586,7 +614,7 @@ class _v1ReceivePackHeader:
                     )
                     self.sent_capabilities = True
             if new_sha1 not in self.have and new_sha1 != ZERO_SHA:
-                self.want.append(new_sha1)
+                self.want.add(new_sha1)
         yield None
 
 
@@ -603,25 +631,28 @@ def _read_side_band64k_data(pkt_seq: Iterable[bytes]) -> Iterator[tuple[int, byt
         yield channel, pkt[1:]
 
 
-def find_capability(capabilities: list, key: bytes, value: Optional[bytes]) -> Optional[bytes]:
+def find_capability(
+    capabilities: list, key: bytes, value: Optional[bytes]
+) -> Optional[bytes]:
     for capability in capabilities:
         k, v = parse_capability(capability)
         if k != key:
             continue
-        if value and value not in v.split(b" "):
+        if value and v and value not in v.split(b" "):
             continue
         return capability
+    return None
 
 
 def _handle_upload_pack_head(
-    proto: Any,
+    proto: Protocol,
     capabilities: list,
-    graph_walker: Any,
+    graph_walker: GraphWalker,
     wants: list,
-    can_read: Callable,
+    can_read: Optional[Callable],
     depth: Optional[int],
     protocol_version: Optional[int],
-) -> None:
+) -> tuple[Optional[set[bytes]], Optional[set[bytes]]]:
     """Handle the head of a 'git-upload-pack' request.
 
     Args:
@@ -634,6 +665,8 @@ def _handle_upload_pack_head(
       depth: Depth for request
       protocol_version: Neogiated Git protocol version.
     """
+    new_shallow: Optional[set[bytes]]
+    new_unshallow: Optional[set[bytes]]
     assert isinstance(wants, list) and isinstance(wants[0], bytes)
     wantcmd = COMMAND_WANT + b" " + wants[0]
     if protocol_version is None:
@@ -644,7 +677,9 @@ def _handle_upload_pack_head(
     proto.write_pkt_line(wantcmd)
     for want in wants[1:]:
         proto.write_pkt_line(COMMAND_WANT + b" " + want + b"\n")
-    if depth not in (0, None) or graph_walker.shallow:
+    if depth not in (0, None) or (
+        hasattr(graph_walker, "shallow") and graph_walker.shallow
+    ):
         if protocol_version == 2:
             if not find_capability(capabilities, CAPABILITY_FETCH, CAPABILITY_SHALLOW):
                 raise GitProtocolError(
@@ -654,8 +689,9 @@ def _handle_upload_pack_head(
             raise GitProtocolError(
                 "server does not support shallow capability required for depth"
             )
-        for sha in graph_walker.shallow:
-            proto.write_pkt_line(COMMAND_SHALLOW + b" " + sha + b"\n")
+        if hasattr(graph_walker, "shallow"):
+            for sha in graph_walker.shallow:
+                proto.write_pkt_line(COMMAND_SHALLOW + b" " + sha + b"\n")
         if depth is not None:
             proto.write_pkt_line(
                 COMMAND_DEEPEN + b" " + str(depth).encode("ascii") + b"\n"
@@ -668,6 +704,7 @@ def _handle_upload_pack_head(
         proto.write_pkt_line(COMMAND_HAVE + b" " + have + b"\n")
         if can_read is not None and can_read():
             pkt = proto.read_pkt_line()
+            assert pkt is not None
             parts = pkt.rstrip(b"\n").split(b" ")
             if parts[0] == b"ACK":
                 graph_walker.ack(parts[1])
@@ -677,7 +714,7 @@ def _handle_upload_pack_head(
                     break
                 else:
                     raise AssertionError(
-                        f"{parts[2]} not in ('continue', 'ready', 'common)"
+                        f"{parts[2]!r} not in ('continue', 'ready', 'common)"
                     )
         have = next(graph_walker)
     proto.write_pkt_line(COMMAND_DONE + b"\n")
@@ -688,7 +725,8 @@ def _handle_upload_pack_head(
         if can_read is not None:
             (new_shallow, new_unshallow) = _read_shallow_updates(proto.read_pkt_seq())
         else:
-            new_shallow = new_unshallow = None
+            new_shallow = None
+            new_unshallow = None
     else:
         new_shallow = new_unshallow = set()
 
@@ -767,6 +805,7 @@ def _extract_symrefs_and_agent(capabilities):
     for capability in capabilities:
         k, v = parse_capability(capability)
         if k == CAPABILITY_SYMREF:
+            assert v is not None
             (src, dst) = v.split(b":", 1)
             symrefs[src] = dst
         if k == CAPABILITY_AGENT:
@@ -842,9 +881,7 @@ class GitClient:
         self,
         path: str,
         update_refs,
-        generate_pack_data: Callable[
-            [set[bytes], set[bytes], bool], tuple[int, Iterator[UnpackedObject]]
-        ],
+        generate_pack_data,
         progress=None,
     ) -> SendPackResult:
         """Upload a pack to a remote repository.
@@ -935,8 +972,11 @@ class GitClient:
             origin_sha = result.refs.get(b"HEAD")
             if origin is None or (origin_sha and not origin_head):
                 # set detached HEAD
-                target.refs[b"HEAD"] = origin_sha
-                head = origin_sha
+                if origin_sha is not None:
+                    target.refs[b"HEAD"] = origin_sha
+                    head = origin_sha
+                else:
+                    head = None
             else:
                 _set_origin_head(target.refs, origin.encode("utf-8"), origin_head)
                 head_ref = _set_default_branch(
@@ -1166,10 +1206,11 @@ class GitClient:
             if self.protocol_version == 2 and k == CAPABILITY_FETCH:
                 fetch_capa = CAPABILITY_FETCH
                 fetch_features = []
-                v = v.strip().split(b" ")
-                if b"shallow" in v:
+                assert v is not None
+                v_list = v.strip().split(b" ")
+                if b"shallow" in v_list:
                     fetch_features.append(CAPABILITY_SHALLOW)
-                if b"filter" in v:
+                if b"filter" in v_list:
                     fetch_features.append(CAPABILITY_FILTER)
                 for i in range(len(fetch_features)):
                     if i == 0:
@@ -1320,10 +1361,10 @@ class TraditionalGitClient(GitClient):
                 for ref, sha in orig_new_refs.items():
                     if sha == ZERO_SHA:
                         if CAPABILITY_REPORT_STATUS in negotiated_capabilities:
+                            assert report_status_parser is not None
                             report_status_parser._ref_statuses.append(
                                 b"ng " + ref + b" remote does not support deleting refs"
                             )
-                            report_status_parser._ref_status_ok = False
                         del new_refs[ref]
 
             if new_refs is None:
@@ -1730,7 +1771,7 @@ class TCPGitClient(TraditionalGitClient):
         proto.send_cmd(
             b"git-" + cmd, path, b"host=" + self._host.encode("ascii") + version_str
         )
-        return proto, lambda: _fileno_can_read(s), None
+        return proto, lambda: _fileno_can_read(s.fileno()), None
 
 
 class SubprocessWrapper:
@@ -1960,7 +2001,7 @@ class LocalGitClient(GitClient):
                 *generate_pack_data(have, want, ofs_delta=True)
             )
 
-            ref_status = {}
+            ref_status: dict[bytes, Optional[str]] = {}
 
             for refname, new_sha1 in new_refs.items():
                 old_sha1 = old_refs.get(refname, ZERO_SHA)
@@ -2199,7 +2240,7 @@ class BundleClient(GitClient):
 
             while line.startswith(b"-"):
                 (obj_id, comment) = line[1:].rstrip(b"\n").split(b" ", 1)
-                prerequisites.append((obj_id, comment.decode("utf-8")))
+                prerequisites.append((obj_id, comment))
                 line = f.readline()
 
             while line != b"\n":
@@ -2940,7 +2981,11 @@ class AbstractHttpGitClient(GitClient):
         protocol_version: Optional[int] = None,
         ref_prefix: Optional[list[Ref]] = None,
     ) -> tuple[
-        dict[Ref, ObjectID], set[bytes], str, dict[Ref, Ref], dict[Ref, ObjectID]
+        dict[Ref, Optional[ObjectID]],
+        set[bytes],
+        str,
+        dict[Ref, Ref],
+        dict[Ref, ObjectID],
     ]:
         if (
             protocol_version is not None
@@ -3003,10 +3048,10 @@ class AbstractHttpGitClient(GitClient):
                     resp, read = self._smart_request(
                         service.decode("ascii"), base_url, body
                     )
-                    proto = Protocol(read, None)
+                    proto = Protocol(read, lambda data: None)
                     return server_capabilities, resp, read, proto
 
-                proto = Protocol(read, None)  # type: ignore
+                proto = Protocol(read, lambda data: None)
                 server_protocol_version = negotiate_protocol_version(proto)
                 if server_protocol_version not in GIT_PROTOCOL_VERSIONS:
                     raise ValueError(
@@ -3071,7 +3116,12 @@ class AbstractHttpGitClient(GitClient):
                     if not chunk:
                         break
                     data += chunk
-                (refs, peeled) = split_peeled_refs(read_info_refs(BytesIO(data)))
+                from typing import Optional, cast
+
+                info_refs = read_info_refs(BytesIO(data))
+                (refs, peeled) = split_peeled_refs(
+                    cast(dict[bytes, Optional[bytes]], info_refs)
+                )
                 if ref_prefix is not None:
                     refs = filter_ref_prefix(refs, ref_prefix)
                 return refs, set(), base_url, {}, peeled
@@ -3159,7 +3209,7 @@ class AbstractHttpGitClient(GitClient):
 
         resp, read = self._smart_request("git-receive-pack", url, data=body_generator())
         try:
-            resp_proto = Protocol(read, None)
+            resp_proto = Protocol(read, lambda data: None)
             ref_status = self._handle_receive_pack_tail(
                 resp_proto, negotiated_capabilities, progress
             )
@@ -3404,7 +3454,7 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
     def _http_request(self, url, headers=None, data=None, raise_for_status=True):
         import urllib3.exceptions
 
-        req_headers = self.pool_manager.headers.copy()
+        req_headers = dict(self.pool_manager.headers)
         if headers is not None:
             req_headers.update(headers)
         req_headers["Pragma"] = "no-cache"
@@ -3418,10 +3468,10 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
                 request_kwargs["timeout"] = self._timeout
 
             if data is None:
-                resp = self.pool_manager.request("GET", url, **request_kwargs)
+                resp = self.pool_manager.request("GET", url, **request_kwargs)  # type: ignore[arg-type]
             else:
                 request_kwargs["body"] = data
-                resp = self.pool_manager.request("POST", url, **request_kwargs)
+                resp = self.pool_manager.request("POST", url, **request_kwargs)  # type: ignore[arg-type]
         except urllib3.exceptions.HTTPError as e:
             raise GitProtocolError(str(e)) from e
 
@@ -3435,15 +3485,15 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
             if resp.status != 200:
                 raise GitProtocolError(f"unexpected http resp {resp.status} for {url}")
 
-        resp.content_type = resp.headers.get("Content-Type")
+        resp.content_type = resp.headers.get("Content-Type")  # type: ignore[attr-defined]
         # Check if geturl() is available (urllib3 version >= 1.23)
         try:
             resp_url = resp.geturl()
         except AttributeError:
             # get_redirect_location() is available for urllib3 >= 1.1
-            resp.redirect_location = resp.get_redirect_location()
+            resp.redirect_location = resp.get_redirect_location()  # type: ignore[attr-defined]
         else:
-            resp.redirect_location = resp_url if resp_url != url else ""
+            resp.redirect_location = resp_url if resp_url != url else ""  # type: ignore[attr-defined]
         return resp, _wrap_urllib3_exceptions(resp.read)
 
 

+ 9 - 3
dulwich/cloud/gcs.py

@@ -52,7 +52,7 @@ class GcsObjectStore(BucketBasedObjectStore):
         """Return string representation of GcsObjectStore."""
         return f"{type(self).__name__}({self.bucket!r}, subpath={self.subpath!r})"
 
-    def _remove_pack(self, name: str) -> None:
+    def _remove_pack_by_name(self, name: str) -> None:
         self.bucket.delete_blobs(
             [posixpath.join(self.subpath, name) + "." + ext for ext in ["pack", "idx"]]
         )
@@ -68,10 +68,14 @@ class GcsObjectStore(BucketBasedObjectStore):
 
     def _load_pack_data(self, name: str) -> PackData:
         b = self.bucket.blob(posixpath.join(self.subpath, name + ".pack"))
+        from typing import cast
+
+        from ..file import _GitFile
+
         f = tempfile.SpooledTemporaryFile(max_size=PACK_SPOOL_FILE_MAX_SIZE)
         b.download_to_file(f)
         f.seek(0)
-        return PackData(name + ".pack", f)
+        return PackData(name + ".pack", cast(_GitFile, f))
 
     def _load_pack_index(self, name: str) -> PackIndex:
         b = self.bucket.blob(posixpath.join(self.subpath, name + ".idx"))
@@ -85,7 +89,9 @@ class GcsObjectStore(BucketBasedObjectStore):
             lambda: self._load_pack_data(name), lambda: self._load_pack_index(name)
         )
 
-    def _upload_pack(self, basename: str, pack_file: BinaryIO, index_file: BinaryIO) -> None:
+    def _upload_pack(
+        self, basename: str, pack_file: BinaryIO, index_file: BinaryIO
+    ) -> None:
         idxblob = self.bucket.blob(posixpath.join(self.subpath, basename + ".idx"))
         datablob = self.bucket.blob(posixpath.join(self.subpath, basename + ".pack"))
         idxblob.upload_from_file(index_file)

+ 1 - 3
dulwich/commit_graph.py

@@ -567,9 +567,7 @@ def write_commit_graph(
 
     graph_path = os.path.join(info_dir, b"commit-graph")
     with GitFile(graph_path, "wb") as f:
-        from typing import BinaryIO, cast
-
-        graph.write_to_file(cast(BinaryIO, f))
+        graph.write_to_file(f)
 
 
 def get_reachable_commits(

+ 39 - 12
dulwich/config.py

@@ -41,14 +41,11 @@ from contextlib import suppress
 from pathlib import Path
 from typing import (
     IO,
-    Any,
-    BinaryIO,
     Callable,
     Generic,
     Optional,
     TypeVar,
     Union,
-    cast,
     overload,
 )
 
@@ -60,7 +57,7 @@ ConfigValue = Union[str, bytes, bool, int]
 logger = logging.getLogger(__name__)
 
 # Type for file opener callback
-FileOpener = Callable[[Union[str, os.PathLike]], BinaryIO]
+FileOpener = Callable[[Union[str, os.PathLike]], IO[bytes]]
 
 # Type for includeIf condition matcher
 # Takes the condition value (e.g., "main" for onbranch:main) and returns bool
@@ -194,7 +191,7 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
           default_factory: Optional factory function for default values
         """
         self._real: list[tuple[K, V]] = []
-        self._keyed: dict[Any, V] = {}
+        self._keyed: dict[ConfigKey, V] = {}
         self._default_factory = default_factory
 
     @classmethod
@@ -239,7 +236,31 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
 
     def keys(self) -> KeysView[K]:
         """Return a view of the dictionary's keys."""
-        return self._keyed.keys()  # type: ignore[return-value]
+        # Return a view of the original keys (not lowercased)
+        # We need to deduplicate since _real can have duplicates
+        seen = set()
+        unique_keys = []
+        for k, _ in self._real:
+            lower = lower_key(k)
+            if lower not in seen:
+                seen.add(lower)
+                unique_keys.append(k)
+        from collections.abc import KeysView as ABCKeysView
+
+        class UniqueKeysView(ABCKeysView[K]):
+            def __init__(self, keys: list[K]):
+                self._keys = keys
+
+            def __contains__(self, key: object) -> bool:
+                return key in self._keys
+
+            def __iter__(self):
+                return iter(self._keys)
+
+            def __len__(self) -> int:
+                return len(self._keys)
+
+        return UniqueKeysView(unique_keys)
 
     def items(self) -> ItemsView[K, V]:
         """Return a view of the dictionary's (key, value) pairs in insertion order."""
@@ -267,7 +288,13 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
 
     def __iter__(self) -> Iterator[K]:
         """Iterate over the dictionary's keys."""
-        return iter(self._keyed)
+        # Return iterator over original keys (not lowercased), deduplicated
+        seen = set()
+        for k, _ in self._real:
+            lower = lower_key(k)
+            if lower not in seen:
+                seen.add(lower)
+                yield k
 
     def values(self) -> ValuesView[V]:
         """Return a view of the dictionary's values."""
@@ -898,7 +925,7 @@ class ConfigFile(ConfigDict):
     @classmethod
     def from_file(
         cls,
-        f: BinaryIO,
+        f: IO[bytes],
         *,
         config_dir: Optional[str] = None,
         included_paths: Optional[set[str]] = None,
@@ -1075,8 +1102,8 @@ class ConfigFile(ConfigDict):
             opener: FileOpener
             if file_opener is None:
 
-                def opener(path: Union[str, os.PathLike]) -> BinaryIO:
-                    return cast(BinaryIO, GitFile(path, "rb"))
+                def opener(path: Union[str, os.PathLike]) -> IO[bytes]:
+                    return GitFile(path, "rb")
             else:
                 opener = file_opener
 
@@ -1236,8 +1263,8 @@ class ConfigFile(ConfigDict):
         opener: FileOpener
         if file_opener is None:
 
-            def opener(p: Union[str, os.PathLike]) -> BinaryIO:
-                return cast(BinaryIO, GitFile(p, "rb"))
+            def opener(p: Union[str, os.PathLike]) -> IO[bytes]:
+                return GitFile(p, "rb")
         else:
             opener = file_opener
 

+ 12 - 6
dulwich/contrib/swift.py

@@ -43,6 +43,7 @@ from typing import BinaryIO, Callable, Optional, Union, cast
 
 from geventhttpclient import HTTPClient
 
+from ..file import _GitFile
 from ..greenthreads import GreenThreadsMissingObjectFinder
 from ..lru_cache import LRUSizeCache
 from ..object_store import INFODIR, PACKDIR, ObjectContainer, PackBasedObjectStore
@@ -823,7 +824,11 @@ class SwiftObjectStore(PackBasedObjectStore):
               The created SwiftPack or None if empty
             """
             f.seek(0)
-            pack = PackData(file=f, filename="")
+            from typing import cast
+
+            from ..file import _GitFile
+
+            pack = PackData(file=cast(_GitFile, f), filename="")
             entries = pack.sorted_entries()
             if entries:
                 basename = posixpath.join(
@@ -875,7 +880,8 @@ class SwiftObjectStore(PackBasedObjectStore):
         fd, path = tempfile.mkstemp(prefix="tmp_pack_")
         f = os.fdopen(fd, "w+b")
         try:
-            indexer = PackIndexer(f, resolve_ext_ref=None)
+            pack_data = PackData(file=cast(_GitFile, f), filename=path)
+            indexer = PackIndexer(pack_data, resolve_ext_ref=None)
             copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer)
             copier.verify()
             return self._complete_thin_pack(f, path, copier, indexer)
@@ -928,7 +934,7 @@ class SwiftObjectStore(PackBasedObjectStore):
 
         # Write pack info.
         f.seek(0)
-        pack_data = PackData(filename="", file=f)
+        pack_data = PackData(filename="", file=cast(_GitFile, f))
         index_file.seek(0)
         pack_index = load_pack_index_file("", index_file)
         serialized_pack_info = pack_info_create(pack_data, pack_index)
@@ -1030,17 +1036,17 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
         del self._refs[name]
         return True
 
-    def allkeys(self) -> Iterator[bytes]:
+    def allkeys(self) -> set[bytes]:
         """Get all reference names.
 
         Returns:
-          Iterator of reference names as bytes
+          Set of reference names as bytes
         """
         try:
             self._refs[b"HEAD"] = self._refs[b"refs/heads/master"]
         except KeyError:
             pass
-        return iter(self._refs.keys())
+        return set(self._refs.keys())
 
 
 class SwiftRepo(BaseRepo):

+ 12 - 6
dulwich/diff.py

@@ -47,11 +47,11 @@ Example usage:
 import logging
 import os
 import stat
-from typing import BinaryIO, Optional, cast
+from typing import BinaryIO, Optional
 
 from .index import ConflictedIndexEntry, commit_index
 from .object_store import iter_tree_contents
-from .objects import S_ISGITLINK, Blob
+from .objects import S_ISGITLINK, Blob, Commit
 from .patch import write_blob_diff, write_object_diff
 from .repo import Repo
 
@@ -90,12 +90,16 @@ def diff_index_to_tree(
     if commit_sha is None:
         try:
             commit_sha = repo.refs[b"HEAD"]
-            old_tree = repo[commit_sha].tree
+            old_commit = repo[commit_sha]
+            assert isinstance(old_commit, Commit)
+            old_tree = old_commit.tree
         except KeyError:
             # No HEAD means no commits yet
             old_tree = None
     else:
-        old_tree = repo[commit_sha].tree
+        old_commit = repo[commit_sha]
+        assert isinstance(old_commit, Commit)
+        old_tree = old_commit.tree
 
     # Get tree from index
     index = repo.open_index()
@@ -125,7 +129,9 @@ def diff_working_tree_to_tree(
         commit_sha: SHA of commit to compare against
         paths: Optional list of paths to filter (as bytes)
     """
-    tree = repo[commit_sha].tree
+    commit = repo[commit_sha]
+    assert isinstance(commit, Commit)
+    tree = commit.tree
     normalizer = repo.get_blob_normalizer()
     filter_callback = normalizer.checkin_normalize
 
@@ -382,7 +388,7 @@ def diff_working_tree_to_index(
         old_obj = repo.object_store[old_sha]
         # Type check and cast to Blob
         if isinstance(old_obj, Blob):
-            old_blob = cast(Blob, old_obj)
+            old_blob = old_obj
         else:
             old_blob = None
 

+ 4 - 4
dulwich/dumb.py

@@ -24,7 +24,7 @@
 import os
 import tempfile
 import zlib
-from collections.abc import Iterator
+from collections.abc import Iterator, Sequence
 from io import BytesIO
 from typing import Any, Callable, Optional
 from urllib.parse import urljoin
@@ -338,9 +338,9 @@ class DumbHTTPObjectStore(BaseObjectStore):
 
     def add_objects(
         self,
-        objects: Iterator[ShaFile],
-        progress: Optional[Callable[[int], None]] = None,
-    ) -> None:
+        objects: Sequence[tuple[ShaFile, Optional[str]]],
+        progress: Optional[Callable[[str], None]] = None,
+    ) -> Optional["Pack"]:
         """Add a set of objects to this object store."""
         raise NotImplementedError("Cannot add objects to dumb HTTP repository")
 

+ 53 - 9
dulwich/file.py

@@ -24,10 +24,15 @@
 import os
 import sys
 import warnings
-from collections.abc import Iterator
+from collections.abc import Iterable, Iterator
 from types import TracebackType
 from typing import IO, Any, ClassVar, Literal, Optional, Union, overload
 
+if sys.version_info >= (3, 12):
+    from collections.abc import Buffer
+else:
+    Buffer = Union[bytes, bytearray, memoryview]
+
 
 def ensure_dir_exists(dirname: Union[str, bytes, os.PathLike]) -> None:
     """Ensure a directory exists, creating if necessary."""
@@ -136,7 +141,7 @@ class FileLocked(Exception):
         super().__init__(filename, lockfilename)
 
 
-class _GitFile:
+class _GitFile(IO[bytes]):
     """File that follows the git locking protocol for writes.
 
     All writes to a file foo will be written into foo.lock in the same
@@ -148,7 +153,6 @@ class _GitFile:
     """
 
     PROXY_PROPERTIES: ClassVar[set[str]] = {
-        "closed",
         "encoding",
         "errors",
         "mode",
@@ -158,15 +162,19 @@ class _GitFile:
     }
     PROXY_METHODS: ClassVar[set[str]] = {
         "__iter__",
+        "__next__",
         "flush",
         "fileno",
         "isatty",
         "read",
+        "readable",
         "readline",
         "readlines",
         "seek",
+        "seekable",
         "tell",
         "truncate",
+        "writable",
         "write",
         "writelines",
     }
@@ -195,9 +203,6 @@ class _GitFile:
         self._file = os.fdopen(fd, mode, bufsize)
         self._closed = False
 
-        for method in self.PROXY_METHODS:
-            setattr(self, method, getattr(self._file, method))
-
     def __iter__(self) -> Iterator[bytes]:
         """Iterate over lines in the file."""
         return iter(self._file)
@@ -267,20 +272,59 @@ class _GitFile:
         else:
             self.close()
 
+    @property
+    def closed(self) -> bool:
+        """Return whether the file is closed."""
+        return self._closed
+
     def __getattr__(self, name: str) -> Any:  # noqa: ANN401
         """Proxy property calls to the underlying file."""
         if name in self.PROXY_PROPERTIES:
             return getattr(self._file, name)
         raise AttributeError(name)
 
+    # Implement IO[bytes] methods by delegating to the underlying file
+    def read(self, size: int = -1) -> bytes:
+        return self._file.read(size)
+
+    def write(self, data: Buffer, /) -> int:
+        return self._file.write(data)
+
+    def readline(self, size: int = -1) -> bytes:
+        return self._file.readline(size)
+
+    def readlines(self, hint: int = -1) -> list[bytes]:
+        return self._file.readlines(hint)
+
+    def writelines(self, lines: Iterable[Buffer], /) -> None:
+        return self._file.writelines(lines)
+
+    def seek(self, offset: int, whence: int = 0) -> int:
+        return self._file.seek(offset, whence)
+
+    def tell(self) -> int:
+        return self._file.tell()
+
+    def flush(self) -> None:
+        return self._file.flush()
+
+    def truncate(self, size: Optional[int] = None) -> int:
+        return self._file.truncate(size)
+
+    def fileno(self) -> int:
+        return self._file.fileno()
+
+    def isatty(self) -> bool:
+        return self._file.isatty()
+
     def readable(self) -> bool:
-        """Return whether the file is readable."""
         return self._file.readable()
 
     def writable(self) -> bool:
-        """Return whether the file is writable."""
         return self._file.writable()
 
     def seekable(self) -> bool:
-        """Return whether the file is seekable."""
         return self._file.seekable()
+
+    def __next__(self) -> bytes:
+        return next(iter(self._file))

+ 15 - 7
dulwich/filters.py

@@ -30,7 +30,7 @@ from .objects import Blob
 
 if TYPE_CHECKING:
     from .config import StackedConfig
-    from .repo import Repo
+    from .repo import BaseRepo
 
 
 class FilterError(Exception):
@@ -128,7 +128,9 @@ class FilterRegistry:
     """Registry for filter drivers."""
 
     def __init__(
-        self, config: Optional["StackedConfig"] = None, repo: Optional["Repo"] = None
+        self,
+        config: Optional["StackedConfig"] = None,
+        repo: Optional["BaseRepo"] = None,
     ) -> None:
         """Initialize FilterRegistry.
 
@@ -211,8 +213,12 @@ class FilterRegistry:
         required = self.config.get_boolean(("filter", name), "required", False)
 
         if clean_cmd or smudge_cmd:
-            # Get repository working directory
-            repo_path = self.repo.path if self.repo else None
+            # Get repository working directory (only for Repo, not BaseRepo)
+            from .repo import Repo
+
+            repo_path = (
+                self.repo.path if self.repo and isinstance(self.repo, Repo) else None
+            )
             return ProcessFilterDriver(clean_cmd, smudge_cmd, required, repo_path)
 
         return None
@@ -221,8 +227,10 @@ class FilterRegistry:
         """Create LFS filter driver."""
         from .lfs import LFSFilterDriver, LFSStore
 
-        # If we have a repo, use its LFS store
-        if registry.repo is not None:
+        # If we have a Repo (not just BaseRepo), use its LFS store
+        from .repo import Repo
+
+        if registry.repo is not None and isinstance(registry.repo, Repo):
             lfs_store = LFSStore.from_repo(registry.repo, create=True)
         else:
             # Fall back to creating a temporary LFS store
@@ -389,7 +397,7 @@ class FilterBlobNormalizer:
         config_stack: Optional["StackedConfig"],
         gitattributes: GitAttributes,
         filter_registry: Optional[FilterRegistry] = None,
-        repo: Optional["Repo"] = None,
+        repo: Optional["BaseRepo"] = None,
     ) -> None:
         """Initialize FilterBlobNormalizer.
 

+ 24 - 22
dulwich/index.py

@@ -32,13 +32,13 @@ from collections.abc import Generator, Iterable, Iterator
 from dataclasses import dataclass
 from enum import Enum
 from typing import (
+    IO,
     TYPE_CHECKING,
     Any,
     BinaryIO,
     Callable,
     Optional,
     Union,
-    cast,
 )
 
 if TYPE_CHECKING:
@@ -588,7 +588,7 @@ def read_cache_time(f: BinaryIO) -> tuple[int, int]:
     return struct.unpack(">LL", f.read(8))
 
 
-def write_cache_time(f: BinaryIO, t: Union[int, float, tuple[int, int]]) -> None:
+def write_cache_time(f: IO[bytes], t: Union[int, float, tuple[int, int]]) -> None:
     """Write a cache time.
 
     Args:
@@ -664,7 +664,7 @@ def read_cache_entry(
 
 
 def write_cache_entry(
-    f: BinaryIO, entry: SerializedIndexEntry, version: int, previous_path: bytes = b""
+    f: IO[bytes], entry: SerializedIndexEntry, version: int, previous_path: bytes = b""
 ) -> None:
     """Write an index entry to a file.
 
@@ -745,7 +745,7 @@ def read_index_header(f: BinaryIO) -> tuple[int, int]:
     return version, num_entries
 
 
-def write_index_extension(f: BinaryIO, extension: IndexExtension) -> None:
+def write_index_extension(f: IO[bytes], extension: IndexExtension) -> None:
     """Write an index extension.
 
     Args:
@@ -868,7 +868,7 @@ def read_index_dict(
 
 
 def write_index(
-    f: BinaryIO,
+    f: IO[bytes],
     entries: list[SerializedIndexEntry],
     version: Optional[int] = None,
     extensions: Optional[list[IndexExtension]] = None,
@@ -909,7 +909,7 @@ def write_index(
 
 
 def write_index_dict(
-    f: BinaryIO,
+    f: IO[bytes],
     entries: dict[bytes, Union[IndexEntry, ConflictedIndexEntry]],
     version: Optional[int] = None,
     extensions: Optional[list[IndexExtension]] = None,
@@ -1007,8 +1007,6 @@ class Index:
 
     def write(self) -> None:
         """Write current contents of index to disk."""
-        from typing import BinaryIO, cast
-
         f = GitFile(self._filename, "wb")
         try:
             # Filter out extensions with no meaningful data
@@ -1022,7 +1020,7 @@ class Index:
             if self._skip_hash:
                 # When skipHash is enabled, write the index without computing SHA1
                 write_index_dict(
-                    cast(BinaryIO, f),
+                    f,
                     self._byname,
                     version=self._version,
                     extensions=meaningful_extensions,
@@ -1031,9 +1029,9 @@ class Index:
                 f.write(b"\x00" * 20)
                 f.close()
             else:
-                sha1_writer = SHA1Writer(cast(BinaryIO, f))
+                sha1_writer = SHA1Writer(f)
                 write_index_dict(
-                    cast(BinaryIO, sha1_writer),
+                    sha1_writer,
                     self._byname,
                     version=self._version,
                     extensions=meaningful_extensions,
@@ -1050,9 +1048,7 @@ class Index:
         f = GitFile(self._filename, "rb")
         try:
             sha1_reader = SHA1Reader(f)
-            entries, version, extensions = read_index_dict_with_version(
-                cast(BinaryIO, sha1_reader)
-            )
+            entries, version, extensions = read_index_dict_with_version(sha1_reader)
             self._version = version
             self._extensions = extensions
             self.update(entries)
@@ -1411,7 +1407,9 @@ def build_file_from_blob(
     *,
     honor_filemode: bool = True,
     tree_encoding: str = "utf-8",
-    symlink_fn: Optional[Callable] = None,
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ] = None,
 ) -> os.stat_result:
     """Build a file or symlink on disk based on a Git object.
 
@@ -1596,7 +1594,9 @@ def build_index_from_tree(
     tree_id: bytes,
     honor_filemode: bool = True,
     validate_path_element: Callable[[bytes], bool] = validate_path_element_default,
-    symlink_fn: Optional[Callable] = None,
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ] = None,
     blob_normalizer: Optional["BlobNormalizer"] = None,
     tree_encoding: str = "utf-8",
 ) -> None:
@@ -1956,7 +1956,9 @@ def _transition_to_file(
     entry: IndexEntry,
     index: Index,
     honor_filemode: bool,
-    symlink_fn: Optional[Callable[[bytes, bytes], None]],
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ],
     blob_normalizer: Optional["BlobNormalizer"],
     tree_encoding: str = "utf-8",
 ) -> None:
@@ -2208,7 +2210,9 @@ def update_working_tree(
     change_iterator: Iterator["TreeChange"],
     honor_filemode: bool = True,
     validate_path_element: Optional[Callable[[bytes], bool]] = None,
-    symlink_fn: Optional[Callable] = None,
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ] = None,
     force_remove_untracked: bool = False,
     blob_normalizer: Optional["BlobNormalizer"] = None,
     tree_encoding: str = "utf-8",
@@ -2685,10 +2689,8 @@ class locked_index:
             self._file.abort()
             return
         try:
-            from typing import BinaryIO, cast
-
-            f = SHA1Writer(cast(BinaryIO, self._file))
-            write_index_dict(cast(BinaryIO, f), self._index._byname)
+            f = SHA1Writer(self._file)
+            write_index_dict(f, self._index._byname)
         except BaseException:
             self._file.abort()
         else:

+ 110 - 53
dulwich/object_store.py

@@ -33,11 +33,11 @@ from collections.abc import Iterable, Iterator, Sequence
 from contextlib import suppress
 from io import BytesIO
 from typing import (
+    TYPE_CHECKING,
     Callable,
     Optional,
     Protocol,
     Union,
-    cast,
 )
 
 from .errors import NotTreeError
@@ -82,6 +82,18 @@ from .pack import (
 from .protocol import DEPTH_INFINITE
 from .refs import PEELED_TAG_SUFFIX, Ref
 
+if TYPE_CHECKING:
+    from .commit_graph import CommitGraph
+    from .diff_tree import RenameDetector
+
+
+class GraphWalker(Protocol):
+    """Protocol for graph walker objects."""
+
+    def __next__(self) -> Optional[bytes]: ...
+    def ack(self, sha: bytes) -> None: ...
+
+
 INFODIR = "info"
 PACKDIR = "pack"
 
@@ -95,7 +107,9 @@ PACK_MODE = 0o444 if sys.platform != "win32" else 0o644
 DEFAULT_TEMPFILE_GRACE_PERIOD = 14 * 24 * 60 * 60  # 2 weeks
 
 
-def find_shallow(store: 'BaseObjectStore', heads: Any, depth: int) -> tuple:
+def find_shallow(
+    store: ObjectContainer, heads: Iterable[bytes], depth: int
+) -> tuple[set[bytes], set[bytes]]:
     """Find shallow commits according to a given depth.
 
     Args:
@@ -107,7 +121,7 @@ def find_shallow(store: 'BaseObjectStore', heads: Any, depth: int) -> tuple:
         considered shallow and unshallow according to the arguments. Note that
         these sets may overlap if a commit is reachable along multiple paths.
     """
-    parents = {}
+    parents: dict[bytes, list[bytes]] = {}
     commit_graph = store.get_commit_graph()
 
     def get_parents(sha: bytes) -> list[bytes]:
@@ -121,7 +135,9 @@ def find_shallow(store: 'BaseObjectStore', heads: Any, depth: int) -> tuple:
                     parents[sha] = result
                     return result
             # Fall back to loading the object
-            result = store[sha].parents
+            commit = store[sha]
+            assert isinstance(commit, Commit)
+            result = commit.parents
             parents[sha] = result
         return result
 
@@ -150,7 +166,7 @@ def find_shallow(store: 'BaseObjectStore', heads: Any, depth: int) -> tuple:
 
 
 def get_depth(
-    store: 'BaseObjectStore',
+    store: ObjectContainer,
     head: bytes,
     get_parents: Callable = lambda commit: commit.parents,
     max_depth: Optional[int] = None,
@@ -233,7 +249,7 @@ class BaseObjectStore:
         return self.contains_loose(sha1)
 
     @property
-    def packs(self) -> Any:
+    def packs(self) -> list[Pack]:
         """Iterable of pack objects."""
         raise NotImplementedError
 
@@ -251,15 +267,19 @@ class BaseObjectStore:
         type_num, uncomp = self.get_raw(sha1)
         return ShaFile.from_raw_string(type_num, uncomp, sha=sha1)
 
-    def __iter__(self) -> Any:
+    def __iter__(self) -> Iterator[bytes]:
         """Iterate over the SHAs that are present in this store."""
         raise NotImplementedError(self.__iter__)
 
-    def add_object(self, obj: Any) -> None:
+    def add_object(self, obj: ShaFile) -> None:
         """Add a single object to this object store."""
         raise NotImplementedError(self.add_object)
 
-    def add_objects(self, objects: Any, progress: Optional[Callable] = None) -> None:
+    def add_objects(
+        self,
+        objects: Sequence[tuple[ShaFile, Optional[str]]],
+        progress: Optional[Callable] = None,
+    ) -> Optional["Pack"]:
         """Add a set of objects to this object store.
 
         Args:
@@ -275,9 +295,15 @@ class BaseObjectStore:
         want_unchanged: bool = False,
         include_trees: bool = False,
         change_type_same: bool = False,
-        rename_detector: Optional[Any] = None,
-        paths: Optional[Any] = None,
-    ) -> Any:
+        rename_detector: Optional["RenameDetector"] = None,
+        paths: Optional[list[bytes]] = None,
+    ) -> Iterator[
+        tuple[
+            tuple[Optional[bytes], Optional[bytes]],
+            tuple[Optional[int], Optional[int]],
+            tuple[Optional[bytes], Optional[bytes]],
+        ]
+    ]:
         """Find the differences between the contents of two trees.
 
         Args:
@@ -310,7 +336,9 @@ class BaseObjectStore:
                 (change.old.sha, change.new.sha),
             )
 
-    def iter_tree_contents(self, tree_id: bytes, include_trees: bool = False) -> Any:
+    def iter_tree_contents(
+        self, tree_id: bytes, include_trees: bool = False
+    ) -> Iterator[tuple[bytes, int, bytes]]:
         """Iterate the contents of a tree and all subtrees.
 
         Iteration is depth-first pre-order, as in e.g. os.walk.
@@ -352,13 +380,13 @@ class BaseObjectStore:
 
     def find_missing_objects(
         self,
-        haves: Any,
-        wants: Any,
-        shallow: Optional[Any] = None,
+        haves: Iterable[bytes],
+        wants: Iterable[bytes],
+        shallow: Optional[set[bytes]] = None,
         progress: Optional[Callable] = None,
         get_tagged: Optional[Callable] = None,
         get_parents: Callable = lambda commit: commit.parents,
-    ) -> Any:
+    ) -> Iterator[tuple[bytes, Optional[bytes]]]:
         """Find the missing objects required for a set of revisions.
 
         Args:
@@ -385,7 +413,7 @@ class BaseObjectStore:
         )
         return iter(finder)
 
-    def find_common_revisions(self, graphwalker: Any) -> list[bytes]:
+    def find_common_revisions(self, graphwalker: GraphWalker) -> list[bytes]:
         """Find which revisions this store has in common using graphwalker.
 
         Args:
@@ -402,7 +430,12 @@ class BaseObjectStore:
         return haves
 
     def generate_pack_data(
-        self, have: Any, want: Any, shallow: Optional[Any] = None, progress: Optional[Callable] = None, ofs_delta: bool = True
+        self,
+        have: Iterable[bytes],
+        want: Iterable[bytes],
+        shallow: Optional[set[bytes]] = None,
+        progress: Optional[Callable] = None,
+        ofs_delta: bool = True,
     ) -> tuple[int, Iterator[UnpackedObject]]:
         """Generate pack data objects for a set of wants/haves.
 
@@ -439,7 +472,7 @@ class BaseObjectStore:
             DeprecationWarning,
             stacklevel=2,
         )
-        return peel_sha(self, sha)[1]
+        return peel_sha(self, sha)[1].id
 
     def _get_depth(
         self,
@@ -486,7 +519,7 @@ class BaseObjectStore:
             if sha.startswith(prefix):
                 yield sha
 
-    def get_commit_graph(self) -> Optional[Any]:
+    def get_commit_graph(self) -> Optional["CommitGraph"]:
         """Get the commit graph for this object store.
 
         Returns:
@@ -494,7 +527,9 @@ class BaseObjectStore:
         """
         return None
 
-    def write_commit_graph(self, refs: Optional[Any] = None, reachable: bool = True) -> None:
+    def write_commit_graph(
+        self, refs: Optional[list[bytes]] = None, reachable: bool = True
+    ) -> None:
         """Write a commit graph file for this object store.
 
         Args:
@@ -571,8 +606,11 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         raise NotImplementedError(self.add_pack)
 
     def add_pack_data(
-        self, count: int, unpacked_objects: Iterator[UnpackedObject], progress: Optional[Callable] = None
-    ) -> None:
+        self,
+        count: int,
+        unpacked_objects: Iterator[UnpackedObject],
+        progress: Optional[Callable] = None,
+    ) -> Optional["Pack"]:
         """Add pack data to this object store.
 
         Args:
@@ -582,7 +620,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         """
         if count == 0:
             # Don't bother writing an empty pack file
-            return
+            return None
         f, commit, abort = self.add_pack()
         try:
             write_pack_data(
@@ -627,7 +665,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
                 return True
         return False
 
-    def _add_cached_pack(self, base_name: str, pack: Any) -> None:
+    def _add_cached_pack(self, base_name: str, pack: Pack) -> None:
         """Add a newly appeared pack to the cache by path."""
         prev_pack = self._pack_cache.get(base_name)
         if prev_pack is not pack:
@@ -653,7 +691,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         remote_has = missing_objects.get_remote_has()
         object_ids = list(missing_objects)
         return len(object_ids), generate_unpacked_objects(
-            cast(PackedObjectContainer, self),
+            self,
             object_ids,
             progress=progress,
             ofs_delta=ofs_delta,
@@ -667,8 +705,8 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
             (name, pack) = pack_cache.popitem()
             pack.close()
 
-    def _iter_cached_packs(self) -> Any:
-        return self._pack_cache.values()
+    def _iter_cached_packs(self) -> Iterator[Pack]:
+        return iter(self._pack_cache.values())
 
     def _update_pack_cache(self) -> list[Pack]:
         raise NotImplementedError(self._update_pack_cache)
@@ -681,7 +719,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         self._clear_cached_packs()
 
     @property
-    def packs(self) -> Any:
+    def packs(self) -> list[Pack]:
         """List with pack objects."""
         return list(self._iter_cached_packs()) + list(self._update_pack_cache())
 
@@ -699,12 +737,12 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
                 count += 1
         return count
 
-    def _iter_alternate_objects(self) -> Any:
+    def _iter_alternate_objects(self) -> Iterator[bytes]:
         """Iterate over the SHAs of all the objects in alternate stores."""
         for alternate in self.alternates:
             yield from alternate
 
-    def _iter_loose_objects(self) -> Any:
+    def _iter_loose_objects(self) -> Iterator[bytes]:
         """Iterate over the SHAs of all loose objects."""
         raise NotImplementedError(self._iter_loose_objects)
 
@@ -719,7 +757,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         """
         raise NotImplementedError(self.delete_loose_object)
 
-    def _remove_pack(self, name: str) -> None:
+    def _remove_pack(self, pack: "Pack") -> None:
         raise NotImplementedError(self._remove_pack)
 
     def pack_loose_objects(self) -> int:
@@ -727,15 +765,17 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
 
         Returns: Number of objects packed
         """
-        objects = set()
+        objects: list[tuple[ShaFile, None]] = []
         for sha in self._iter_loose_objects():
-            objects.add((self._get_loose_object(sha), None))
-        self.add_objects(list(objects))
+            obj = self._get_loose_object(sha)
+            if obj is not None:
+                objects.append((obj, None))
+        self.add_objects(objects)
         for obj, path in objects:
             self.delete_loose_object(obj.id)
         return len(objects)
 
-    def repack(self, exclude: Optional[set] = None) -> None:
+    def repack(self, exclude: Optional[set] = None) -> int:
         """Repack the packs in this repository.
 
         Note that this implementation is fairly naive and currently keeps all
@@ -751,11 +791,13 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         excluded_loose_objects = set()
         for sha in self._iter_loose_objects():
             if sha not in exclude:
-                loose_objects.add(self._get_loose_object(sha))
+                obj = self._get_loose_object(sha)
+                if obj is not None:
+                    loose_objects.add(obj)
             else:
                 excluded_loose_objects.add(sha)
 
-        objects = {(obj, None) for obj in loose_objects}
+        objects: set[tuple[ShaFile, None]] = {(obj, None) for obj in loose_objects}
         old_packs = {p.name(): p for p in self.packs}
         for name, pack in old_packs.items():
             objects.update(
@@ -767,12 +809,14 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
             # The name of the consolidated pack might match the name of a
             # pre-existing pack. Take care not to remove the newly created
             # consolidated pack.
-            consolidated = self.add_objects(objects)
-            old_packs.pop(consolidated.name(), None)
+            consolidated = self.add_objects(list(objects))
+            if consolidated is not None:
+                old_packs.pop(consolidated.name(), None)
 
         # Delete loose objects that were packed
         for obj in loose_objects:
-            self.delete_loose_object(obj.id)
+            if obj is not None:
+                self.delete_loose_object(obj.id)
         # Delete excluded loose objects
         for sha in excluded_loose_objects:
             self.delete_loose_object(sha)
@@ -928,9 +972,9 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
                 yield o
                 todo.remove(o.id)
         for oid in todo:
-            o = self._get_loose_object(oid)
-            if o is not None:
-                yield o
+            loose_obj: Optional[ShaFile] = self._get_loose_object(oid)
+            if loose_obj is not None:
+                yield loose_obj
             elif not allow_missing:
                 raise KeyError(oid)
 
@@ -978,7 +1022,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         self,
         objects: Sequence[tuple[ShaFile, Optional[str]]],
         progress: Optional[Callable[[str], None]] = None,
-    ) -> None:
+    ) -> Optional["Pack"]:
         """Add a set of objects to this object store.
 
         Args:
@@ -997,6 +1041,8 @@ class DiskObjectStore(PackBasedObjectStore):
 
     path: Union[str, os.PathLike]
     pack_dir: Union[str, os.PathLike]
+    _alternates: Optional[list["DiskObjectStore"]]
+    _commit_graph: Optional["CommitGraph"]
 
     def __init__(
         self,
@@ -1229,7 +1275,7 @@ class DiskObjectStore(PackBasedObjectStore):
 
     def _get_shafile_path(self, sha):
         # Check from object dir
-        return hex_to_filename(self.path, sha)
+        return hex_to_filename(os.fspath(self.path), sha)
 
     def _iter_loose_objects(self):
         for base in os.listdir(self.path):
@@ -1330,9 +1376,9 @@ class DiskObjectStore(PackBasedObjectStore):
         os.remove(pack.index.path)
 
     def _get_pack_basepath(self, entries):
-        suffix = iter_sha1(entry[0] for entry in entries)
+        suffix_bytes = iter_sha1(entry[0] for entry in entries)
         # TODO: Handle self.pack_dir being bytes
-        suffix = suffix.decode("ascii")
+        suffix = suffix_bytes.decode("ascii")
         return os.path.join(self.pack_dir, "pack-" + suffix)
 
     def _complete_pack(self, f, path, num_objects, indexer, progress=None):
@@ -1451,6 +1497,7 @@ class DiskObjectStore(PackBasedObjectStore):
         def commit():
             if f.tell() > 0:
                 f.seek(0)
+
                 with PackData(path, f) as pd:
                     indexer = PackIndexer.for_pack_data(
                         pd, resolve_ext_ref=self.get_raw
@@ -1783,6 +1830,7 @@ class MemoryObjectStore(BaseObjectStore):
             size = f.tell()
             if size > 0:
                 f.seek(0)
+
                 p = PackData.from_file(f, size)
                 for obj in PackInflater.for_pack_data(p, self.get_raw):
                     self.add_object(obj)
@@ -2119,7 +2167,7 @@ class ObjectStoreGraphWalker:
     heads: set[ObjectID]
     """Revisions without descendants in the local repo."""
 
-    get_parents: Callable[[ObjectID], ObjectID]
+    get_parents: Callable[[ObjectID], list[ObjectID]]
     """Function to retrieve parents in the local repo."""
 
     shallow: set[ObjectID]
@@ -2215,7 +2263,7 @@ def commit_tree_changes(object_store, tree, changes):
     """
     # TODO(jelmer): Save up the objects and add them using .add_objects
     # rather than with individual calls to .add_object.
-    nested_changes = {}
+    nested_changes: dict[bytes, list[tuple[bytes, Optional[int], Optional[bytes]]]] = {}
     for path, new_mode, new_sha in changes:
         try:
             (dirname, subpath) = path.split(b"/", 1)
@@ -2450,8 +2498,16 @@ class BucketBasedObjectStore(PackBasedObjectStore):
         """
         # Doesn't exist..
 
-    def _remove_pack(self, name) -> None:
-        raise NotImplementedError(self._remove_pack)
+    def pack_loose_objects(self) -> int:
+        """Pack loose objects. Returns number of objects packed.
+
+        BucketBasedObjectStore doesn't support loose objects, so this is a no-op.
+        """
+        return 0
+
+    def _remove_pack_by_name(self, name: str) -> None:
+        """Remove a pack by name. Subclasses should implement this."""
+        raise NotImplementedError(self._remove_pack_by_name)
 
     def _iter_pack_names(self) -> Iterator[str]:
         raise NotImplementedError(self._iter_pack_names)
@@ -2496,6 +2552,7 @@ class BucketBasedObjectStore(PackBasedObjectStore):
                 return None
 
             pf.seek(0)
+
             p = PackData(pf.name, pf)
             entries = p.sorted_entries()
             basename = iter_sha1(entry[0] for entry in entries).decode("ascii")

+ 2 - 1
dulwich/objects.py

@@ -430,7 +430,8 @@ class ShaFile:
             self._sha = None
             self._chunked_text = self._serialize()
             self._needs_serialization = False
-        return self._chunked_text  # type: ignore
+        assert self._chunked_text is not None
+        return self._chunked_text
 
     def as_raw_string(self) -> bytes:
         """Return raw string with serialization of the object.

+ 30 - 15
dulwich/objectspec.py

@@ -24,6 +24,7 @@
 from typing import TYPE_CHECKING, Optional, Union
 
 from .objects import Commit, ShaFile, Tag, Tree
+from .repo import BaseRepo
 
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
@@ -40,9 +41,9 @@ def to_bytes(text: Union[str, bytes]) -> bytes:
     Returns:
       Bytes representation of text
     """
-    if getattr(text, "encode", None) is not None:
-        text = text.encode("ascii")  # type: ignore
-    return text  # type: ignore
+    if isinstance(text, str):
+        return text.encode("ascii")
+    return text
 
 
 def _resolve_object(repo: "Repo", ref: bytes) -> "ShaFile":
@@ -136,7 +137,9 @@ def parse_object(repo: "Repo", objectish: Union[bytes, str]) -> "ShaFile":
                         raise ValueError(
                             f"Commit {commit.id.decode('ascii', 'replace')} has no parents"
                         )
-                    commit = repo[commit.parents[0]]
+                    parent_obj = repo[commit.parents[0]]
+                    assert isinstance(parent_obj, Commit)
+                    commit = parent_obj
                 obj = commit
             else:  # sep == b"^"
                 # Get N-th parent (or commit itself if N=0)
@@ -157,11 +160,13 @@ def parse_object(repo: "Repo", objectish: Union[bytes, str]) -> "ShaFile":
     return _resolve_object(repo, objectish)
 
 
-def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "Tree":
+def parse_tree(
+    repo: "BaseRepo", treeish: Union[bytes, str, Tree, Commit, Tag]
+) -> "Tree":
     """Parse a string referring to a tree.
 
     Args:
-      repo: A `Repo` object
+      repo: A repository object
       treeish: A string referring to a tree, or a Tree, Commit, or Tag object
     Returns: A Tree object
     Raises:
@@ -173,7 +178,9 @@ def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "
 
     # If it's a Commit, return its tree
     if isinstance(treeish, Commit):
-        return repo[treeish.tree]
+        tree = repo[treeish.tree]
+        assert isinstance(tree, Tree)
+        return tree
 
     # For Tag objects or strings, use the existing logic
     if isinstance(treeish, Tag):
@@ -181,7 +188,7 @@ def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "
     else:
         treeish = to_bytes(treeish)
     try:
-        treeish = parse_ref(repo, treeish)
+        treeish = parse_ref(repo.refs, treeish)
     except KeyError:  # treeish is commit sha
         pass
     try:
@@ -190,15 +197,21 @@ def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "
         # Try parsing as commit (handles short hashes)
         try:
             commit = parse_commit(repo, treeish)
-            return repo[commit.tree]
+            assert isinstance(commit, Commit)
+            tree = repo[commit.tree]
+            assert isinstance(tree, Tree)
+            return tree
         except KeyError:
             raise KeyError(treeish)
-    if o.type_name == b"commit":
-        return repo[o.tree]
-    elif o.type_name == b"tag":
+    if isinstance(o, Commit):
+        tree = repo[o.tree]
+        assert isinstance(tree, Tree)
+        return tree
+    elif isinstance(o, Tag):
         # Tag handling - dereference and recurse
         obj_type, obj_sha = o.object
         return parse_tree(repo, obj_sha)
+    assert isinstance(o, Tree)
     return o
 
 
@@ -383,11 +396,13 @@ def scan_for_short_id(
     raise AmbiguousShortId(prefix, ret)
 
 
-def parse_commit(repo: "Repo", committish: Union[str, bytes, Commit, Tag]) -> "Commit":
+def parse_commit(
+    repo: "BaseRepo", committish: Union[str, bytes, Commit, Tag]
+) -> "Commit":
     """Parse a string referring to a single commit.
 
     Args:
-      repo: A` Repo` object
+      repo: A repository object
       committish: A string referring to a single commit, or a Commit or Tag object.
     Returns: A Commit object
     Raises:
@@ -426,7 +441,7 @@ def parse_commit(repo: "Repo", committish: Union[str, bytes, Commit, Tag]) -> "C
     else:
         return dereference_tag(obj)
     try:
-        obj = repo[parse_ref(repo, committish)]
+        obj = repo[parse_ref(repo.refs, committish)]
     except KeyError:
         pass
     else:

+ 137 - 46
dulwich/pack.py

@@ -43,6 +43,7 @@ try:
 except ModuleNotFoundError:
     from difflib import SequenceMatcher
 
+import hashlib
 import os
 import struct
 import sys
@@ -53,6 +54,7 @@ from hashlib import sha1
 from itertools import chain
 from os import SEEK_CUR, SEEK_END
 from struct import unpack_from
+from types import TracebackType
 from typing import (
     IO,
     TYPE_CHECKING,
@@ -64,6 +66,7 @@ from typing import (
     Protocol,
     TypeVar,
     Union,
+    cast,
 )
 
 try:
@@ -129,12 +132,13 @@ class ObjectContainer(Protocol):
         self,
         objects: Sequence[tuple[ShaFile, Optional[str]]],
         progress: Optional[Callable[[str], None]] = None,
-    ) -> None:
+    ) -> Optional["Pack"]:
         """Add a set of objects to this object store.
 
         Args:
           objects: Iterable over a list of (object, path) tuples
           progress: Progress callback for object insertion
+        Returns: Optional Pack object of the objects written.
         """
 
     def __contains__(self, sha1: bytes) -> bool:
@@ -331,6 +335,7 @@ class UnpackedObject:
     def sha(self) -> bytes:
         """Return the binary SHA of this object."""
         if self._sha is None:
+            assert self.obj_type_num is not None and self.obj_chunks is not None
             self._sha = obj_sha(self.obj_type_num, self.obj_chunks)
         return self._sha
 
@@ -643,6 +648,18 @@ class PackIndex:
         """Yield all the SHA1's of the objects in the index, sorted."""
         raise NotImplementedError(self._itersha)
 
+    def iter_prefix(self, prefix: bytes) -> Iterator[bytes]:
+        """Iterate over all SHA1s with the given prefix.
+
+        Args:
+            prefix: Binary prefix to match
+        Returns: Iterator of matching SHA1s
+        """
+        # Default implementation for PackIndex classes that don't override
+        for sha, _, _ in self.iterentries():
+            if sha.startswith(prefix):
+                yield sha
+
     def close(self) -> None:
         """Close any open files."""
 
@@ -729,11 +746,12 @@ class FilePackIndex(PackIndex):
     """
 
     _fan_out_table: list[int]
+    _file: Union[IO[bytes], _GitFile]
 
     def __init__(
         self,
         filename: Union[str, os.PathLike],
-        file: Optional[BinaryIO] = None,
+        file: Optional[Union[IO[bytes], _GitFile]] = None,
         contents: Optional[Union[bytes, "mmap.mmap"]] = None,
         size: Optional[int] = None,
     ) -> None:
@@ -924,7 +942,11 @@ class PackIndex1(FilePackIndex):
     """Version 1 Pack Index file."""
 
     def __init__(
-        self, filename: Union[str, os.PathLike], file: Optional[BinaryIO] = None, contents: Optional[bytes] = None, size: Optional[int] = None
+        self,
+        filename: Union[str, os.PathLike],
+        file: Optional[Union[IO[bytes], _GitFile]] = None,
+        contents: Optional[bytes] = None,
+        size: Optional[int] = None,
     ) -> None:
         """Initialize a version 1 pack index.
 
@@ -959,7 +981,11 @@ class PackIndex2(FilePackIndex):
     """Version 2 Pack Index file."""
 
     def __init__(
-        self, filename: Union[str, os.PathLike], file: Optional[BinaryIO] = None, contents: Optional[bytes] = None, size: Optional[int] = None
+        self,
+        filename: Union[str, os.PathLike],
+        file: Optional[Union[IO[bytes], _GitFile]] = None,
+        contents: Optional[bytes] = None,
+        size: Optional[int] = None,
     ) -> None:
         """Initialize a version 2 pack index.
 
@@ -1013,7 +1039,11 @@ class PackIndex3(FilePackIndex):
     """
 
     def __init__(
-        self, filename: Union[str, os.PathLike], file: Optional[BinaryIO] = None, contents: Optional[bytes] = None, size: Optional[int] = None
+        self,
+        filename: Union[str, os.PathLike],
+        file: Optional[Union[IO[bytes], _GitFile]] = None,
+        contents: Optional[bytes] = None,
+        size: Optional[int] = None,
     ) -> None:
         """Initialize a version 3 pack index.
 
@@ -1201,7 +1231,12 @@ class PackStreamReader:
     appropriate.
     """
 
-    def __init__(self, read_all: Callable[[int], bytes], read_some: Optional[Callable[[int], bytes]] = None, zlib_bufsize: int = _ZLIB_BUFSIZE) -> None:
+    def __init__(
+        self,
+        read_all: Callable[[int], bytes],
+        read_some: Optional[Callable[[int], bytes]] = None,
+        zlib_bufsize: int = _ZLIB_BUFSIZE,
+    ) -> None:
         self.read_all = read_all
         if read_some is None:
             self.read_some = read_all
@@ -1211,7 +1246,7 @@ class PackStreamReader:
         self._offset = 0
         self._rbuf = BytesIO()
         # trailer is a deque to avoid memory allocation on small reads
-        self._trailer: deque[bytes] = deque()
+        self._trailer: deque[int] = deque()
         self._zlib_bufsize = zlib_bufsize
 
     def _read(self, read: Callable[[int], bytes], size: int) -> bytes:
@@ -1343,7 +1378,13 @@ class PackStreamCopier(PackStreamReader):
     appropriate and written out to the given file-like object.
     """
 
-    def __init__(self, read_all: Callable, read_some: Callable, outfile: Any, delta_iter: Optional[Any] = None) -> None:
+    def __init__(
+        self,
+        read_all: Callable,
+        read_some: Callable,
+        outfile: IO[bytes],
+        delta_iter: Optional["DeltaChainIterator"] = None,
+    ) -> None:
         """Initialize the copier.
 
         Args:
@@ -1393,7 +1434,9 @@ def obj_sha(type: int, chunks: Union[bytes, Iterable[bytes]]) -> bytes:
     return sha.digest()
 
 
-def compute_file_sha(f: IO[bytes], start_ofs: int = 0, end_ofs: int = 0, buffer_size: int = 1 << 16) -> "sha1":
+def compute_file_sha(
+    f: IO[bytes], start_ofs: int = 0, end_ofs: int = 0, buffer_size: int = 1 << 16
+) -> "hashlib._Hash":
     """Hash a portion of a file into a new SHA.
 
     Args:
@@ -1450,7 +1493,7 @@ class PackData:
     def __init__(
         self,
         filename: Union[str, os.PathLike],
-        file: Optional[Any] = None,
+        file: Optional[IO[bytes]] = None,
         size: Optional[int] = None,
         *,
         delta_window_size: Optional[int] = None,
@@ -1477,6 +1520,7 @@ class PackData:
         self.depth = depth
         self.threads = threads
         self.big_file_threshold = big_file_threshold
+        self._file: IO[bytes]
 
         if file is None:
             self._file = GitFile(self._filename, "rb")
@@ -1500,7 +1544,7 @@ class PackData:
         return os.path.basename(self._filename)
 
     @property
-    def path(self) -> str:
+    def path(self) -> Union[str, os.PathLike]:
         """Get the full path of the pack file.
 
         Returns:
@@ -1509,7 +1553,7 @@ class PackData:
         return self._filename
 
     @classmethod
-    def from_file(cls, file: Any, size: Optional[int] = None) -> 'PackData':
+    def from_file(cls, file: IO[bytes], size: Optional[int] = None) -> "PackData":
         """Create a PackData object from an open file.
 
         Args:
@@ -1522,7 +1566,7 @@ class PackData:
         return cls(str(file), file=file, size=size)
 
     @classmethod
-    def from_path(cls, path: Union[str, os.PathLike]) -> 'PackData':
+    def from_path(cls, path: Union[str, os.PathLike]) -> "PackData":
         """Create a PackData object from a file path.
 
         Args:
@@ -1537,15 +1581,20 @@ class PackData:
         """Close the underlying pack file."""
         self._file.close()
 
-    def __enter__(self) -> 'PackData':
+    def __enter__(self) -> "PackData":
         """Enter context manager."""
         return self
 
-    def __exit__(self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional[Any]) -> None:
+    def __exit__(
+        self,
+        exc_type: Optional[type],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
         """Exit context manager."""
         self.close()
 
-    def __eq__(self, other: Any) -> bool:
+    def __eq__(self, other: object) -> bool:
         if isinstance(other, PackData):
             return self.get_stored_checksum() == other.get_stored_checksum()
         return False
@@ -1568,9 +1617,9 @@ class PackData:
 
         Returns: 20-byte binary SHA1 digest
         """
-        return compute_file_sha(self._file, end_ofs=-20).digest()
+        return compute_file_sha(cast(IO[bytes], self._file), end_ofs=-20).digest()
 
-    def iter_unpacked(self, *, include_comp: bool = False) -> Any:
+    def iter_unpacked(self, *, include_comp: bool = False) -> Iterator[UnpackedObject]:
         self._file.seek(self._header_size)
 
         if self._num_objects is None:
@@ -1608,7 +1657,7 @@ class PackData:
         self,
         progress: Optional[ProgressFn] = None,
         resolve_ext_ref: Optional[ResolveExtRefFn] = None,
-    ) -> Any:
+    ) -> list[tuple[bytes, int, int]]:
         """Return entries in this pack, sorted by SHA.
 
         Args:
@@ -1621,7 +1670,12 @@ class PackData:
             self.iterentries(progress=progress, resolve_ext_ref=resolve_ext_ref)
         )
 
-    def create_index_v1(self, filename: str, progress: Optional[Callable] = None, resolve_ext_ref: Optional[Callable] = None) -> bytes:
+    def create_index_v1(
+        self,
+        filename: str,
+        progress: Optional[Callable] = None,
+        resolve_ext_ref: Optional[Callable] = None,
+    ) -> bytes:
         """Create a version 1 file for this data file.
 
         Args:
@@ -1636,7 +1690,12 @@ class PackData:
         with GitFile(filename, "wb") as f:
             return write_pack_index_v1(f, entries, self.calculate_checksum())
 
-    def create_index_v2(self, filename: str, progress: Optional[Callable] = None, resolve_ext_ref: Optional[Callable] = None) -> bytes:
+    def create_index_v2(
+        self,
+        filename: str,
+        progress: Optional[Callable] = None,
+        resolve_ext_ref: Optional[Callable] = None,
+    ) -> bytes:
         """Create a version 2 index file for this data file.
 
         Args:
@@ -1652,7 +1711,11 @@ class PackData:
             return write_pack_index_v2(f, entries, self.calculate_checksum())
 
     def create_index_v3(
-        self, filename: str, progress: Optional[Callable] = None, resolve_ext_ref: Optional[Callable] = None, hash_algorithm: int = 1
+        self,
+        filename: str,
+        progress: Optional[Callable] = None,
+        resolve_ext_ref: Optional[Callable] = None,
+        hash_algorithm: int = 1,
     ) -> bytes:
         """Create a version 3 index file for this data file.
 
@@ -1672,7 +1735,12 @@ class PackData:
             )
 
     def create_index(
-        self, filename: str, progress: Optional[Callable] = None, version: int = 2, resolve_ext_ref: Optional[Callable] = None, hash_algorithm: int = 1
+        self,
+        filename: str,
+        progress: Optional[Callable] = None,
+        version: int = 2,
+        resolve_ext_ref: Optional[Callable] = None,
+        hash_algorithm: int = 1,
     ) -> bytes:
         """Create an  index file for this data file.
 
@@ -1766,14 +1834,13 @@ class DeltaChainIterator(Generic[T]):
     _compute_crc32 = False
     _include_comp = False
 
-    def __init__(self, file_obj, *, resolve_ext_ref=None) -> None:
+    def __init__(self, file_obj: Any, *, resolve_ext_ref: Optional[Callable] = None) -> None:
         """Initialize DeltaChainIterator.
 
         Args:
             file_obj: File object to read pack data from
             resolve_ext_ref: Optional function to resolve external references
         """
-    def __init__(self, file_obj: Any, *, resolve_ext_ref: Optional[Callable] = None) -> None:
         self._file = file_obj
         self._resolve_ext_ref = resolve_ext_ref
         self._pending_ofs: dict[int, list[int]] = defaultdict(list)
@@ -1782,7 +1849,9 @@ class DeltaChainIterator(Generic[T]):
         self._ext_refs: list[bytes] = []
 
     @classmethod
-    def for_pack_data(cls, pack_data: PackData, resolve_ext_ref: Optional[Callable] = None) -> 'DeltaChainIterator':
+    def for_pack_data(
+        cls, pack_data: PackData, resolve_ext_ref: Optional[Callable] = None
+    ) -> "DeltaChainIterator":
         """Create a DeltaChainIterator from pack data.
 
         Args:
@@ -1878,7 +1947,7 @@ class DeltaChainIterator(Generic[T]):
         """
         self._file = pack_data._file
 
-    def _walk_all_chains(self) -> Any:
+    def _walk_all_chains(self) -> Iterator[T]:
         for offset, type_num in self._full_ofs:
             yield from self._follow_chain(offset, type_num, None)
         yield from self._walk_ref_chains()
@@ -1888,7 +1957,7 @@ class DeltaChainIterator(Generic[T]):
         if self._pending_ref:
             raise UnresolvedDeltas([sha_to_hex(s) for s in self._pending_ref])
 
-    def _walk_ref_chains(self) -> Any:
+    def _walk_ref_chains(self) -> Iterator[T]:
         if not self._resolve_ext_ref:
             self._ensure_no_pending()
             return
@@ -1910,12 +1979,13 @@ class DeltaChainIterator(Generic[T]):
 
         self._ensure_no_pending()
 
-    def _result(self, unpacked: UnpackedObject) -> Any:
+    def _result(self, unpacked: UnpackedObject) -> T:
         raise NotImplementedError
 
     def _resolve_object(
         self, offset: int, obj_type_num: int, base_chunks: Optional[list[bytes]]
     ) -> UnpackedObject:
+        assert self._file is not None
         self._file.seek(offset)
         unpacked, _ = unpack_object(
             self._file.read,
@@ -1931,7 +2001,9 @@ class DeltaChainIterator(Generic[T]):
             unpacked.obj_chunks = apply_delta(base_chunks, unpacked.decomp_chunks)
         return unpacked
 
-    def _follow_chain(self, offset: int, obj_type_num: int, base_chunks: list[bytes]) -> Iterator[T]:
+    def _follow_chain(
+        self, offset: int, obj_type_num: int, base_chunks: Optional[list[bytes]]
+    ) -> Iterator[T]:
         # Unlike PackData.get_object_at, there is no need to cache offsets as
         # this approach by design inflates each object exactly once.
         todo = [(offset, obj_type_num, base_chunks)]
@@ -2004,19 +2076,19 @@ class PackInflater(DeltaChainIterator[ShaFile]):
         Returns:
             ShaFile object from the unpacked data
         """
+    def _result(self, unpacked: UnpackedObject) -> ShaFile:
         return unpacked.sha_file()
 
 
 class SHA1Reader(BinaryIO):
     """Wrapper for file-like object that remembers the SHA1 of its data."""
 
-    def __init__(self, f) -> None:
+    def __init__(self, f: IO[bytes]) -> None:
         """Initialize SHA1Reader.
 
         Args:
             f: File-like object to wrap
         """
-    def __init__(self, f: BinaryIO) -> None:
         self.f = f
         self.sha1 = sha1(b"")
 
@@ -2126,19 +2198,24 @@ class SHA1Reader(BinaryIO):
         """
         raise UnsupportedOperation("write")
 
-    def writelines(self, lines: Any) -> None:
+    def writelines(self, lines: Iterable[bytes], /) -> None:  # type: ignore[override]
         raise UnsupportedOperation("writelines")
 
-    def write(self, data: bytes) -> int:
+    def write(self, data: bytes, /) -> int:  # type: ignore[override]
         raise UnsupportedOperation("write")
 
-    def __enter__(self) -> 'SHA1Reader':
+    def __enter__(self) -> "SHA1Reader":
         return self
 
-    def __exit__(self, type: Optional[type], value: Optional[BaseException], traceback: Optional[Any]) -> None:
+    def __exit__(
+        self,
+        type: Optional[type],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
         self.close()
 
-    def __iter__(self) -> 'SHA1Reader':
+    def __iter__(self) -> "SHA1Reader":
         return self
 
     def __next__(self) -> bytes:
@@ -2181,10 +2258,10 @@ class SHA1Writer(BinaryIO):
         Args:
             f: File-like object to wrap
         """
-    def __init__(self, f: BinaryIO) -> None:
         self.f = f
         self.length = 0
         self.sha1 = sha1(b"")
+        self.digest: Optional[bytes] = None
 
     def write(self, data) -> int:
         """Write data and update SHA1.
@@ -2195,7 +2272,6 @@ class SHA1Writer(BinaryIO):
         Returns:
             Number of bytes written
         """
-    def write(self, data: bytes) -> int:
         self.sha1.update(data)
         self.f.write(data)
         self.length += len(data)
@@ -2220,8 +2296,9 @@ class SHA1Writer(BinaryIO):
             The SHA1 digest bytes
         """
         sha = self.write_sha()
+    def close(self) -> None:
+        self.digest = self.write_sha()
         self.f.close()
-        return sha
 
     def offset(self) -> int:
         """Get the total number of bytes written.
@@ -2281,13 +2358,12 @@ class SHA1Writer(BinaryIO):
         """
         raise UnsupportedOperation("readlines")
 
-    def writelines(self, lines) -> None:
+    def writelines(self, lines: Iterable[bytes]) -> None:
         """Write multiple lines to the file.
 
         Args:
             lines: Iterable of lines to write
         """
-    def writelines(self, lines: Any) -> None:
         for line in lines:
             self.write(line)
 
@@ -2309,6 +2385,18 @@ class SHA1Writer(BinaryIO):
 
     def __iter__(self) -> 'SHA1Writer':
         """Return iterator."""
+    def __enter__(self) -> "SHA1Writer":
+        return self
+
+    def __exit__(
+        self,
+        type: Optional[type],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
+        self.close()
+
+    def __iter__(self) -> "SHA1Writer":
         return self
 
     def __next__(self) -> bytes:
@@ -2498,7 +2586,7 @@ def find_reusable_deltas(
 
 
 def deltify_pack_objects(
-    objects: Union[Iterator[bytes], Iterator[tuple[ShaFile, Optional[bytes]]]],
+    objects: Union[Iterator[ShaFile], Iterator[tuple[ShaFile, Optional[bytes]]]],
     *,
     window_size: Optional[int] = None,
     progress=None,
@@ -2927,7 +3015,7 @@ def write_pack_index_v1(f: BinaryIO, entries: list[tuple[bytes, int, Optional[in
     Returns: The SHA of the written index file
     """
     f = SHA1Writer(f)
-    fan_out_table = defaultdict(lambda: 0)
+    fan_out_table: dict[int, int] = defaultdict(lambda: 0)
     for name, _offset, _entry_checksum in entries:
         fan_out_table[ord(name[:1])] += 1
     # Fan-out table
@@ -3489,6 +3577,7 @@ class Pack:
             ofs: (sha, crc32) for (sha, ofs, crc32) in self.index.iterentries()
         }
         for unpacked in self.data.iter_unpacked(include_comp=include_comp):
+            assert unpacked.offset is not None
             (sha, crc32) = ofs_to_entries[unpacked.offset]
             unpacked._sha = sha
             unpacked.crc32 = crc32
@@ -3589,8 +3678,10 @@ class Pack:
             object count
         Returns: Iterator of tuples with (sha, offset, crc32)
         """
-        return self.data.sorted_entries(
-            progress=progress, resolve_ext_ref=self.resolve_ext_ref
+        return iter(
+            self.data.sorted_entries(
+                progress=progress, resolve_ext_ref=self.resolve_ext_ref
+            )
         )
 
     def get_unpacked_object(

+ 8 - 9
dulwich/patch.py

@@ -30,6 +30,7 @@ import time
 from collections.abc import Generator
 from difflib import SequenceMatcher
 from typing import (
+    IO,
     TYPE_CHECKING,
     BinaryIO,
     Optional,
@@ -48,7 +49,7 @@ FIRST_FEW_BYTES = 8000
 
 
 def write_commit_patch(
-    f: BinaryIO,
+    f: IO[bytes],
     commit: "Commit",
     contents: Union[str, bytes],
     progress: tuple[int, int],
@@ -231,7 +232,7 @@ def patch_filename(p: Optional[bytes], root: bytes) -> bytes:
 
 
 def write_object_diff(
-    f: BinaryIO,
+    f: IO[bytes],
     store: "BaseObjectStore",
     old_file: tuple[Optional[bytes], Optional[int], Optional[bytes]],
     new_file: tuple[Optional[bytes], Optional[int], Optional[bytes]],
@@ -264,19 +265,17 @@ def write_object_diff(
         Returns:
             Blob object
         """
-        from typing import cast
-
         if hexsha is None:
-            return cast(Blob, Blob.from_string(b""))
+            return Blob.from_string(b"")
         elif mode is not None and S_ISGITLINK(mode):
-            return cast(Blob, Blob.from_string(b"Subproject commit " + hexsha + b"\n"))
+            return Blob.from_string(b"Subproject commit " + hexsha + b"\n")
         else:
             obj = store[hexsha]
             if isinstance(obj, Blob):
                 return obj
             else:
                 # Fallback for non-blob objects
-                return cast(Blob, Blob.from_string(obj.as_raw_string()))
+                return Blob.from_string(obj.as_raw_string())
 
     def lines(content: "Blob") -> list[bytes]:
         """Split blob content into lines.
@@ -356,7 +355,7 @@ def gen_diff_header(
 
 # TODO(jelmer): Support writing unicode, rather than bytes.
 def write_blob_diff(
-    f: BinaryIO,
+    f: IO[bytes],
     old_file: tuple[Optional[bytes], Optional[int], Optional["Blob"]],
     new_file: tuple[Optional[bytes], Optional[int], Optional["Blob"]],
 ) -> None:
@@ -403,7 +402,7 @@ def write_blob_diff(
 
 
 def write_tree_diff(
-    f: BinaryIO,
+    f: IO[bytes],
     store: "BaseObjectStore",
     old_tree: Optional[bytes],
     new_tree: Optional[bytes],

文件差異過大導致無法顯示
+ 331 - 117
dulwich/porcelain.py


+ 0 - 12
dulwich/protocol.py

@@ -258,18 +258,6 @@ def pkt_seq(*seq: Optional[bytes]) -> bytes:
     return b"".join([pkt_line(s) for s in seq]) + pkt_line(None)
 
 
-def filter_ref_prefix(
-    refs: dict[bytes, bytes], prefixes: Iterable[bytes]
-) -> dict[bytes, bytes]:
-    """Filter refs to only include those with a given prefix.
-
-    Args:
-      refs: A list of refs.
-      prefixes: The prefixes to filter by.
-    """
-    return {k: v for k, v in refs.items() if any(k.startswith(p) for p in prefixes)}
-
-
 class Protocol:
     """Class for interacting with a remote git process over the wire.
 

+ 13 - 3
dulwich/rebase.py

@@ -32,7 +32,7 @@ from dulwich.graph import find_merge_base
 from dulwich.merge import three_way_merge
 from dulwich.objects import Commit
 from dulwich.objectspec import parse_commit
-from dulwich.repo import Repo
+from dulwich.repo import BaseRepo, Repo
 
 
 class RebaseError(Exception):
@@ -529,7 +529,7 @@ class DiskRebaseStateManager:
 class MemoryRebaseStateManager:
     """Manages rebase state in memory for MemoryRepo."""
 
-    def __init__(self, repo: Repo) -> None:
+    def __init__(self, repo: BaseRepo) -> None:
         """Initialize MemoryRebaseStateManager.
 
         Args:
@@ -642,6 +642,8 @@ class Rebaser:
         if branch is None:
             # Use current HEAD
             head_ref, head_sha = self.repo.refs.follow(b"HEAD")
+            if head_sha is None:
+                raise ValueError("HEAD does not point to a valid commit")
             branch_commit = self.repo[head_sha]
         else:
             # Parse the branch reference
@@ -664,6 +666,7 @@ class Rebaser:
         commits = []
         current = branch_commit
         while current.id != merge_base:
+            assert isinstance(current, Commit)
             commits.append(current)
             if not current.parents:
                 break
@@ -691,6 +694,9 @@ class Rebaser:
         parent = self.repo[commit.parents[0]]
         onto_commit = self.repo[onto]
 
+        assert isinstance(parent, Commit)
+        assert isinstance(onto_commit, Commit)
+
         # Perform three-way merge
         merged_tree, conflicts = three_way_merge(
             self.object_store, parent, onto_commit, commit
@@ -798,7 +804,9 @@ class Rebaser:
 
         if new_sha:
             # Success - add to done list
-            self._done.append(self.repo[new_sha])
+            new_commit = self.repo[new_sha]
+            assert isinstance(new_commit, Commit)
+            self._done.append(new_commit)
             self._save_rebase_state()
 
             # Continue with next commit if any
@@ -822,6 +830,8 @@ class Rebaser:
             raise RebaseError("No rebase in progress")
 
         # Restore original HEAD
+        if self._original_head is None:
+            raise RebaseError("No original HEAD to restore")
         self.repo.refs[b"HEAD"] = self._original_head
 
         # Clean up rebase state

+ 146 - 44
dulwich/refs.py

@@ -25,9 +25,19 @@
 import os
 import types
 import warnings
-from collections.abc import Iterator
+from collections.abc import Iterable, Iterator
 from contextlib import suppress
-from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+from typing import (
+    IO,
+    TYPE_CHECKING,
+    Any,
+    BinaryIO,
+    Callable,
+    Optional,
+    TypeVar,
+    Union,
+    cast,
+)
 
 if TYPE_CHECKING:
     from .file import _GitFile
@@ -136,7 +146,23 @@ def parse_remote_ref(ref: bytes) -> tuple[bytes, bytes]:
 class RefsContainer:
     """A container for refs."""
 
-    def __init__(self, logger: Optional[Callable[[bytes, Optional[bytes], Optional[bytes], Optional[bytes], Optional[int], Optional[int], Optional[bytes]], None]] = None) -> None:
+    def __init__(
+        self,
+        logger: Optional[
+            Callable[
+                [
+                    bytes,
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[int],
+                    Optional[int],
+                    Optional[bytes],
+                ],
+                None,
+            ]
+        ] = None,
+    ) -> None:
         self._logger = logger
 
     def _log(
@@ -246,14 +272,14 @@ class RefsContainer:
         for ref in to_delete:
             self.remove_if_equals(b"/".join((base, ref)), None, message=message)
 
-    def allkeys(self) -> Iterator[Ref]:
+    def allkeys(self) -> set[Ref]:
         """All refs present in this container."""
         raise NotImplementedError(self.allkeys)
 
     def __iter__(self) -> Iterator[Ref]:
         return iter(self.allkeys())
 
-    def keys(self, base: Optional[bytes] = None) -> Union[Iterator[Ref], set[bytes]]:
+    def keys(self, base=None):
         """Refs present in this container.
 
         Args:
@@ -339,16 +365,16 @@ class RefsContainer:
         """
         raise NotImplementedError(self.read_loose_ref)
 
-    def follow(self, name: bytes) -> tuple[list[bytes], bytes]:
+    def follow(self, name: bytes) -> tuple[list[bytes], Optional[bytes]]:
         """Follow a reference name.
 
         Returns: a tuple of (refnames, sha), wheres refnames are the names of
             references in the chain
         """
-        contents = SYMREF + name
+        contents: Optional[bytes] = SYMREF + name
         depth = 0
         refnames = []
-        while contents.startswith(SYMREF):
+        while contents and contents.startswith(SYMREF):
             refname = contents[len(SYMREF) :]
             refnames.append(refname)
             contents = self.read_ref(refname)
@@ -404,7 +430,13 @@ class RefsContainer:
         raise NotImplementedError(self.set_if_equals)
 
     def add_if_new(
-        self, name: bytes, ref: bytes, committer: Optional[bytes] = None, timestamp: Optional[int] = None, timezone: Optional[int] = None, message: Optional[bytes] = None
+        self,
+        name: bytes,
+        ref: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
         """Add a new reference only if it does not already exist.
 
@@ -486,7 +518,9 @@ class RefsContainer:
         ret = {}
         for src in self.allkeys():
             try:
-                dst = parse_symref_value(self.read_ref(src))
+                ref_value = self.read_ref(src)
+                assert ref_value is not None
+                dst = parse_symref_value(ref_value)
             except ValueError:
                 pass
             else:
@@ -509,14 +543,31 @@ class DictRefsContainer(RefsContainer):
     threadsafe.
     """
 
-    def __init__(self, refs: dict[bytes, bytes], logger: Optional[Callable[[bytes, Optional[bytes], Optional[bytes], Optional[bytes], Optional[int], Optional[int], Optional[bytes]], None]] = None) -> None:
+    def __init__(
+        self,
+        refs: dict[bytes, bytes],
+        logger: Optional[
+            Callable[
+                [
+                    bytes,
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[int],
+                    Optional[int],
+                    Optional[bytes],
+                ],
+                None,
+            ]
+        ] = None,
+    ) -> None:
         super().__init__(logger=logger)
         self._refs = refs
         self._peeled: dict[bytes, ObjectID] = {}
         self._watchers: set[Any] = set()
 
-    def allkeys(self) -> Iterator[bytes]:
-        return self._refs.keys()
+    def allkeys(self) -> set[bytes]:
+        return set(self._refs.keys())
 
     def read_loose_ref(self, name: bytes) -> Optional[bytes]:
         return self._refs.get(name, None)
@@ -707,14 +758,14 @@ class DictRefsContainer(RefsContainer):
 class InfoRefsContainer(RefsContainer):
     """Refs container that reads refs from a info/refs file."""
 
-    def __init__(self, f: Any) -> None:
-        self._refs = {}
-        self._peeled = {}
+    def __init__(self, f: BinaryIO) -> None:
+        self._refs: dict[bytes, bytes] = {}
+        self._peeled: dict[bytes, bytes] = {}
         refs = read_info_refs(f)
         (self._refs, self._peeled) = split_peeled_refs(refs)
 
-    def allkeys(self) -> Iterator[bytes]:
-        return self._refs.keys()
+    def allkeys(self) -> set[bytes]:
+        return set(self._refs.keys())
 
     def read_loose_ref(self, name: bytes) -> Optional[bytes]:
         return self._refs.get(name, None)
@@ -736,7 +787,20 @@ class DiskRefsContainer(RefsContainer):
         self,
         path: Union[str, bytes, os.PathLike],
         worktree_path: Optional[Union[str, bytes, os.PathLike]] = None,
-        logger: Optional[Callable[[bytes, Optional[bytes], Optional[bytes], Optional[bytes], Optional[int], Optional[int], Optional[bytes]], None]] = None,
+        logger: Optional[
+            Callable[
+                [
+                    bytes,
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[int],
+                    Optional[int],
+                    Optional[bytes],
+                ],
+                None,
+            ]
+        ] = None,
     ) -> None:
         """Initialize DiskRefsContainer."""
         super().__init__(logger=logger)
@@ -746,8 +810,8 @@ class DiskRefsContainer(RefsContainer):
             self.worktree_path = self.path
         else:
             self.worktree_path = os.fsencode(os.fspath(worktree_path))
-        self._packed_refs = None
-        self._peeled_refs = None
+        self._packed_refs: Optional[dict[bytes, bytes]] = None
+        self._peeled_refs: Optional[dict[bytes, bytes]] = None
 
     def __repr__(self) -> str:
         """Return string representation of DiskRefsContainer."""
@@ -772,7 +836,7 @@ class DiskRefsContainer(RefsContainer):
                 subkeys.add(key[len(base) :].strip(b"/"))
         return subkeys
 
-    def allkeys(self) -> Iterator[bytes]:
+    def allkeys(self) -> set[bytes]:
         allkeys = set()
         if os.path.exists(self.refpath(HEADREF)):
             allkeys.add(HEADREF)
@@ -878,7 +942,11 @@ class DiskRefsContainer(RefsContainer):
             to a tag, but no cached information is available, None is returned.
         """
         self.get_packed_refs()
-        if self._peeled_refs is None or name not in self._packed_refs:
+        if (
+            self._peeled_refs is None
+            or self._packed_refs is None
+            or name not in self._packed_refs
+        ):
             # No cache: no peeled refs were read, or this ref is loose
             return None
         if name in self._peeled_refs:
@@ -927,13 +995,14 @@ class DiskRefsContainer(RefsContainer):
             self._packed_refs = None
             self.get_packed_refs()
 
-            if name not in self._packed_refs:
+            if self._packed_refs is None or name not in self._packed_refs:
                 f.abort()
                 return
 
             del self._packed_refs[name]
-            with suppress(KeyError):
-                del self._peeled_refs[name]
+            if self._peeled_refs is not None:
+                with suppress(KeyError):
+                    del self._peeled_refs[name]
             write_packed_refs(f, self._packed_refs, self._peeled_refs)
             f.close()
         except BaseException:
@@ -1240,7 +1309,7 @@ def _split_ref_line(line: bytes) -> tuple[bytes, bytes]:
     return (sha, name)
 
 
-def read_packed_refs(f: Any) -> Iterator[tuple[bytes, bytes]]:
+def read_packed_refs(f: IO[bytes]) -> Iterator[tuple[bytes, bytes]]:
     """Read a packed refs file.
 
     Args:
@@ -1256,7 +1325,9 @@ def read_packed_refs(f: Any) -> Iterator[tuple[bytes, bytes]]:
         yield _split_ref_line(line)
 
 
-def read_packed_refs_with_peeled(f: Any) -> Iterator[tuple[bytes, bytes, Optional[bytes]]]:
+def read_packed_refs_with_peeled(
+    f: IO[bytes],
+) -> Iterator[tuple[bytes, bytes, Optional[bytes]]]:
     """Read a packed refs file including peeled refs.
 
     Assumes the "# pack-refs with: peeled" line was already read. Yields tuples
@@ -1288,7 +1359,11 @@ def read_packed_refs_with_peeled(f: Any) -> Iterator[tuple[bytes, bytes, Optiona
         yield (sha, name, None)
 
 
-def write_packed_refs(f: Any, packed_refs: dict[bytes, bytes], peeled_refs: Optional[dict[bytes, bytes]] = None) -> None:
+def write_packed_refs(
+    f: IO[bytes],
+    packed_refs: dict[bytes, bytes],
+    peeled_refs: Optional[dict[bytes, bytes]] = None,
+) -> None:
     """Write a packed refs file.
 
     Args:
@@ -1306,7 +1381,7 @@ def write_packed_refs(f: Any, packed_refs: dict[bytes, bytes], peeled_refs: Opti
             f.write(b"^" + peeled_refs[refname] + b"\n")
 
 
-def read_info_refs(f: Any) -> dict[bytes, bytes]:
+def read_info_refs(f: BinaryIO) -> dict[bytes, bytes]:
     """Read info/refs file.
 
     Args:
@@ -1322,7 +1397,9 @@ def read_info_refs(f: Any) -> dict[bytes, bytes]:
     return ret
 
 
-def write_info_refs(refs: dict[bytes, bytes], store: ObjectContainer) -> Iterator[bytes]:
+def write_info_refs(
+    refs: dict[bytes, bytes], store: ObjectContainer
+) -> Iterator[bytes]:
     """Generate info refs."""
     # TODO: Avoid recursive import :(
     from .object_store import peel_sha
@@ -1346,26 +1423,33 @@ def is_local_branch(x: bytes) -> bool:
     return x.startswith(LOCAL_BRANCH_PREFIX)
 
 
-def strip_peeled_refs(refs: dict[bytes, bytes]) -> dict[bytes, bytes]:
+T = TypeVar("T", dict[bytes, bytes], dict[bytes, Optional[bytes]])
+
+
+def strip_peeled_refs(refs: T) -> T:
     """Remove all peeled refs."""
     return {
         ref: sha for (ref, sha) in refs.items() if not ref.endswith(PEELED_TAG_SUFFIX)
     }
 
 
-def split_peeled_refs(refs: dict[bytes, bytes]) -> tuple[dict[bytes, bytes], dict[bytes, bytes]]:
+def split_peeled_refs(refs: T) -> tuple[T, dict[bytes, bytes]]:
     """Split peeled refs from regular refs."""
-    peeled = {}
-    regular = {}
+    peeled: dict[bytes, bytes] = {}
+    regular = {k: v for k, v in refs.items() if not k.endswith(PEELED_TAG_SUFFIX)}
+
     for ref, sha in refs.items():
         if ref.endswith(PEELED_TAG_SUFFIX):
-            peeled[ref[: -len(PEELED_TAG_SUFFIX)]] = sha
-        else:
-            regular[ref] = sha
+            # Only add to peeled dict if sha is not None
+            if sha is not None:
+                peeled[ref[: -len(PEELED_TAG_SUFFIX)]] = sha
+
     return regular, peeled
 
 
-def _set_origin_head(refs: RefsContainer, origin: bytes, origin_head: Optional[bytes]) -> None:
+def _set_origin_head(
+    refs: RefsContainer, origin: bytes, origin_head: Optional[bytes]
+) -> None:
     # set refs/remotes/origin/HEAD
     origin_base = b"refs/remotes/" + origin + b"/"
     if origin_head and origin_head.startswith(LOCAL_BRANCH_PREFIX):
@@ -1409,7 +1493,9 @@ def _set_default_branch(
     return head_ref
 
 
-def _set_head(refs: RefsContainer, head_ref: bytes, ref_message: Optional[bytes]) -> Optional[bytes]:
+def _set_head(
+    refs: RefsContainer, head_ref: bytes, ref_message: Optional[bytes]
+) -> Optional[bytes]:
     if head_ref.startswith(LOCAL_TAG_PREFIX):
         # detach HEAD at specified tag
         head = refs[head_ref]
@@ -1432,7 +1518,7 @@ def _set_head(refs: RefsContainer, head_ref: bytes, ref_message: Optional[bytes]
 def _import_remote_refs(
     refs_container: RefsContainer,
     remote_name: str,
-    refs: dict[str, str],
+    refs: dict[bytes, Optional[bytes]],
     message: Optional[bytes] = None,
     prune: bool = False,
     prune_tags: bool = False,
@@ -1441,7 +1527,7 @@ def _import_remote_refs(
     branches = {
         n[len(LOCAL_BRANCH_PREFIX) :]: v
         for (n, v) in stripped_refs.items()
-        if n.startswith(LOCAL_BRANCH_PREFIX)
+        if n.startswith(LOCAL_BRANCH_PREFIX) and v is not None
     }
     refs_container.import_refs(
         b"refs/remotes/" + remote_name.encode(),
@@ -1452,14 +1538,18 @@ def _import_remote_refs(
     tags = {
         n[len(LOCAL_TAG_PREFIX) :]: v
         for (n, v) in stripped_refs.items()
-        if n.startswith(LOCAL_TAG_PREFIX) and not n.endswith(PEELED_TAG_SUFFIX)
+        if n.startswith(LOCAL_TAG_PREFIX)
+        and not n.endswith(PEELED_TAG_SUFFIX)
+        and v is not None
     }
     refs_container.import_refs(
         LOCAL_TAG_PREFIX, tags, message=message, prune=prune_tags
     )
 
 
-def serialize_refs(store: ObjectContainer, refs: dict[bytes, bytes]) -> dict[bytes, bytes]:
+def serialize_refs(
+    store: ObjectContainer, refs: dict[bytes, bytes]
+) -> dict[bytes, bytes]:
     """Serialize refs with peeled refs.
 
     Args:
@@ -1556,6 +1646,7 @@ class locked_ref:
         if not self._file:
             raise RuntimeError("locked_ref not in context")
 
+        assert self._realname is not None
         current_ref = self._refs_container.read_loose_ref(self._realname)
         if current_ref is None:
             current_ref = self._refs_container.get_packed_refs().get(
@@ -1622,3 +1713,14 @@ class locked_ref:
             self._refs_container._remove_packed_ref(self._realname)
 
         self._deleted = True
+
+
+def filter_ref_prefix(refs: T, prefixes: Iterable[bytes]) -> T:
+    """Filter refs to only include those with a given prefix.
+
+    Args:
+      refs: A dictionary of refs.
+      prefixes: The prefixes to filter by.
+    """
+    filtered = {k: v for k, v in refs.items() if any(k.startswith(p) for p in prefixes)}
+    return cast(T, filtered)

+ 1 - 1
dulwich/reftable.py

@@ -700,7 +700,7 @@ class ReftableReader:
         # Read magic bytes
         magic = self.f.read(4)
         if magic != REFTABLE_MAGIC:
-            raise ValueError(f"Invalid reftable magic: {magic}")
+            raise ValueError(f"Invalid reftable magic: {magic!r}")
 
         # Read version + block size (4 bytes total, big-endian network order)
         # Format: uint8(version) + uint24(block_size)

+ 62 - 38
dulwich/repo.py

@@ -34,7 +34,7 @@ import stat
 import sys
 import time
 import warnings
-from collections.abc import Iterable
+from collections.abc import Iterable, Iterator
 from io import BytesIO
 from typing import (
     TYPE_CHECKING,
@@ -42,6 +42,7 @@ from typing import (
     BinaryIO,
     Callable,
     Optional,
+    TypeVar,
     Union,
 )
 
@@ -52,8 +53,11 @@ if TYPE_CHECKING:
     from .attrs import GitAttributes
     from .config import ConditionMatcher, ConfigFile, StackedConfig
     from .index import Index
+    from .line_ending import BlobNormalizer
     from .notes import Notes
-    from .object_store import BaseObjectStore
+    from .object_store import BaseObjectStore, GraphWalker, UnpackedObject
+    from .rebase import RebaseStateManager
+    from .walk import Walker
     from .worktree import WorkTree
 
 from . import replace_me
@@ -116,6 +120,8 @@ from .refs import (
 
 CONTROLDIR = ".git"
 OBJECTDIR = "objects"
+
+T = TypeVar("T", bound="ShaFile")
 REFSDIR = "refs"
 REFSDIR_TAGS = "tags"
 REFSDIR_HEADS = "heads"
@@ -248,11 +254,11 @@ def check_user_identity(identity: bytes) -> None:
     try:
         fst, snd = identity.split(b" <", 1)
     except ValueError as exc:
-        raise InvalidUserIdentity(identity) from exc
+        raise InvalidUserIdentity(identity.decode("utf-8", "replace")) from exc
     if b">" not in snd:
-        raise InvalidUserIdentity(identity)
+        raise InvalidUserIdentity(identity.decode("utf-8", "replace"))
     if b"\0" in identity or b"\n" in identity:
-        raise InvalidUserIdentity(identity)
+        raise InvalidUserIdentity(identity.decode("utf-8", "replace"))
 
 
 def parse_graftpoints(
@@ -333,7 +339,12 @@ def _set_filesystem_hidden(path: str) -> None:
 
 
 class ParentsProvider:
-    def __init__(self, store: "BaseObjectStore", grafts: dict = {}, shallows: list = []) -> None:
+    def __init__(
+        self,
+        store: "BaseObjectStore",
+        grafts: dict = {},
+        shallows: Iterable[bytes] = [],
+    ) -> None:
         self.store = store
         self.grafts = grafts
         self.shallows = set(shallows)
@@ -341,7 +352,9 @@ class ParentsProvider:
         # Get commit graph once at initialization for performance
         self.commit_graph = store.get_commit_graph()
 
-    def get_parents(self, commit_id: bytes, commit: Optional[Any] = None) -> list[bytes]:
+    def get_parents(
+        self, commit_id: bytes, commit: Optional[Commit] = None
+    ) -> list[bytes]:
         try:
             return self.grafts[commit_id]
         except KeyError:
@@ -357,7 +370,9 @@ class ParentsProvider:
 
         # Fallback to reading the commit object
         if commit is None:
-            commit = self.store[commit_id]
+            obj = self.store[commit_id]
+            assert isinstance(obj, Commit)
+            commit = obj
         return commit.parents
 
 
@@ -472,7 +487,11 @@ class BaseRepo:
         raise NotImplementedError(self.open_index)
 
     def fetch(
-        self, target: "BaseRepo", determine_wants: Optional[Callable] = None, progress: Optional[Callable] = None, depth: Optional[int] = None
+        self,
+        target: "BaseRepo",
+        determine_wants: Optional[Callable] = None,
+        progress: Optional[Callable] = None,
+        depth: Optional[int] = None,
     ) -> dict:
         """Fetch objects into another repository.
 
@@ -498,7 +517,7 @@ class BaseRepo:
     def fetch_pack_data(
         self,
         determine_wants: Callable,
-        graph_walker: Any,
+        graph_walker: "GraphWalker",
         progress: Optional[Callable],
         *,
         get_tagged: Optional[Callable] = None,
@@ -533,7 +552,7 @@ class BaseRepo:
     def find_missing_objects(
         self,
         determine_wants: Callable,
-        graph_walker: Any,
+        graph_walker: "GraphWalker",
         progress: Optional[Callable],
         *,
         get_tagged: Optional[Callable] = None,
@@ -563,16 +582,17 @@ class BaseRepo:
         current_shallow = set(getattr(graph_walker, "shallow", set()))
 
         if depth not in (None, 0):
+            assert depth is not None
             shallow, not_shallow = find_shallow(self.object_store, wants, depth)
             # Only update if graph_walker has shallow attribute
             if hasattr(graph_walker, "shallow"):
                 graph_walker.shallow.update(shallow - not_shallow)
                 new_shallow = graph_walker.shallow - current_shallow
-                unshallow = graph_walker.unshallow = not_shallow & current_shallow
+                unshallow = graph_walker.unshallow = not_shallow & current_shallow  # type: ignore[attr-defined]
                 if hasattr(graph_walker, "update_shallow"):
                     graph_walker.update_shallow(new_shallow, unshallow)
         else:
-            unshallow = getattr(graph_walker, "unshallow", frozenset())
+            unshallow = getattr(graph_walker, "unshallow", set())
 
         if wants == []:
             # TODO(dborowitz): find a way to short-circuit that doesn't change
@@ -596,7 +616,7 @@ class BaseRepo:
                 def __len__(self) -> int:
                     return 0
 
-                def __iter__(self) -> Any:
+                def __iter__(self) -> Iterator[tuple[bytes, Optional[bytes]]]:
                     yield from []
 
             return DummyMissingObjectFinder()  # type: ignore
@@ -615,7 +635,7 @@ class BaseRepo:
 
         parents_provider = ParentsProvider(self.object_store, shallows=current_shallow)
 
-        def get_parents(commit: Any) -> list[bytes]:
+        def get_parents(commit: Commit) -> list[bytes]:
             """Get parents for a commit using the parents provider.
 
             Args:
@@ -638,11 +658,11 @@ class BaseRepo:
 
     def generate_pack_data(
         self,
-        have: list[ObjectID],
-        want: list[ObjectID],
+        have: Iterable[ObjectID],
+        want: Iterable[ObjectID],
         progress: Optional[Callable[[str], None]] = None,
         ofs_delta: Optional[bool] = None,
-    ) -> Any:
+    ) -> tuple[int, Iterator["UnpackedObject"]]:
         """Generate pack data objects for a set of wants/haves.
 
         Args:
@@ -697,18 +717,18 @@ class BaseRepo:
         # TODO: move this method to WorkTree
         return self.refs[b"HEAD"]
 
-    def _get_object(self, sha: bytes, cls: Any) -> Any:
+    def _get_object(self, sha: bytes, cls: type[T]) -> T:
         assert len(sha) in (20, 40)
         ret = self.get_object(sha)
         if not isinstance(ret, cls):
             if cls is Commit:
-                raise NotCommitError(ret)
+                raise NotCommitError(ret.id)
             elif cls is Blob:
-                raise NotBlobError(ret)
+                raise NotBlobError(ret.id)
             elif cls is Tree:
-                raise NotTreeError(ret)
+                raise NotTreeError(ret.id)
             elif cls is Tag:
-                raise NotTagError(ret)
+                raise NotTagError(ret.id)
             else:
                 raise Exception(f"Type invalid: {ret.type_name!r} != {cls.type_name!r}")
         return ret
@@ -776,14 +796,14 @@ class BaseRepo:
         """
         raise NotImplementedError(self.set_description)
 
-    def get_rebase_state_manager(self) -> Any:
+    def get_rebase_state_manager(self) -> "RebaseStateManager":
         """Get the appropriate rebase state manager for this repository.
 
         Returns: RebaseStateManager instance
         """
         raise NotImplementedError(self.get_rebase_state_manager)
 
-    def get_blob_normalizer(self) -> Any:
+    def get_blob_normalizer(self) -> "BlobNormalizer":
         """Return a BlobNormalizer object for checkin/checkout operations.
 
         Returns: BlobNormalizer instance
@@ -831,7 +851,9 @@ class BaseRepo:
         with f:
             return {line.strip() for line in f}
 
-    def update_shallow(self, new_shallow: Any, new_unshallow: Any) -> None:
+    def update_shallow(
+        self, new_shallow: Optional[set[bytes]], new_unshallow: Optional[set[bytes]]
+    ) -> None:
         """Update the list of shallow objects.
 
         Args:
@@ -873,7 +895,7 @@ class BaseRepo:
 
         return Notes(self.object_store, self.refs)
 
-    def get_walker(self, include: Optional[list[bytes]] = None, **kwargs) -> Any:
+    def get_walker(self, include: Optional[list[bytes]] = None, **kwargs) -> "Walker":
         """Obtain a walker for this repository.
 
         Args:
@@ -910,7 +932,7 @@ class BaseRepo:
 
         return Walker(self.object_store, include, **kwargs)
 
-    def __getitem__(self, name: Union[ObjectID, Ref]) -> Any:
+    def __getitem__(self, name: Union[ObjectID, Ref]) -> "ShaFile":
         """Retrieve a Git object by SHA1 or ref.
 
         Args:
@@ -1002,7 +1024,7 @@ class BaseRepo:
         for sha in to_remove:
             del self._graftpoints[sha]
 
-    def _read_heads(self, name: str) -> Any:
+    def _read_heads(self, name: str) -> list[bytes]:
         f = self.get_named_file(name)
         if f is None:
             return []
@@ -1028,17 +1050,17 @@ class BaseRepo:
         message: Optional[bytes] = None,
         committer: Optional[bytes] = None,
         author: Optional[bytes] = None,
-        commit_timestamp: Optional[Any] = None,
-        commit_timezone: Optional[Any] = None,
-        author_timestamp: Optional[Any] = None,
-        author_timezone: Optional[Any] = None,
+        commit_timestamp: Optional[float] = None,
+        commit_timezone: Optional[int] = None,
+        author_timestamp: Optional[float] = None,
+        author_timezone: Optional[int] = None,
         tree: Optional[ObjectID] = None,
         encoding: Optional[bytes] = None,
         ref: Optional[Ref] = b"HEAD",
         merge_heads: Optional[list[ObjectID]] = None,
         no_verify: bool = False,
         sign: bool = False,
-    ) -> Any:
+    ) -> bytes:
         """Create a new commit.
 
         If not specified, committer and author default to
@@ -1097,9 +1119,9 @@ def read_gitfile(f: BinaryIO) -> str:
     Returns: A path
     """
     cs = f.read()
-    if not cs.startswith("gitdir: "):
+    if not cs.startswith(b"gitdir: "):
         raise ValueError("Expected file to start with 'gitdir: '")
-    return cs[len("gitdir: ") :].rstrip("\n")
+    return cs[len(b"gitdir: ") :].rstrip(b"\n").decode("utf-8")
 
 
 class UnsupportedVersion(Exception):
@@ -1183,7 +1205,7 @@ class Repo(BaseRepo):
         self.bare = bare
         if bare is False:
             if os.path.isfile(hidden_path):
-                with open(hidden_path) as f:
+                with open(hidden_path, "rb") as f:
                     path = read_gitfile(f)
                 self._controldir = os.path.join(root, path)
             else:
@@ -2018,6 +2040,7 @@ class Repo(BaseRepo):
                 if isinstance(head, Tag):
                     _cls, obj = head.object
                     head = self.get_object(obj)
+                assert isinstance(head, Commit)
                 tree = head.tree
             except KeyError:
                 # No HEAD, no attributes from tree
@@ -2026,6 +2049,7 @@ class Repo(BaseRepo):
         if tree is not None:
             try:
                 tree_obj = self[tree]
+                assert isinstance(tree_obj, Tree)
                 if b".gitattributes" in tree_obj:
                     _, attrs_sha = tree_obj[b".gitattributes"]
                     attrs_blob = self[attrs_sha]
@@ -2114,7 +2138,7 @@ class MemoryRepo(BaseRepo):
 
         self._reflog: list[Any] = []
         refs_container = DictRefsContainer({}, logger=self._append_reflog)
-        BaseRepo.__init__(self, MemoryObjectStore(), refs_container)  # type: ignore
+        BaseRepo.__init__(self, MemoryObjectStore(), refs_container)  # type: ignore[arg-type]
         self._named_files: dict[str, bytes] = {}
         self.bare = True
         self._config = ConfigFile()

+ 15 - 3
dulwich/server.py

@@ -52,7 +52,7 @@ import time
 import zlib
 from collections.abc import Iterable, Iterator
 from functools import partial
-from typing import TYPE_CHECKING, Optional, cast
+from typing import TYPE_CHECKING, Optional
 from typing import Protocol as TypingProtocol
 
 if TYPE_CHECKING:
@@ -551,10 +551,12 @@ def _split_proto_line(line, allowed):
             COMMAND_SHALLOW,
             COMMAND_UNSHALLOW,
         ):
+            assert fields[1] is not None
             if not valid_hexsha(fields[1]):
                 raise GitProtocolError("Invalid sha")
             return tuple(fields)
         elif command == COMMAND_DEEPEN:
+            assert fields[1] is not None
             return command, int(fields[1])
     raise GitProtocolError(f"Received invalid line from client: {line!r}")
 
@@ -633,6 +635,9 @@ class AckGraphWalkerImpl:
         """
         raise NotImplementedError
 
+    def handle_done(self, done_required, done_received):
+        raise NotImplementedError
+
 
 class _ProtocolGraphWalker:
     """A graph walker that knows the git protocol.
@@ -784,6 +789,7 @@ class _ProtocolGraphWalker:
         """
         if len(have_ref) != 40:
             raise ValueError(f"invalid sha {have_ref!r}")
+        assert self._impl is not None
         return self._impl.ack(have_ref)
 
     def reset(self) -> None:
@@ -799,7 +805,8 @@ class _ProtocolGraphWalker:
         if not self._cached:
             if not self._impl and self.stateless_rpc:
                 return None
-            return next(self._impl)
+            assert self._impl is not None
+            return next(self._impl)  # type: ignore[call-overload]
         self._cache_index += 1
         if self._cache_index > len(self._cache):
             return None
@@ -884,6 +891,7 @@ class _ProtocolGraphWalker:
         Returns: True if done handling succeeded
         """
         # Delegate this to the implementation.
+        assert self._impl is not None
         return self._impl.handle_done(done_required, done_received)
 
     def set_wants(self, wants) -> None:
@@ -1400,7 +1408,11 @@ class UploadArchiveHandler(Handler):
                 format = arguments[i].decode("ascii")
             else:
                 commit_sha = self.repo.refs[argument]
-                tree = cast(Tree, store[cast(Commit, store[commit_sha]).tree])
+                commit_obj = store[commit_sha]
+                assert isinstance(commit_obj, Commit)
+                tree_obj = store[commit_obj.tree]
+                assert isinstance(tree_obj, Tree)
+                tree = tree_obj
             i += 1
         self.proto.write_pkt_line(b"ACK")
         self.proto.write_pkt_line(None)

+ 11 - 5
dulwich/stash.py

@@ -23,7 +23,7 @@
 
 import os
 import sys
-from typing import TYPE_CHECKING, Optional, TypedDict
+from typing import TYPE_CHECKING, Optional, TypedDict, Union
 
 from .diff_tree import tree_changes
 from .file import GitFile
@@ -162,10 +162,16 @@ class Stash:
             symlink_fn = symlink
         else:
 
-            def symlink_fn(source, target) -> None:  # type: ignore
-                mode = "w" + ("b" if isinstance(source, bytes) else "")
-                with open(target, mode) as f:
-                    f.write(source)
+            def symlink_fn(
+                src: Union[str, bytes, os.PathLike],
+                dst: Union[str, bytes, os.PathLike],
+                target_is_directory: bool = False,
+                *,
+                dir_fd: Optional[int] = None,
+            ) -> None:
+                mode = "w" + ("b" if isinstance(src, bytes) else "")
+                with open(dst, mode) as f:
+                    f.write(src)
 
         # Get blob normalizer for line ending conversion
         blob_normalizer = self._repo.get_blob_normalizer()

+ 1 - 1
dulwich/tests/test_object_store.py

@@ -444,7 +444,7 @@ class FindShallowTests(TestCase):
     def make_linear_commits(self, n, message=b""):
         """Create a linear chain of commits."""
         commits = []
-        parents = []
+        parents: list[bytes] = []
         for _ in range(n):
             commits.append(self.make_commit(parents=parents, message=message))
             parents = [commits[-1].id]

+ 19 - 4
dulwich/tests/utils.py

@@ -165,7 +165,9 @@ def make_tag(target: ShaFile, **attrs: Any) -> Tag:
     return make_object(Tag, **all_attrs)
 
 
-def functest_builder(method: Callable[[Any, Any], None], func: Any) -> Callable[[Any], None]:
+def functest_builder(
+    method: Callable[[Any, Any], None], func: Any
+) -> Callable[[Any], None]:
     """Generate a test method that tests the given function."""
 
     def do_test(self: Any) -> None:
@@ -174,7 +176,9 @@ def functest_builder(method: Callable[[Any, Any], None], func: Any) -> Callable[
     return do_test
 
 
-def ext_functest_builder(method: Callable[[Any, Any], None], func: Any) -> Callable[[Any], None]:
+def ext_functest_builder(
+    method: Callable[[Any, Any], None], func: Any
+) -> Callable[[Any], None]:
     """Generate a test method that tests the given extension function.
 
     This is intended to generate test methods that test both a pure-Python
@@ -204,7 +208,11 @@ def ext_functest_builder(method: Callable[[Any, Any], None], func: Any) -> Calla
     return do_test
 
 
-def build_pack(f: BinaryIO, objects_spec: list[tuple[int, Any]], store: Optional[BaseObjectStore] = None) -> list[tuple[int, int, bytes, bytes, int]]:
+def build_pack(
+    f: BinaryIO,
+    objects_spec: list[tuple[int, Any]],
+    store: Optional[BaseObjectStore] = None,
+) -> list[tuple[int, int, bytes, bytes, int]]:
     """Write test pack data from a concise spec.
 
     Args:
@@ -282,7 +290,14 @@ def build_pack(f: BinaryIO, objects_spec: list[tuple[int, Any]], store: Optional
     return expected
 
 
-def build_commit_graph(object_store: BaseObjectStore, commit_spec: list[list[int]], trees: Optional[dict[int, list[Union[tuple[bytes, ShaFile], tuple[bytes, ShaFile, int]]]]] = None, attrs: Optional[dict[int, dict[str, Any]]] = None) -> list[Commit]:
+def build_commit_graph(
+    object_store: BaseObjectStore,
+    commit_spec: list[list[int]],
+    trees: Optional[
+        dict[int, list[Union[tuple[bytes, ShaFile], tuple[bytes, ShaFile, int]]]]
+    ] = None,
+    attrs: Optional[dict[int, dict[str, Any]]] = None,
+) -> list[Commit]:
     """Build a commit graph from a concise specification.
 
     Sample usage:

+ 12 - 5
dulwich/walk.py

@@ -87,7 +87,9 @@ class WalkEntry:
                 parent = None
             elif len(self._get_parents(commit)) == 1:
                 changes_func = tree_changes
-                parent = cast(Commit, self._store[self._get_parents(commit)[0]]).tree
+                parent_commit = self._store[self._get_parents(commit)[0]]
+                assert isinstance(parent_commit, Commit)
+                parent = parent_commit.tree
                 if path_prefix:
                     mode, subtree_sha = parent.lookup_path(
                         self._store.__getitem__,
@@ -96,9 +98,12 @@ class WalkEntry:
                     parent = self._store[subtree_sha]
             else:
                 # For merge commits, we need to handle multiple parents differently
-                parent = [
-                    cast(Commit, self._store[p]).tree for p in self._get_parents(commit)
-                ]
+                parent_trees = []
+                for p in self._get_parents(commit):
+                    parent_commit = self._store[p]
+                    assert isinstance(parent_commit, Commit)
+                    parent_trees.append(parent_commit.tree)
+                parent = parent_trees
                 # Use a lambda to adapt the signature
                 changes_func = cast(
                     Any,
@@ -200,7 +205,9 @@ class _CommitTimeQueue:
                     # some caching (which DiskObjectStore currently does not).
                     # We could either add caching in this class or pass around
                     # parsed queue entry objects instead of commits.
-                    todo.append(cast(Commit, self._store[parent]))
+                    parent_commit = self._store[parent]
+                    assert isinstance(parent_commit, Commit)
+                    todo.append(parent_commit)
                 excluded.add(parent)
 
     def next(self) -> Optional[WalkEntry]:

+ 86 - 23
dulwich/web.py

@@ -26,9 +26,10 @@ import os
 import re
 import sys
 import time
-from collections.abc import Iterator
+from collections.abc import Iterable, Iterator
 from io import BytesIO
-from typing import Any, BinaryIO, Callable, ClassVar, Optional, cast
+from types import TracebackType
+from typing import Any, BinaryIO, Callable, ClassVar, Optional, Union, cast
 from urllib.parse import parse_qs
 from wsgiref.simple_server import (
     ServerHandler,
@@ -36,6 +37,7 @@ from wsgiref.simple_server import (
     WSGIServer,
     make_server,
 )
+from wsgiref.types import StartResponse, WSGIApplication, WSGIEnvironment
 
 from dulwich import log_utils
 
@@ -45,6 +47,7 @@ from .server import (
     DEFAULT_HANDLERS,
     Backend,
     DictBackend,
+    Handler,
     generate_info_refs,
     generate_objects_info_packs,
 )
@@ -292,13 +295,21 @@ def get_info_refs(
         yield req.not_found(str(e))
         return
     if service and not req.dumb:
+        if req.handlers is None:
+            yield req.forbidden("No handlers configured")
+            return
         handler_cls = req.handlers.get(service.encode("ascii"), None)
         if handler_cls is None:
             yield req.forbidden("Unsupported service")
             return
         req.nocache()
         write = req.respond(HTTP_OK, f"application/x-{service}-advertisement")
-        proto = ReceivableProtocol(BytesIO().read, write)
+
+        def write_fn(data: bytes) -> Optional[int]:
+            result = write(data)
+            return len(data) if result is not None else None
+
+        proto = ReceivableProtocol(BytesIO().read, write_fn)
         handler = handler_cls(
             backend,
             [url_prefix(mat)],
@@ -425,6 +436,9 @@ def handle_service_request(
     """
     service = mat.group().lstrip("/")
     logger.info("Handling service request for %s", service)
+    if req.handlers is None:
+        yield req.forbidden("No handlers configured")
+        return
     handler_cls = req.handlers.get(service.encode("ascii"), None)
     if handler_cls is None:
         yield req.forbidden("Unsupported service")
@@ -436,11 +450,16 @@ def handle_service_request(
         return
     req.nocache()
     write = req.respond(HTTP_OK, f"application/x-{service}-result")
+
+    def write_fn(data: bytes) -> Optional[int]:
+        result = write(data)
+        return len(data) if result is not None else None
+
     if req.environ.get("HTTP_TRANSFER_ENCODING") == "chunked":
         read = ChunkReader(req.environ["wsgi.input"]).read
     else:
         read = req.environ["wsgi.input"].read
-    proto = ReceivableProtocol(read, write)
+    proto = ReceivableProtocol(read, write_fn)
     # TODO(jelmer): Find a way to pass in repo, rather than having handler_cls
     # reopen.
     handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True)
@@ -455,7 +474,11 @@ class HTTPGitRequest:
     """
 
     def __init__(
-        self, environ: dict[str, Any], start_response: Callable[[str, list[tuple[str, str]]], Any], dumb: bool = False, handlers: Optional[dict[bytes, Callable]] = None
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+        dumb: bool = False,
+        handlers: Optional[dict[bytes, Callable]] = None,
     ) -> None:
         """Initialize HTTPGitRequest.
 
@@ -481,7 +504,7 @@ class HTTPGitRequest:
         status: str = HTTP_OK,
         content_type: Optional[str] = None,
         headers: Optional[list[tuple[str, str]]] = None,
-    ) -> Any:
+    ) -> Callable[[bytes], object]:
         """Begin a response with the given status and other headers."""
         if headers:
             self._headers.extend(headers)
@@ -556,7 +579,11 @@ class HTTPGitApplication:
     }
 
     def __init__(
-        self, backend: Backend, dumb: bool = False, handlers: Optional[dict[bytes, Callable]] = None, fallback_app: Optional[Callable[[dict[str, Any], Callable], Any]] = None
+        self,
+        backend: Backend,
+        dumb: bool = False,
+        handlers: Optional[dict[bytes, Callable]] = None,
+        fallback_app: Optional[WSGIApplication] = None,
     ) -> None:
         """Initialize HTTPGitApplication.
 
@@ -568,12 +595,18 @@ class HTTPGitApplication:
         """
         self.backend = backend
         self.dumb = dumb
-        self.handlers = dict(DEFAULT_HANDLERS)
+        self.handlers: dict[bytes, Union[type[Handler], Callable[..., Any]]] = dict(
+            DEFAULT_HANDLERS
+        )
         self.fallback_app = fallback_app
         if handlers is not None:
             self.handlers.update(handlers)
 
-    def __call__(self, environ: dict[str, Any], start_response: Callable[[str, list[tuple[str, str]]], Any]) -> list[bytes]:
+    def __call__(
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+    ) -> Iterable[bytes]:
         path = environ["PATH_INFO"]
         method = environ["REQUEST_METHOD"]
         req = HTTPGitRequest(
@@ -581,6 +614,7 @@ class HTTPGitApplication:
         )
         # environ['QUERY_STRING'] has qs args
         handler = None
+        mat = None
         for smethod, spath in self.services.keys():
             if smethod != method:
                 continue
@@ -589,7 +623,7 @@ class HTTPGitApplication:
                 handler = self.services[smethod, spath]
                 break
 
-        if handler is None:
+        if handler is None or mat is None:
             if self.fallback_app is not None:
                 return self.fallback_app(environ, start_response)
             else:
@@ -601,10 +635,14 @@ class HTTPGitApplication:
 class GunzipFilter:
     """WSGI middleware that unzips gzip-encoded requests before passing on to the underlying application."""
 
-    def __init__(self, application: Any) -> None:
+    def __init__(self, application: WSGIApplication) -> None:
         self.app = application
 
-    def __call__(self, environ: dict[str, Any], start_response: Callable[[str, list[tuple[str, str]]], Any]) -> Any:
+    def __call__(
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+    ) -> Iterable[bytes]:
         import gzip
 
         if environ.get("HTTP_CONTENT_ENCODING", "") == "gzip":
@@ -620,10 +658,14 @@ class GunzipFilter:
 class LimitedInputFilter:
     """WSGI middleware that limits the input length of a request to that specified in Content-Length."""
 
-    def __init__(self, application: Any) -> None:
+    def __init__(self, application: WSGIApplication) -> None:
         self.app = application
 
-    def __call__(self, environ: dict[str, Any], start_response: Callable[[str, list[tuple[str, str]]], Any]) -> Any:
+    def __call__(
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+    ) -> Iterable[bytes]:
         # This is not necessary if this app is run from a conforming WSGI
         # server. Unfortunately, there's no way to tell that at this point.
         # TODO: git may used HTTP/1.1 chunked encoding instead of specifying
@@ -636,11 +678,18 @@ class LimitedInputFilter:
         return self.app(environ, start_response)
 
 
-def make_wsgi_chain(*args: Any, **kwargs: Any) -> Any:
+def make_wsgi_chain(
+    backend: Backend,
+    dumb: bool = False,
+    handlers: Optional[dict[bytes, Callable[..., Any]]] = None,
+    fallback_app: Optional[WSGIApplication] = None,
+) -> WSGIApplication:
     """Factory function to create an instance of HTTPGitApplication,
     correctly wrapped with needed middleware.
     """
-    app = HTTPGitApplication(*args, **kwargs)
+    app = HTTPGitApplication(
+        backend, dumb=dumb, handlers=handlers, fallback_app=fallback_app
+    )
     wrapped_app = LimitedInputFilter(GunzipFilter(app))
     return wrapped_app
 
@@ -648,32 +697,46 @@ def make_wsgi_chain(*args: Any, **kwargs: Any) -> Any:
 class ServerHandlerLogger(ServerHandler):
     """ServerHandler that uses dulwich's logger for logging exceptions."""
 
-    def log_exception(self, exc_info: Any) -> None:
+    def log_exception(
+        self,
+        exc_info: Union[
+            tuple[type[BaseException], BaseException, TracebackType],
+            tuple[None, None, None],
+            None,
+        ],
+    ) -> None:
         logger.exception(
             "Exception happened during processing of request",
             exc_info=exc_info,
         )
 
-    def log_message(self, format: str, *args: Any) -> None:
+    def log_message(self, format: str, *args: object) -> None:
         logger.info(format, *args)
 
-    def log_error(self, *args: Any) -> None:
+    def log_error(self, *args: object) -> None:
         logger.error(*args)
 
 
 class WSGIRequestHandlerLogger(WSGIRequestHandler):
     """WSGIRequestHandler that uses dulwich's logger for logging exceptions."""
 
-    def log_exception(self, exc_info: Any) -> None:
+    def log_exception(
+        self,
+        exc_info: Union[
+            tuple[type[BaseException], BaseException, TracebackType],
+            tuple[None, None, None],
+            None,
+        ],
+    ) -> None:
         logger.exception(
             "Exception happened during processing of request",
             exc_info=exc_info,
         )
 
-    def log_message(self, format: str, *args: Any) -> None:
+    def log_message(self, format: str, *args: object) -> None:
         logger.info(format, *args)
 
-    def log_error(self, *args: Any) -> None:
+    def log_error(self, *args: object) -> None:
         logger.error(*args)
 
     def handle(self) -> None:
@@ -695,7 +758,7 @@ class WSGIRequestHandlerLogger(WSGIRequestHandler):
 class WSGIServerLogger(WSGIServer):
     """WSGIServer that uses dulwich's logger for error handling."""
 
-    def handle_error(self, request: Any, client_address: Any) -> None:
+    def handle_error(self, request: object, client_address: tuple[str, int]) -> None:
         """Handle an error."""
         logger.exception(
             f"Exception happened during processing of request from {client_address!s}"

+ 40 - 15
dulwich/worktree.py

@@ -34,9 +34,10 @@ import warnings
 from collections.abc import Iterable, Iterator
 from contextlib import contextmanager
 from pathlib import Path
+from typing import Any, Callable, Union
 
 from .errors import CommitError, HookError
-from .objects import Commit, ObjectID, Tag, Tree
+from .objects import Blob, Commit, ObjectID, Tag, Tree
 from .refs import SYMREF, Ref
 from .repo import (
     GITDIR,
@@ -335,7 +336,7 @@ class WorkTree:
 
         index = self._repo.open_index()
         try:
-            tree_id = self._repo[b"HEAD"].tree
+            commit = self._repo[b"HEAD"]
         except KeyError:
             # no head mean no commit in the repo
             for fs_path in fs_paths:
@@ -343,6 +344,9 @@ class WorkTree:
                 del index[tree_path]
             index.write()
             return
+        else:
+            assert isinstance(commit, Commit), "HEAD must be a commit"
+            tree_id = commit.tree
 
         for fs_path in fs_paths:
             tree_path = _fs_to_tree_path(fs_path)
@@ -367,15 +371,19 @@ class WorkTree:
             except FileNotFoundError:
                 pass
 
+            blob_obj = self._repo[tree_entry[1]]
+            assert isinstance(blob_obj, Blob)
+            blob_size = len(blob_obj.data)
+
             index_entry = IndexEntry(
-                ctime=(self._repo[b"HEAD"].commit_time, 0),
-                mtime=(self._repo[b"HEAD"].commit_time, 0),
+                ctime=(commit.commit_time, 0),
+                mtime=(commit.commit_time, 0),
                 dev=st.st_dev if st else 0,
                 ino=st.st_ino if st else 0,
                 mode=tree_entry[0],
                 uid=st.st_uid if st else 0,
                 gid=st.st_gid if st else 0,
-                size=len(self._repo[tree_entry[1]].data),
+                size=blob_size,
                 sha=tree_entry[1],
                 flags=0,
                 extended_flags=0,
@@ -386,7 +394,7 @@ class WorkTree:
 
     def commit(
         self,
-        message: bytes | None = None,
+        message: Union[str, bytes, Callable[[Any, Commit], bytes], None] = None,
         committer: bytes | None = None,
         author: bytes | None = None,
         commit_timestamp: float | None = None,
@@ -541,13 +549,18 @@ class WorkTree:
                 if should_sign:
                     c.sign(keyid)
                 self._repo.object_store.add_object(c)
+                message_bytes = (
+                    message.encode() if isinstance(message, str) else message
+                )
                 ok = self._repo.refs.set_if_equals(
                     ref,
                     old_head,
                     c.id,
-                    message=b"commit: " + message,
+                    message=b"commit: " + message_bytes,
                     committer=committer,
-                    timestamp=commit_timestamp,
+                    timestamp=int(commit_timestamp)
+                    if commit_timestamp is not None
+                    else None,
                     timezone=commit_timezone,
                 )
             except KeyError:
@@ -555,12 +568,17 @@ class WorkTree:
                 if should_sign:
                     c.sign(keyid)
                 self._repo.object_store.add_object(c)
+                message_bytes = (
+                    message.encode() if isinstance(message, str) else message
+                )
                 ok = self._repo.refs.add_if_new(
                     ref,
                     c.id,
-                    message=b"commit: " + message,
+                    message=b"commit: " + message_bytes,
                     committer=committer,
-                    timestamp=commit_timestamp,
+                    timestamp=int(commit_timestamp)
+                    if commit_timestamp is not None
+                    else None,
                     timezone=commit_timezone,
                 )
             if not ok:
@@ -603,6 +621,9 @@ class WorkTree:
             if isinstance(head, Tag):
                 _cls, obj = head.object
                 head = self._repo.get_object(obj)
+            from .objects import Commit
+
+            assert isinstance(head, Commit)
             tree = head.tree
         config = self._repo.get_config()
         honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
@@ -616,11 +637,15 @@ class WorkTree:
             symlink_fn = symlink
         else:
 
-            def symlink_fn(source, target) -> None:  # type: ignore
-                with open(
-                    target, "w" + ("b" if isinstance(source, bytes) else "")
-                ) as f:
-                    f.write(source)
+            def symlink_fn(
+                src: Union[str, bytes, os.PathLike],
+                dst: Union[str, bytes, os.PathLike],
+                target_is_directory: bool = False,
+                *,
+                dir_fd: int | None = None,
+            ) -> None:
+                with open(dst, "w" + ("b" if isinstance(src, bytes) else "")) as f:
+                    f.write(src)
 
         blob_normalizer = self._repo.get_blob_normalizer()
         return build_index_from_tree(

+ 2 - 1
pyproject.toml

@@ -25,7 +25,7 @@ classifiers = [
 requires-python = ">=3.9"
 dependencies = [
     "urllib3>=1.25",
-    'typing_extensions >=4.0 ; python_version < "3.11"',
+    'typing_extensions >=4.0 ; python_version < "3.12"',
 ]
 dynamic = ["version"]
 license-files = ["COPYING"]
@@ -99,6 +99,7 @@ ignore = [
     "ANN205",
     "ANN206",
     "E501",  # line too long
+    "UP007",  # Use X | Y for type annotations (Python 3.10+ syntax, but we support 3.9+)
 ]
 
 [tool.ruff.lint.pydocstyle]

+ 0 - 1
tests/contrib/__init__.py

@@ -23,7 +23,6 @@ import unittest
 
 
 def test_suite() -> unittest.TestSuite:
-
     names = [
         "diffstat",
         "paramiko_vendor",

+ 20 - 4
tests/test_annotate.py

@@ -110,7 +110,9 @@ class UpdateLinesTestCase(TestCase):
 
     def test_update_lines_empty_new(self) -> None:
         """Test update_lines with empty new blob."""
-        old_lines: list[tuple[tuple[Any, Any], bytes]] = [(("commit1", "entry1"), b"line1")]
+        old_lines: list[tuple[tuple[Any, Any], bytes]] = [
+            (("commit1", "entry1"), b"line1")
+        ]
         new_blob = b""
         new_history_data = ("commit2", "entry2")
 
@@ -131,7 +133,9 @@ class AnnotateLinesTestCase(TestCase):
 
         shutil.rmtree(self.temp_dir)
 
-    def _make_commit(self, blob_content: bytes, message: str, parent: Optional[bytes] = None) -> bytes:
+    def _make_commit(
+        self, blob_content: bytes, message: str, parent: Optional[bytes] = None
+    ) -> bytes:
         """Helper to create a commit with a single file."""
         # Create blob
         blob = Blob()
@@ -222,7 +226,13 @@ class PorcelainAnnotateTestCase(TestCase):
 
         shutil.rmtree(self.temp_dir)
 
-    def _make_commit_with_file(self, filename: str, content: bytes, message: str, parent: Optional[bytes] = None) -> bytes:
+    def _make_commit_with_file(
+        self,
+        filename: str,
+        content: bytes,
+        message: str,
+        parent: Optional[bytes] = None,
+    ) -> bytes:
         """Helper to create a commit with a file."""
         # Create blob
         blob = Blob()
@@ -315,7 +325,13 @@ class IntegrationTestCase(TestCase):
 
         shutil.rmtree(self.temp_dir)
 
-    def _create_file_commit(self, filename: str, content: bytes, message: str, parent: Optional[bytes] = None) -> bytes:
+    def _create_file_commit(
+        self,
+        filename: str,
+        content: bytes,
+        message: str,
+        parent: Optional[bytes] = None,
+    ) -> bytes:
         """Helper to create a commit with file content."""
         # Write file to working directory
         filepath = os.path.join(self.temp_dir, filename)

+ 3 - 1
tests/test_archive.py

@@ -46,7 +46,9 @@ class ArchiveTests(TestCase):
         self.addCleanup(tf.close)
         self.assertEqual([], tf.getnames())
 
-    def _get_example_tar_stream(self, mtime: int, prefix: bytes = b"", format: str = "") -> BytesIO:
+    def _get_example_tar_stream(
+        self, mtime: int, prefix: bytes = b"", format: str = ""
+    ) -> BytesIO:
         store = MemoryObjectStore()
         b1 = Blob.from_string(b"somedata")
         store.add_object(b1)

+ 2 - 2
tests/test_cloud_gcs.py

@@ -39,8 +39,8 @@ class GcsObjectStoreTests(unittest.TestCase):
         self.assertIn("git", repr(self.store))
 
     def test_remove_pack(self):
-        """Test _remove_pack method."""
-        self.store._remove_pack("pack-1234")
+        """Test _remove_pack_by_name method."""
+        self.store._remove_pack_by_name("pack-1234")
         self.mock_bucket.delete_blobs.assert_called_once()
         args = self.mock_bucket.delete_blobs.call_args[0][0]
         self.assertEqual(

+ 2 - 1
tests/test_commit_graph.py

@@ -736,7 +736,8 @@ class CommitGraphGenerationTests(unittest.TestCase):
         self.assertEqual(parents_provider.get_parents(commit1.id), [])  # type: ignore[no-untyped-call]
         self.assertEqual(parents_provider.get_parents(commit2.id), [commit1.id])  # type: ignore[no-untyped-call]
         self.assertEqual(
-            parents_provider.get_parents(commit5.id), [commit3.id, commit4.id]  # type: ignore[no-untyped-call]
+            parents_provider.get_parents(commit5.id),
+            [commit3.id, commit4.id],  # type: ignore[no-untyped-call]
         )
 
     def test_performance_with_commit_graph(self) -> None:

+ 12 - 3
tests/test_dumb.py

@@ -32,7 +32,12 @@ from dulwich.objects import Blob, Commit, ShaFile, Tag, Tree, sha_to_hex
 
 
 class MockResponse:
-    def __init__(self, status: int = 200, content: bytes = b"", headers: Optional[dict[str, str]] = None) -> None:
+    def __init__(
+        self,
+        status: int = 200,
+        content: bytes = b"",
+        headers: Optional[dict[str, str]] = None,
+    ) -> None:
         self.status = status
         self.content = content
         self.headers = headers or {}
@@ -50,7 +55,9 @@ class DumbHTTPObjectStoreTests(TestCase):
         self.responses: dict[str, dict[str, Union[int, bytes]]] = {}
         self.store = DumbHTTPObjectStore(self.base_url, self._mock_http_request)
 
-    def _mock_http_request(self, url: str, headers: dict[str, str]) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
+    def _mock_http_request(
+        self, url: str, headers: dict[str, str]
+    ) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
         """Mock HTTP request function."""
         if url in self.responses:
             resp_data = self.responses[url]
@@ -183,7 +190,9 @@ class DumbRemoteHTTPRepoTests(TestCase):
         self.responses: dict[str, dict[str, Union[int, bytes]]] = {}
         self.repo = DumbRemoteHTTPRepo(self.base_url, self._mock_http_request)
 
-    def _mock_http_request(self, url: str, headers: dict[str, str]) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
+    def _mock_http_request(
+        self, url: str, headers: dict[str, str]
+    ) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
         """Mock HTTP request function."""
         if url in self.responses:
             resp_data = self.responses[url]

+ 1 - 1
tests/test_protocol.py

@@ -36,10 +36,10 @@ from dulwich.protocol import (
     ack_type,
     extract_capabilities,
     extract_want_line_capabilities,
-    filter_ref_prefix,
     pkt_line,
     pkt_seq,
 )
+from dulwich.refs import filter_ref_prefix
 
 from . import TestCase
 

部分文件因文件數量過多而無法顯示