Переглянути джерело

Fix remaining mypy type errors and ruff linting issues

Jelmer Vernooij 5 місяців тому
батько
коміт
aee765a866

+ 6 - 5
dulwich/annotate.py

@@ -27,8 +27,9 @@ Python's difflib.
 """
 """
 
 
 import difflib
 import difflib
-from typing import TYPE_CHECKING, Optional, cast
+from typing import TYPE_CHECKING, Optional
 
 
+from dulwich.objects import Blob
 from dulwich.walk import (
 from dulwich.walk import (
     ORDER_DATE,
     ORDER_DATE,
     Walker,
     Walker,
@@ -37,7 +38,7 @@ from dulwich.walk import (
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from dulwich.diff_tree import TreeChange, TreeEntry
     from dulwich.diff_tree import TreeChange, TreeEntry
     from dulwich.object_store import BaseObjectStore
     from dulwich.object_store import BaseObjectStore
-    from dulwich.objects import Blob, Commit
+    from dulwich.objects import Commit
 
 
 # Walk over ancestry graph breadth-first
 # Walk over ancestry graph breadth-first
 # When checking each revision, find lines that according to difflib.Differ()
 # When checking each revision, find lines that according to difflib.Differ()
@@ -108,7 +109,7 @@ def annotate_lines(
 
 
     lines_annotated: list[tuple[tuple[Commit, TreeEntry], bytes]] = []
     lines_annotated: list[tuple[tuple[Commit, TreeEntry], bytes]] = []
     for commit, entry in reversed(revs):
     for commit, entry in reversed(revs):
-        lines_annotated = update_lines(
-            lines_annotated, (commit, entry), cast("Blob", store[entry.sha])
-        )
+        blob_obj = store[entry.sha]
+        assert isinstance(blob_obj, Blob)
+        lines_annotated = update_lines(lines_annotated, (commit, entry), blob_obj)
     return lines_annotated
     return lines_annotated

+ 4 - 2
dulwich/bisect.py

@@ -49,7 +49,7 @@ class BisectState:
         self,
         self,
         bad: Optional[bytes] = None,
         bad: Optional[bytes] = None,
         good: Optional[list[bytes]] = None,
         good: Optional[list[bytes]] = None,
-        paths: Optional[list[str]] = None,
+        paths: Optional[list[bytes]] = None,
         no_checkout: bool = False,
         no_checkout: bool = False,
         term_bad: str = "bad",
         term_bad: str = "bad",
         term_good: str = "good",
         term_good: str = "good",
@@ -103,7 +103,9 @@ class BisectState:
         names_file = os.path.join(self.repo.controldir(), "BISECT_NAMES")
         names_file = os.path.join(self.repo.controldir(), "BISECT_NAMES")
         with open(names_file, "w") as f:
         with open(names_file, "w") as f:
             if paths:
             if paths:
-                f.write("\n".join(paths) + "\n")
+                f.write(
+                    "\n".join(path.decode("utf-8", "replace") for path in paths) + "\n"
+                )
             else:
             else:
                 f.write("\n")
                 f.write("\n")
 
 

+ 27 - 7
dulwich/bundle.py

@@ -22,9 +22,25 @@
 """Bundle format support."""
 """Bundle format support."""
 
 
 from collections.abc import Iterator
 from collections.abc import Iterator
-from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Optional
+from typing import (
+    TYPE_CHECKING,
+    BinaryIO,
+    Callable,
+    Optional,
+    Protocol,
+    runtime_checkable,
+)
+
+from .pack import PackData, UnpackedObject, write_pack_data
+
+
+@runtime_checkable
+class PackDataLike(Protocol):
+    """Protocol for objects that behave like PackData."""
+
+    def __len__(self) -> int: ...
+    def iter_unpacked(self) -> Iterator[UnpackedObject]: ...
 
 
-from .pack import PackData, write_pack_data
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
     from .object_store import BaseObjectStore
@@ -39,7 +55,7 @@ class Bundle:
     capabilities: dict[str, Optional[str]]
     capabilities: dict[str, Optional[str]]
     prerequisites: list[tuple[bytes, bytes]]
     prerequisites: list[tuple[bytes, bytes]]
     references: dict[bytes, bytes]
     references: dict[bytes, bytes]
-    pack_data: PackData
+    pack_data: Optional[PackDataLike]
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         """Return string representation of Bundle."""
         """Return string representation of Bundle."""
@@ -79,10 +95,12 @@ class Bundle:
         """
         """
         from .objects import ShaFile
         from .objects import ShaFile
 
 
+        if self.pack_data is None:
+            raise ValueError("pack_data is not loaded")
         count = 0
         count = 0
         for unpacked in self.pack_data.iter_unpacked():
         for unpacked in self.pack_data.iter_unpacked():
             # Convert the unpacked object to a proper git object
             # Convert the unpacked object to a proper git object
-            if unpacked.decomp_chunks:
+            if unpacked.decomp_chunks and unpacked.obj_type_num is not None:
                 git_obj = ShaFile.from_raw_chunks(
                 git_obj = ShaFile.from_raw_chunks(
                     unpacked.obj_type_num, unpacked.decomp_chunks
                     unpacked.obj_type_num, unpacked.decomp_chunks
                 )
                 )
@@ -187,6 +205,8 @@ def write_bundle(f: BinaryIO, bundle: Bundle) -> None:
     for ref, obj_id in bundle.references.items():
     for ref, obj_id in bundle.references.items():
         f.write(obj_id + b" " + ref + b"\n")
         f.write(obj_id + b" " + ref + b"\n")
     f.write(b"\n")
     f.write(b"\n")
+    if bundle.pack_data is None:
+        raise ValueError("bundle.pack_data is not loaded")
     write_pack_data(
     write_pack_data(
         f.write,
         f.write,
         num_records=len(bundle.pack_data),
         num_records=len(bundle.pack_data),
@@ -283,14 +303,14 @@ def create_bundle_from_repo(
     # Store the pack objects directly, we'll write them when saving the bundle
     # Store the pack objects directly, we'll write them when saving the bundle
     # For now, create a simple wrapper to hold the data
     # For now, create a simple wrapper to hold the data
     class _BundlePackData:
     class _BundlePackData:
-        def __init__(self, count: int, objects: Iterator[Any]) -> None:
+        def __init__(self, count: int, objects: Iterator[UnpackedObject]) -> None:
             self._count = count
             self._count = count
             self._objects = list(objects)  # Materialize the iterator
             self._objects = list(objects)  # Materialize the iterator
 
 
         def __len__(self) -> int:
         def __len__(self) -> int:
             return self._count
             return self._count
 
 
-        def iter_unpacked(self) -> Iterator[Any]:
+        def iter_unpacked(self) -> Iterator[UnpackedObject]:
             return iter(self._objects)
             return iter(self._objects)
 
 
     pack_data = _BundlePackData(pack_count, pack_objects)
     pack_data = _BundlePackData(pack_count, pack_objects)
@@ -301,6 +321,6 @@ def create_bundle_from_repo(
     bundle.capabilities = capabilities
     bundle.capabilities = capabilities
     bundle.prerequisites = bundle_prerequisites
     bundle.prerequisites = bundle_prerequisites
     bundle.references = bundle_refs
     bundle.references = bundle_refs
-    bundle.pack_data = pack_data  # type: ignore[assignment]
+    bundle.pack_data = pack_data
 
 
     return bundle
     return bundle

+ 81 - 32
dulwich/cli.py

@@ -37,7 +37,7 @@ import subprocess
 import sys
 import sys
 import tempfile
 import tempfile
 from pathlib import Path
 from pathlib import Path
-from typing import Callable, ClassVar, Optional, Union
+from typing import BinaryIO, Callable, ClassVar, Optional, Union
 
 
 from dulwich import porcelain
 from dulwich import porcelain
 
 
@@ -45,12 +45,26 @@ from .bundle import create_bundle_from_repo, read_bundle, write_bundle
 from .client import GitProtocolError, get_transport_and_path
 from .client import GitProtocolError, get_transport_and_path
 from .errors import ApplyDeltaError
 from .errors import ApplyDeltaError
 from .index import Index
 from .index import Index
-from .objects import valid_hexsha
+from .objects import Commit, valid_hexsha
 from .objectspec import parse_commit_range
 from .objectspec import parse_commit_range
 from .pack import Pack, sha_to_hex
 from .pack import Pack, sha_to_hex
 from .repo import Repo
 from .repo import Repo
 
 
 
 
+def to_display_str(value: Union[bytes, str]) -> str:
+    """Convert a bytes or string value to a display string.
+
+    Args:
+        value: The value to convert (bytes or str)
+
+    Returns:
+        A string suitable for display
+    """
+    if isinstance(value, bytes):
+        return value.decode("utf-8", "replace")
+    return value
+
+
 class CommitMessageError(Exception):
 class CommitMessageError(Exception):
     """Raised when there's an issue with the commit message."""
     """Raised when there's an issue with the commit message."""
 
 
@@ -126,7 +140,7 @@ def parse_relative_time(time_str: str) -> int:
         raise
         raise
 
 
 
 
-def format_bytes(bytes: int) -> str:
+def format_bytes(bytes: float) -> str:
     """Format bytes as human-readable string.
     """Format bytes as human-readable string.
 
 
     Args:
     Args:
@@ -231,7 +245,7 @@ class Pager:
         Args:
         Args:
             pager_cmd: Command to use for paging (default: "cat")
             pager_cmd: Command to use for paging (default: "cat")
         """
         """
-        self.pager_process = None
+        self.pager_process: Optional[subprocess.Popen] = None
         self.buffer = PagerBuffer(self)
         self.buffer = PagerBuffer(self)
         self._closed = False
         self._closed = False
         self.pager_cmd = pager_cmd
         self.pager_cmd = pager_cmd
@@ -449,12 +463,12 @@ def get_pager(config=None, cmd_name: Optional[str] = None):
 
 
 def disable_pager() -> None:
 def disable_pager() -> None:
     """Disable pager for this session."""
     """Disable pager for this session."""
-    get_pager._disabled = True
+    get_pager._disabled = True  # type: ignore[attr-defined]
 
 
 
 
 def enable_pager() -> None:
 def enable_pager() -> None:
     """Enable pager for this session."""
     """Enable pager for this session."""
-    get_pager._disabled = False
+    get_pager._disabled = False  # type: ignore[attr-defined]
 
 
 
 
 class Command:
 class Command:
@@ -491,10 +505,14 @@ class cmd_archive(Command):
                 write_error=sys.stderr.write,
                 write_error=sys.stderr.write,
             )
             )
         else:
         else:
-            # Use buffer if available (for binary output), otherwise use stdout
-            outstream = getattr(sys.stdout, "buffer", sys.stdout)
+            # Use binary buffer for archive output
+            outstream: BinaryIO = sys.stdout.buffer
+            errstream: BinaryIO = sys.stderr.buffer
             porcelain.archive(
             porcelain.archive(
-                ".", args.committish, outstream=outstream, errstream=sys.stderr
+                ".",
+                args.committish,
+                outstream=outstream,
+                errstream=errstream,
             )
             )
 
 
 
 
@@ -642,10 +660,11 @@ class cmd_fetch(Command):
         def progress(msg: bytes) -> None:
         def progress(msg: bytes) -> None:
             sys.stdout.buffer.write(msg)
             sys.stdout.buffer.write(msg)
 
 
-        refs = client.fetch(path, r, progress=progress)
+        result = client.fetch(path, r, progress=progress)
         print("Remote refs:")
         print("Remote refs:")
-        for item in refs.items():
-            print("{} -> {}".format(*item))
+        for ref, sha in result.refs.items():
+            if sha is not None:
+                print(f"{ref.decode()} -> {sha.decode()}")
 
 
 
 
 class cmd_for_each_ref(Command):
 class cmd_for_each_ref(Command):
@@ -676,7 +695,7 @@ class cmd_fsck(Command):
         parser = argparse.ArgumentParser()
         parser = argparse.ArgumentParser()
         parser.parse_args(args)
         parser.parse_args(args)
         for obj, msg in porcelain.fsck("."):
         for obj, msg in porcelain.fsck("."):
-            print(f"{obj}: {msg}")
+            print(f"{obj.decode() if isinstance(obj, bytes) else obj}: {msg}")
 
 
 
 
 class cmd_log(Command):
 class cmd_log(Command):
@@ -1302,9 +1321,17 @@ class cmd_reflog(Command):
 
 
                     for i, entry in enumerate(porcelain.reflog(repo, ref)):
                     for i, entry in enumerate(porcelain.reflog(repo, ref)):
                         # Format similar to git reflog
                         # Format similar to git reflog
+                        from dulwich.reflog import Entry
+
+                        assert isinstance(entry, Entry)
                         short_new = entry.new_sha[:8].decode("ascii")
                         short_new = entry.new_sha[:8].decode("ascii")
+                        message = (
+                            entry.message.decode("utf-8", "replace")
+                            if entry.message
+                            else ""
+                        )
                         outstream.write(
                         outstream.write(
-                            f"{short_new} {ref.decode('utf-8', 'replace')}@{{{i}}}: {entry.message.decode('utf-8', 'replace')}\n"
+                            f"{short_new} {ref.decode('utf-8', 'replace')}@{{{i}}}: {message}\n"
                         )
                         )
 
 
 
 
@@ -1538,11 +1565,14 @@ class cmd_ls_remote(Command):
         if args.symref:
         if args.symref:
             # Show symrefs first, like git does
             # Show symrefs first, like git does
             for ref, target in sorted(result.symrefs.items()):
             for ref, target in sorted(result.symrefs.items()):
-                sys.stdout.write(f"ref: {target.decode()}\t{ref.decode()}\n")
+                if target:
+                    sys.stdout.write(f"ref: {target.decode()}\t{ref.decode()}\n")
 
 
         # Show regular refs
         # Show regular refs
         for ref in sorted(result.refs):
         for ref in sorted(result.refs):
-            sys.stdout.write(f"{result.refs[ref].decode()}\t{ref.decode()}\n")
+            sha = result.refs[ref]
+            if sha is not None:
+                sys.stdout.write(f"{sha.decode()}\t{ref.decode()}\n")
 
 
 
 
 class cmd_ls_tree(Command):
 class cmd_ls_tree(Command):
@@ -1601,12 +1631,13 @@ class cmd_pack_objects(Command):
         if not args.stdout and not args.basename:
         if not args.stdout and not args.basename:
             parser.error("basename required when not using --stdout")
             parser.error("basename required when not using --stdout")
 
 
-        object_ids = [line.strip() for line in sys.stdin.readlines()]
+        object_ids = [line.strip().encode() for line in sys.stdin.readlines()]
         deltify = args.deltify
         deltify = args.deltify
         reuse_deltas = not args.no_reuse_deltas
         reuse_deltas = not args.no_reuse_deltas
 
 
         if args.stdout:
         if args.stdout:
             packf = getattr(sys.stdout, "buffer", sys.stdout)
             packf = getattr(sys.stdout, "buffer", sys.stdout)
+            assert isinstance(packf, BinaryIO)
             idxf = None
             idxf = None
             close = []
             close = []
         else:
         else:
@@ -2022,8 +2053,17 @@ class cmd_stash_list(Command):
         """
         """
         parser = argparse.ArgumentParser()
         parser = argparse.ArgumentParser()
         parser.parse_args(args)
         parser.parse_args(args)
-        for i, entry in porcelain.stash_list("."):
-            print("stash@{{{}}}: {}".format(i, entry.message.rstrip("\n")))
+        from .repo import Repo
+        from .stash import Stash
+
+        with Repo(".") as r:
+            stash = Stash.from_repo(r)
+            for i, entry in enumerate(stash.stashes()):
+                print(
+                    "stash@{{{}}}: {}".format(
+                        i, entry.message.decode("utf-8", "replace").rstrip("\n")
+                    )
+                )
 
 
 
 
 class cmd_stash_push(Command):
 class cmd_stash_push(Command):
@@ -2145,6 +2185,7 @@ class cmd_bisect(SuperCommand):
                         with open(bad_ref, "rb") as f:
                         with open(bad_ref, "rb") as f:
                             bad_sha = f.read().strip()
                             bad_sha = f.read().strip()
                         commit = r.object_store[bad_sha]
                         commit = r.object_store[bad_sha]
+                        assert isinstance(commit, Commit)
                         message = commit.message.decode(
                         message = commit.message.decode(
                             "utf-8", errors="replace"
                             "utf-8", errors="replace"
                         ).split("\n")[0]
                         ).split("\n")[0]
@@ -2173,7 +2214,7 @@ class cmd_bisect(SuperCommand):
                 print(log, end="")
                 print(log, end="")
 
 
             elif parsed_args.subcommand == "replay":
             elif parsed_args.subcommand == "replay":
-                porcelain.bisect_replay(log_file=parsed_args.logfile)
+                porcelain.bisect_replay(".", log_file=parsed_args.logfile)
                 print(f"Replayed bisect log from {parsed_args.logfile}")
                 print(f"Replayed bisect log from {parsed_args.logfile}")
 
 
             elif parsed_args.subcommand == "help":
             elif parsed_args.subcommand == "help":
@@ -2270,6 +2311,7 @@ class cmd_merge(Command):
             elif args.no_commit:
             elif args.no_commit:
                 print("Automatic merge successful; not committing as requested.")
                 print("Automatic merge successful; not committing as requested.")
             else:
             else:
+                assert merge_commit_id is not None
                 print(
                 print(
                     f"Merge successful. Created merge commit {merge_commit_id.decode()}"
                     f"Merge successful. Created merge commit {merge_commit_id.decode()}"
                 )
                 )
@@ -3129,12 +3171,14 @@ class cmd_lfs(Command):
             tracked = porcelain.lfs_untrack(patterns=args.patterns)
             tracked = porcelain.lfs_untrack(patterns=args.patterns)
             print("Remaining tracked patterns:")
             print("Remaining tracked patterns:")
             for pattern in tracked:
             for pattern in tracked:
-                print(f"  {pattern}")
+                print(f"  {to_display_str(pattern)}")
 
 
         elif args.subcommand == "ls-files":
         elif args.subcommand == "ls-files":
             files = porcelain.lfs_ls_files(ref=args.ref)
             files = porcelain.lfs_ls_files(ref=args.ref)
             for path, oid, size in files:
             for path, oid, size in files:
-                print(f"{oid[:12]} * {path} ({format_bytes(size)})")
+                print(
+                    f"{to_display_str(oid[:12])} * {to_display_str(path)} ({format_bytes(size)})"
+                )
 
 
         elif args.subcommand == "migrate":
         elif args.subcommand == "migrate":
             count = porcelain.lfs_migrate(
             count = porcelain.lfs_migrate(
@@ -3145,13 +3189,13 @@ class cmd_lfs(Command):
         elif args.subcommand == "pointer":
         elif args.subcommand == "pointer":
             if args.paths is not None:
             if args.paths is not None:
                 results = porcelain.lfs_pointer_check(paths=args.paths or None)
                 results = porcelain.lfs_pointer_check(paths=args.paths or None)
-                for path, pointer in results.items():
+                for file_path, pointer in results.items():
                     if pointer:
                     if pointer:
                         print(
                         print(
-                            f"{path}: LFS pointer (oid: {pointer.oid[:12]}, size: {format_bytes(pointer.size)})"
+                            f"{to_display_str(file_path)}: LFS pointer (oid: {to_display_str(pointer.oid[:12])}, size: {format_bytes(pointer.size)})"
                         )
                         )
                     else:
                     else:
-                        print(f"{path}: Not an LFS pointer")
+                        print(f"{to_display_str(file_path)}: Not an LFS pointer")
 
 
         elif args.subcommand == "clean":
         elif args.subcommand == "clean":
             pointer = porcelain.lfs_clean(path=args.path)
             pointer = porcelain.lfs_clean(path=args.path)
@@ -3188,13 +3232,13 @@ class cmd_lfs(Command):
 
 
             if status["missing"]:
             if status["missing"]:
                 print("\nMissing LFS objects:")
                 print("\nMissing LFS objects:")
-                for path in status["missing"]:
-                    print(f"  {path}")
+                for file_path in status["missing"]:
+                    print(f"  {to_display_str(file_path)}")
 
 
             if status["not_staged"]:
             if status["not_staged"]:
                 print("\nModified LFS files not staged:")
                 print("\nModified LFS files not staged:")
-                for path in status["not_staged"]:
-                    print(f"  {path}")
+                for file_path in status["not_staged"]:
+                    print(f"  {to_display_str(file_path)}")
 
 
             if not any(status.values()):
             if not any(status.values()):
                 print("No LFS files found.")
                 print("No LFS files found.")
@@ -3273,14 +3317,19 @@ class cmd_format_patch(Command):
         args = parser.parse_args(args)
         args = parser.parse_args(args)
 
 
         # Parse committish using the new function
         # Parse committish using the new function
-        committish = None
+        committish: Optional[Union[bytes, tuple[bytes, bytes]]] = None
         if args.committish:
         if args.committish:
             with Repo(".") as r:
             with Repo(".") as r:
                 range_result = parse_commit_range(r, args.committish)
                 range_result = parse_commit_range(r, args.committish)
                 if range_result:
                 if range_result:
-                    committish = range_result
+                    # Convert Commit objects to their SHAs
+                    committish = (range_result[0].id, range_result[1].id)
                 else:
                 else:
-                    committish = args.committish
+                    committish = (
+                        args.committish.encode()
+                        if isinstance(args.committish, str)
+                        else args.committish
+                    )
 
 
         filenames = porcelain.format_patch(
         filenames = porcelain.format_patch(
             ".",
             ".",

+ 128 - 78
dulwich/client.py

@@ -53,7 +53,6 @@ from io import BufferedReader, BytesIO
 from typing import (
 from typing import (
     IO,
     IO,
     TYPE_CHECKING,
     TYPE_CHECKING,
-    Any,
     Callable,
     Callable,
     ClassVar,
     ClassVar,
     Optional,
     Optional,
@@ -70,11 +69,11 @@ import dulwich
 
 
 from .config import Config, apply_instead_of, get_xdg_config_home_path
 from .config import Config, apply_instead_of, get_xdg_config_home_path
 from .errors import GitProtocolError, NotGitRepository, SendPackError
 from .errors import GitProtocolError, NotGitRepository, SendPackError
+from .object_store import GraphWalker
 from .pack import (
 from .pack import (
     PACK_SPOOL_FILE_MAX_SIZE,
     PACK_SPOOL_FILE_MAX_SIZE,
     PackChunkGenerator,
     PackChunkGenerator,
     PackData,
     PackData,
-    UnpackedObject,
     write_pack_from_container,
     write_pack_from_container,
 )
 )
 from .protocol import (
 from .protocol import (
@@ -117,7 +116,6 @@ from .protocol import (
     capability_agent,
     capability_agent,
     extract_capabilities,
     extract_capabilities,
     extract_capability_names,
     extract_capability_names,
-    filter_ref_prefix,
     parse_capability,
     parse_capability,
     pkt_line,
     pkt_line,
     pkt_seq,
     pkt_seq,
@@ -130,6 +128,7 @@ from .refs import (
     _set_default_branch,
     _set_default_branch,
     _set_head,
     _set_head,
     _set_origin_head,
     _set_origin_head,
+    filter_ref_prefix,
     read_info_refs,
     read_info_refs,
     split_peeled_refs,
     split_peeled_refs,
 )
 )
@@ -150,7 +149,7 @@ logger = logging.getLogger(__name__)
 class InvalidWants(Exception):
 class InvalidWants(Exception):
     """Invalid wants."""
     """Invalid wants."""
 
 
-    def __init__(self, wants: Any) -> None:
+    def __init__(self, wants: set[bytes]) -> None:
         """Initialize InvalidWants exception.
         """Initialize InvalidWants exception.
 
 
         Args:
         Args:
@@ -164,7 +163,7 @@ class InvalidWants(Exception):
 class HTTPUnauthorized(Exception):
 class HTTPUnauthorized(Exception):
     """Raised when authentication fails."""
     """Raised when authentication fails."""
 
 
-    def __init__(self, www_authenticate: Any, url: str) -> None:
+    def __init__(self, www_authenticate: Optional[str], url: str) -> None:
         """Initialize HTTPUnauthorized exception.
         """Initialize HTTPUnauthorized exception.
 
 
         Args:
         Args:
@@ -179,7 +178,7 @@ class HTTPUnauthorized(Exception):
 class HTTPProxyUnauthorized(Exception):
 class HTTPProxyUnauthorized(Exception):
     """Raised when proxy authentication fails."""
     """Raised when proxy authentication fails."""
 
 
-    def __init__(self, proxy_authenticate: Any, url: str) -> None:
+    def __init__(self, proxy_authenticate: Optional[str], url: str) -> None:
         """Initialize HTTPProxyUnauthorized exception.
         """Initialize HTTPProxyUnauthorized exception.
 
 
         Args:
         Args:
@@ -196,17 +195,23 @@ def _fileno_can_read(fileno: int) -> bool:
     return len(select.select([fileno], [], [], 0)[0]) > 0
     return len(select.select([fileno], [], [], 0)[0]) > 0
 
 
 
 
-def _win32_peek_avail(handle: Any) -> int:
+def _win32_peek_avail(handle: int) -> int:
     """Wrapper around PeekNamedPipe to check how many bytes are available."""
     """Wrapper around PeekNamedPipe to check how many bytes are available."""
-    from ctypes import byref, windll, wintypes
+    from ctypes import (  # type: ignore[attr-defined]
+        byref,
+        windll,  # type: ignore[attr-defined]
+        wintypes,
+    )
 
 
     c_avail = wintypes.DWORD()
     c_avail = wintypes.DWORD()
     c_message = wintypes.DWORD()
     c_message = wintypes.DWORD()
-    success = windll.kernel32.PeekNamedPipe(
+    success = windll.kernel32.PeekNamedPipe(  # type: ignore[attr-defined]
         handle, None, 0, None, byref(c_avail), byref(c_message)
         handle, None, 0, None, byref(c_avail), byref(c_message)
     )
     )
     if not success:
     if not success:
-        raise OSError(wintypes.GetLastError())
+        from ctypes import GetLastError  # type: ignore[attr-defined]
+
+        raise OSError(GetLastError())
     return c_avail.value
     return c_avail.value
 
 
 
 
@@ -231,10 +236,10 @@ class ReportStatusParser:
     def __init__(self) -> None:
     def __init__(self) -> None:
         """Initialize ReportStatusParser."""
         """Initialize ReportStatusParser."""
         self._done = False
         self._done = False
-        self._pack_status = None
+        self._pack_status: Optional[bytes] = None
         self._ref_statuses: list[bytes] = []
         self._ref_statuses: list[bytes] = []
 
 
-    def check(self) -> Any:
+    def check(self) -> Iterator[tuple[bytes, Optional[str]]]:
         """Check if there were any errors and, if so, raise exceptions.
         """Check if there were any errors and, if so, raise exceptions.
 
 
         Raises:
         Raises:
@@ -277,7 +282,7 @@ class ReportStatusParser:
             self._ref_statuses.append(ref_status)
             self._ref_statuses.append(ref_status)
 
 
 
 
-def negotiate_protocol_version(proto: Any) -> int:
+def negotiate_protocol_version(proto: Protocol) -> int:
     pkt = proto.read_pkt_line()
     pkt = proto.read_pkt_line()
     if pkt is not None and pkt.strip() == b"version 2":
     if pkt is not None and pkt.strip() == b"version 2":
         return 2
         return 2
@@ -285,7 +290,7 @@ def negotiate_protocol_version(proto: Any) -> int:
     return 0
     return 0
 
 
 
 
-def read_server_capabilities(pkt_seq: Any) -> set:
+def read_server_capabilities(pkt_seq: Iterable[bytes]) -> set[bytes]:
     server_capabilities = []
     server_capabilities = []
     for pkt in pkt_seq:
     for pkt in pkt_seq:
         server_capabilities.append(pkt)
         server_capabilities.append(pkt)
@@ -293,21 +298,15 @@ def read_server_capabilities(pkt_seq: Any) -> set:
 
 
 
 
 def read_pkt_refs_v2(
 def read_pkt_refs_v2(
-    pkt_seq: Any,
-) -> tuple[dict[bytes, bytes], dict[bytes, bytes], dict[bytes, bytes]]:
-    """Read packet references in protocol v2 format.
-
-    Args:
-      pkt_seq: Sequence of packets
-    Returns: Tuple of (refs dict, symrefs dict, peeled dict)
-    """
-    refs = {}
+    pkt_seq: Iterable[bytes],
+) -> tuple[dict[bytes, Optional[bytes]], dict[bytes, bytes], dict[bytes, bytes]]:
+    refs: dict[bytes, Optional[bytes]] = {}
     symrefs = {}
     symrefs = {}
     peeled = {}
     peeled = {}
     # Receive refs from server
     # Receive refs from server
     for pkt in pkt_seq:
     for pkt in pkt_seq:
         parts = pkt.rstrip(b"\n").split(b" ")
         parts = pkt.rstrip(b"\n").split(b" ")
-        sha = parts[0]
+        sha: Optional[bytes] = parts[0]
         if sha == b"unborn":
         if sha == b"unborn":
             sha = None
             sha = None
         ref = parts[1]
         ref = parts[1]
@@ -323,9 +322,11 @@ def read_pkt_refs_v2(
     return refs, symrefs, peeled
     return refs, symrefs, peeled
 
 
 
 
-def read_pkt_refs_v1(pkt_seq: Any) -> tuple[dict[bytes, bytes], set[bytes]]:
+def read_pkt_refs_v1(
+    pkt_seq: Iterable[bytes],
+) -> tuple[dict[bytes, Optional[bytes]], set[bytes]]:
     server_capabilities = None
     server_capabilities = None
-    refs = {}
+    refs: dict[bytes, Optional[bytes]] = {}
     # Receive refs from server
     # Receive refs from server
     for pkt in pkt_seq:
     for pkt in pkt_seq:
         (sha, ref) = pkt.rstrip(b"\n").split(None, 1)
         (sha, ref) = pkt.rstrip(b"\n").split(None, 1)
@@ -346,6 +347,8 @@ def read_pkt_refs_v1(pkt_seq: Any) -> tuple[dict[bytes, bytes], set[bytes]]:
 class _DeprecatedDictProxy:
 class _DeprecatedDictProxy:
     """Base class for result objects that provide deprecated dict-like interface."""
     """Base class for result objects that provide deprecated dict-like interface."""
 
 
+    refs: dict[bytes, Optional[bytes]]  # To be overridden by subclasses
+
     _FORWARDED_ATTRS: ClassVar[set[str]] = {
     _FORWARDED_ATTRS: ClassVar[set[str]] = {
         "clear",
         "clear",
         "copy",
         "copy",
@@ -376,7 +379,7 @@ class _DeprecatedDictProxy:
         self._warn_deprecated()
         self._warn_deprecated()
         return name in self.refs
         return name in self.refs
 
 
-    def __getitem__(self, name: bytes) -> bytes:
+    def __getitem__(self, name: bytes) -> Optional[bytes]:
         self._warn_deprecated()
         self._warn_deprecated()
         return self.refs[name]
         return self.refs[name]
 
 
@@ -384,11 +387,11 @@ class _DeprecatedDictProxy:
         self._warn_deprecated()
         self._warn_deprecated()
         return len(self.refs)
         return len(self.refs)
 
 
-    def __iter__(self) -> Any:
+    def __iter__(self) -> Iterator[bytes]:
         self._warn_deprecated()
         self._warn_deprecated()
         return iter(self.refs)
         return iter(self.refs)
 
 
-    def __getattribute__(self, name: str) -> Any:
+    def __getattribute__(self, name: str) -> object:
         # Avoid infinite recursion by checking against class variable directly
         # Avoid infinite recursion by checking against class variable directly
         if name != "_FORWARDED_ATTRS" and name in type(self)._FORWARDED_ATTRS:
         if name != "_FORWARDED_ATTRS" and name in type(self)._FORWARDED_ATTRS:
             self._warn_deprecated()
             self._warn_deprecated()
@@ -407,8 +410,16 @@ class FetchPackResult(_DeprecatedDictProxy):
       agent: User agent string
       agent: User agent string
     """
     """
 
 
+    symrefs: dict[bytes, bytes]
+    agent: Optional[bytes]
+
     def __init__(
     def __init__(
-        self, refs: dict, symrefs: dict, agent: Optional[bytes], new_shallow: Optional[Any] = None, new_unshallow: Optional[Any] = None
+        self,
+        refs: dict[bytes, Optional[bytes]],
+        symrefs: dict[bytes, bytes],
+        agent: Optional[bytes],
+        new_shallow: Optional[set[bytes]] = None,
+        new_unshallow: Optional[set[bytes]] = None,
     ) -> None:
     ) -> None:
         """Initialize FetchPackResult.
         """Initialize FetchPackResult.
 
 
@@ -425,10 +436,12 @@ class FetchPackResult(_DeprecatedDictProxy):
         self.new_shallow = new_shallow
         self.new_shallow = new_shallow
         self.new_unshallow = new_unshallow
         self.new_unshallow = new_unshallow
 
 
-    def __eq__(self, other: Any) -> bool:
+    def __eq__(self, other: object) -> bool:
         if isinstance(other, dict):
         if isinstance(other, dict):
             self._warn_deprecated()
             self._warn_deprecated()
             return self.refs == other
             return self.refs == other
+        if not isinstance(other, FetchPackResult):
+            return False
         return (
         return (
             self.refs == other.refs
             self.refs == other.refs
             and self.symrefs == other.symrefs
             and self.symrefs == other.symrefs
@@ -448,7 +461,11 @@ class LsRemoteResult(_DeprecatedDictProxy):
       symrefs: Dictionary with remote symrefs
       symrefs: Dictionary with remote symrefs
     """
     """
 
 
-    def __init__(self, refs: dict, symrefs: dict) -> None:
+    symrefs: dict[bytes, bytes]
+
+    def __init__(
+        self, refs: dict[bytes, Optional[bytes]], symrefs: dict[bytes, bytes]
+    ) -> None:
         """Initialize LsRemoteResult.
         """Initialize LsRemoteResult.
 
 
         Args:
         Args:
@@ -468,10 +485,12 @@ class LsRemoteResult(_DeprecatedDictProxy):
             stacklevel=3,
             stacklevel=3,
         )
         )
 
 
-    def __eq__(self, other: Any) -> bool:
+    def __eq__(self, other: object) -> bool:
         if isinstance(other, dict):
         if isinstance(other, dict):
             self._warn_deprecated()
             self._warn_deprecated()
             return self.refs == other
             return self.refs == other
+        if not isinstance(other, LsRemoteResult):
+            return False
         return self.refs == other.refs and self.symrefs == other.symrefs
         return self.refs == other.refs and self.symrefs == other.symrefs
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
@@ -489,7 +508,12 @@ class SendPackResult(_DeprecatedDictProxy):
         failed to update), or None if it was updated successfully
         failed to update), or None if it was updated successfully
     """
     """
 
 
-    def __init__(self, refs: dict, agent: Optional[bytes] = None, ref_status: Optional[dict] = None) -> None:
+    def __init__(
+        self,
+        refs: dict[bytes, Optional[bytes]],
+        agent: Optional[bytes] = None,
+        ref_status: Optional[dict[bytes, Optional[str]]] = None,
+    ) -> None:
         """Initialize SendPackResult.
         """Initialize SendPackResult.
 
 
         Args:
         Args:
@@ -501,10 +525,12 @@ class SendPackResult(_DeprecatedDictProxy):
         self.agent = agent
         self.agent = agent
         self.ref_status = ref_status
         self.ref_status = ref_status
 
 
-    def __eq__(self, other: Any) -> bool:
+    def __eq__(self, other: object) -> bool:
         if isinstance(other, dict):
         if isinstance(other, dict):
             self._warn_deprecated()
             self._warn_deprecated()
             return self.refs == other
             return self.refs == other
+        if not isinstance(other, SendPackResult):
+            return False
         return self.refs == other.refs and self.agent == other.agent
         return self.refs == other.refs and self.agent == other.agent
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
@@ -512,7 +538,7 @@ class SendPackResult(_DeprecatedDictProxy):
         return f"{self.__class__.__name__}({self.refs!r}, {self.agent!r})"
         return f"{self.__class__.__name__}({self.refs!r}, {self.agent!r})"
 
 
 
 
-def _read_shallow_updates(pkt_seq: Any) -> tuple[set, set]:
+def _read_shallow_updates(pkt_seq: Iterable[bytes]) -> tuple[set[bytes], set[bytes]]:
     new_shallow = set()
     new_shallow = set()
     new_unshallow = set()
     new_unshallow = set()
     for pkt in pkt_seq:
     for pkt in pkt_seq:
@@ -521,27 +547,29 @@ def _read_shallow_updates(pkt_seq: Any) -> tuple[set, set]:
         try:
         try:
             cmd, sha = pkt.split(b" ", 1)
             cmd, sha = pkt.split(b" ", 1)
         except ValueError:
         except ValueError:
-            raise GitProtocolError(f"unknown command {pkt}")
+            raise GitProtocolError(f"unknown command {pkt!r}")
         if cmd == COMMAND_SHALLOW:
         if cmd == COMMAND_SHALLOW:
             new_shallow.add(sha.strip())
             new_shallow.add(sha.strip())
         elif cmd == COMMAND_UNSHALLOW:
         elif cmd == COMMAND_UNSHALLOW:
             new_unshallow.add(sha.strip())
             new_unshallow.add(sha.strip())
         else:
         else:
-            raise GitProtocolError(f"unknown command {pkt}")
+            raise GitProtocolError(f"unknown command {pkt!r}")
     return (new_shallow, new_unshallow)
     return (new_shallow, new_unshallow)
 
 
 
 
 class _v1ReceivePackHeader:
 class _v1ReceivePackHeader:
     def __init__(self, capabilities: list, old_refs: dict, new_refs: dict) -> None:
     def __init__(self, capabilities: list, old_refs: dict, new_refs: dict) -> None:
-        self.want: list[bytes] = []
-        self.have: list[bytes] = []
+        self.want: set[bytes] = set()
+        self.have: set[bytes] = set()
         self._it = self._handle_receive_pack_head(capabilities, old_refs, new_refs)
         self._it = self._handle_receive_pack_head(capabilities, old_refs, new_refs)
         self.sent_capabilities = False
         self.sent_capabilities = False
 
 
-    def __iter__(self) -> Any:
+    def __iter__(self) -> Iterator[Optional[bytes]]:
         return self._it
         return self._it
 
 
-    def _handle_receive_pack_head(self, capabilities: list, old_refs: dict, new_refs: dict) -> Any:
+    def _handle_receive_pack_head(
+        self, capabilities: list, old_refs: dict, new_refs: dict
+    ) -> Iterator[Optional[bytes]]:
         """Handle the head of a 'git-receive-pack' request.
         """Handle the head of a 'git-receive-pack' request.
 
 
         Args:
         Args:
@@ -552,7 +580,7 @@ class _v1ReceivePackHeader:
         Returns:
         Returns:
           (have, want) tuple
           (have, want) tuple
         """
         """
-        self.have = [x for x in old_refs.values() if not x == ZERO_SHA]
+        self.have = {x for x in old_refs.values() if not x == ZERO_SHA}
 
 
         for refname in new_refs:
         for refname in new_refs:
             if not isinstance(refname, bytes):
             if not isinstance(refname, bytes):
@@ -586,7 +614,7 @@ class _v1ReceivePackHeader:
                     )
                     )
                     self.sent_capabilities = True
                     self.sent_capabilities = True
             if new_sha1 not in self.have and new_sha1 != ZERO_SHA:
             if new_sha1 not in self.have and new_sha1 != ZERO_SHA:
-                self.want.append(new_sha1)
+                self.want.add(new_sha1)
         yield None
         yield None
 
 
 
 
@@ -603,25 +631,28 @@ def _read_side_band64k_data(pkt_seq: Iterable[bytes]) -> Iterator[tuple[int, byt
         yield channel, pkt[1:]
         yield channel, pkt[1:]
 
 
 
 
-def find_capability(capabilities: list, key: bytes, value: Optional[bytes]) -> Optional[bytes]:
+def find_capability(
+    capabilities: list, key: bytes, value: Optional[bytes]
+) -> Optional[bytes]:
     for capability in capabilities:
     for capability in capabilities:
         k, v = parse_capability(capability)
         k, v = parse_capability(capability)
         if k != key:
         if k != key:
             continue
             continue
-        if value and value not in v.split(b" "):
+        if value and v and value not in v.split(b" "):
             continue
             continue
         return capability
         return capability
+    return None
 
 
 
 
 def _handle_upload_pack_head(
 def _handle_upload_pack_head(
-    proto: Any,
+    proto: Protocol,
     capabilities: list,
     capabilities: list,
-    graph_walker: Any,
+    graph_walker: GraphWalker,
     wants: list,
     wants: list,
-    can_read: Callable,
+    can_read: Optional[Callable],
     depth: Optional[int],
     depth: Optional[int],
     protocol_version: Optional[int],
     protocol_version: Optional[int],
-) -> None:
+) -> tuple[Optional[set[bytes]], Optional[set[bytes]]]:
     """Handle the head of a 'git-upload-pack' request.
     """Handle the head of a 'git-upload-pack' request.
 
 
     Args:
     Args:
@@ -634,6 +665,8 @@ def _handle_upload_pack_head(
       depth: Depth for request
       depth: Depth for request
       protocol_version: Neogiated Git protocol version.
       protocol_version: Neogiated Git protocol version.
     """
     """
+    new_shallow: Optional[set[bytes]]
+    new_unshallow: Optional[set[bytes]]
     assert isinstance(wants, list) and isinstance(wants[0], bytes)
     assert isinstance(wants, list) and isinstance(wants[0], bytes)
     wantcmd = COMMAND_WANT + b" " + wants[0]
     wantcmd = COMMAND_WANT + b" " + wants[0]
     if protocol_version is None:
     if protocol_version is None:
@@ -644,7 +677,9 @@ def _handle_upload_pack_head(
     proto.write_pkt_line(wantcmd)
     proto.write_pkt_line(wantcmd)
     for want in wants[1:]:
     for want in wants[1:]:
         proto.write_pkt_line(COMMAND_WANT + b" " + want + b"\n")
         proto.write_pkt_line(COMMAND_WANT + b" " + want + b"\n")
-    if depth not in (0, None) or graph_walker.shallow:
+    if depth not in (0, None) or (
+        hasattr(graph_walker, "shallow") and graph_walker.shallow
+    ):
         if protocol_version == 2:
         if protocol_version == 2:
             if not find_capability(capabilities, CAPABILITY_FETCH, CAPABILITY_SHALLOW):
             if not find_capability(capabilities, CAPABILITY_FETCH, CAPABILITY_SHALLOW):
                 raise GitProtocolError(
                 raise GitProtocolError(
@@ -654,8 +689,9 @@ def _handle_upload_pack_head(
             raise GitProtocolError(
             raise GitProtocolError(
                 "server does not support shallow capability required for depth"
                 "server does not support shallow capability required for depth"
             )
             )
-        for sha in graph_walker.shallow:
-            proto.write_pkt_line(COMMAND_SHALLOW + b" " + sha + b"\n")
+        if hasattr(graph_walker, "shallow"):
+            for sha in graph_walker.shallow:
+                proto.write_pkt_line(COMMAND_SHALLOW + b" " + sha + b"\n")
         if depth is not None:
         if depth is not None:
             proto.write_pkt_line(
             proto.write_pkt_line(
                 COMMAND_DEEPEN + b" " + str(depth).encode("ascii") + b"\n"
                 COMMAND_DEEPEN + b" " + str(depth).encode("ascii") + b"\n"
@@ -668,6 +704,7 @@ def _handle_upload_pack_head(
         proto.write_pkt_line(COMMAND_HAVE + b" " + have + b"\n")
         proto.write_pkt_line(COMMAND_HAVE + b" " + have + b"\n")
         if can_read is not None and can_read():
         if can_read is not None and can_read():
             pkt = proto.read_pkt_line()
             pkt = proto.read_pkt_line()
+            assert pkt is not None
             parts = pkt.rstrip(b"\n").split(b" ")
             parts = pkt.rstrip(b"\n").split(b" ")
             if parts[0] == b"ACK":
             if parts[0] == b"ACK":
                 graph_walker.ack(parts[1])
                 graph_walker.ack(parts[1])
@@ -677,7 +714,7 @@ def _handle_upload_pack_head(
                     break
                     break
                 else:
                 else:
                     raise AssertionError(
                     raise AssertionError(
-                        f"{parts[2]} not in ('continue', 'ready', 'common)"
+                        f"{parts[2]!r} not in ('continue', 'ready', 'common)"
                     )
                     )
         have = next(graph_walker)
         have = next(graph_walker)
     proto.write_pkt_line(COMMAND_DONE + b"\n")
     proto.write_pkt_line(COMMAND_DONE + b"\n")
@@ -688,7 +725,8 @@ def _handle_upload_pack_head(
         if can_read is not None:
         if can_read is not None:
             (new_shallow, new_unshallow) = _read_shallow_updates(proto.read_pkt_seq())
             (new_shallow, new_unshallow) = _read_shallow_updates(proto.read_pkt_seq())
         else:
         else:
-            new_shallow = new_unshallow = None
+            new_shallow = None
+            new_unshallow = None
     else:
     else:
         new_shallow = new_unshallow = set()
         new_shallow = new_unshallow = set()
 
 
@@ -767,6 +805,7 @@ def _extract_symrefs_and_agent(capabilities):
     for capability in capabilities:
     for capability in capabilities:
         k, v = parse_capability(capability)
         k, v = parse_capability(capability)
         if k == CAPABILITY_SYMREF:
         if k == CAPABILITY_SYMREF:
+            assert v is not None
             (src, dst) = v.split(b":", 1)
             (src, dst) = v.split(b":", 1)
             symrefs[src] = dst
             symrefs[src] = dst
         if k == CAPABILITY_AGENT:
         if k == CAPABILITY_AGENT:
@@ -842,9 +881,7 @@ class GitClient:
         self,
         self,
         path: str,
         path: str,
         update_refs,
         update_refs,
-        generate_pack_data: Callable[
-            [set[bytes], set[bytes], bool], tuple[int, Iterator[UnpackedObject]]
-        ],
+        generate_pack_data,
         progress=None,
         progress=None,
     ) -> SendPackResult:
     ) -> SendPackResult:
         """Upload a pack to a remote repository.
         """Upload a pack to a remote repository.
@@ -935,8 +972,11 @@ class GitClient:
             origin_sha = result.refs.get(b"HEAD")
             origin_sha = result.refs.get(b"HEAD")
             if origin is None or (origin_sha and not origin_head):
             if origin is None or (origin_sha and not origin_head):
                 # set detached HEAD
                 # set detached HEAD
-                target.refs[b"HEAD"] = origin_sha
-                head = origin_sha
+                if origin_sha is not None:
+                    target.refs[b"HEAD"] = origin_sha
+                    head = origin_sha
+                else:
+                    head = None
             else:
             else:
                 _set_origin_head(target.refs, origin.encode("utf-8"), origin_head)
                 _set_origin_head(target.refs, origin.encode("utf-8"), origin_head)
                 head_ref = _set_default_branch(
                 head_ref = _set_default_branch(
@@ -1166,10 +1206,11 @@ class GitClient:
             if self.protocol_version == 2 and k == CAPABILITY_FETCH:
             if self.protocol_version == 2 and k == CAPABILITY_FETCH:
                 fetch_capa = CAPABILITY_FETCH
                 fetch_capa = CAPABILITY_FETCH
                 fetch_features = []
                 fetch_features = []
-                v = v.strip().split(b" ")
-                if b"shallow" in v:
+                assert v is not None
+                v_list = v.strip().split(b" ")
+                if b"shallow" in v_list:
                     fetch_features.append(CAPABILITY_SHALLOW)
                     fetch_features.append(CAPABILITY_SHALLOW)
-                if b"filter" in v:
+                if b"filter" in v_list:
                     fetch_features.append(CAPABILITY_FILTER)
                     fetch_features.append(CAPABILITY_FILTER)
                 for i in range(len(fetch_features)):
                 for i in range(len(fetch_features)):
                     if i == 0:
                     if i == 0:
@@ -1320,10 +1361,10 @@ class TraditionalGitClient(GitClient):
                 for ref, sha in orig_new_refs.items():
                 for ref, sha in orig_new_refs.items():
                     if sha == ZERO_SHA:
                     if sha == ZERO_SHA:
                         if CAPABILITY_REPORT_STATUS in negotiated_capabilities:
                         if CAPABILITY_REPORT_STATUS in negotiated_capabilities:
+                            assert report_status_parser is not None
                             report_status_parser._ref_statuses.append(
                             report_status_parser._ref_statuses.append(
                                 b"ng " + ref + b" remote does not support deleting refs"
                                 b"ng " + ref + b" remote does not support deleting refs"
                             )
                             )
-                            report_status_parser._ref_status_ok = False
                         del new_refs[ref]
                         del new_refs[ref]
 
 
             if new_refs is None:
             if new_refs is None:
@@ -1730,7 +1771,7 @@ class TCPGitClient(TraditionalGitClient):
         proto.send_cmd(
         proto.send_cmd(
             b"git-" + cmd, path, b"host=" + self._host.encode("ascii") + version_str
             b"git-" + cmd, path, b"host=" + self._host.encode("ascii") + version_str
         )
         )
-        return proto, lambda: _fileno_can_read(s), None
+        return proto, lambda: _fileno_can_read(s.fileno()), None
 
 
 
 
 class SubprocessWrapper:
 class SubprocessWrapper:
@@ -1960,7 +2001,7 @@ class LocalGitClient(GitClient):
                 *generate_pack_data(have, want, ofs_delta=True)
                 *generate_pack_data(have, want, ofs_delta=True)
             )
             )
 
 
-            ref_status = {}
+            ref_status: dict[bytes, Optional[str]] = {}
 
 
             for refname, new_sha1 in new_refs.items():
             for refname, new_sha1 in new_refs.items():
                 old_sha1 = old_refs.get(refname, ZERO_SHA)
                 old_sha1 = old_refs.get(refname, ZERO_SHA)
@@ -2199,7 +2240,7 @@ class BundleClient(GitClient):
 
 
             while line.startswith(b"-"):
             while line.startswith(b"-"):
                 (obj_id, comment) = line[1:].rstrip(b"\n").split(b" ", 1)
                 (obj_id, comment) = line[1:].rstrip(b"\n").split(b" ", 1)
-                prerequisites.append((obj_id, comment.decode("utf-8")))
+                prerequisites.append((obj_id, comment))
                 line = f.readline()
                 line = f.readline()
 
 
             while line != b"\n":
             while line != b"\n":
@@ -2940,7 +2981,11 @@ class AbstractHttpGitClient(GitClient):
         protocol_version: Optional[int] = None,
         protocol_version: Optional[int] = None,
         ref_prefix: Optional[list[Ref]] = None,
         ref_prefix: Optional[list[Ref]] = None,
     ) -> tuple[
     ) -> tuple[
-        dict[Ref, ObjectID], set[bytes], str, dict[Ref, Ref], dict[Ref, ObjectID]
+        dict[Ref, Optional[ObjectID]],
+        set[bytes],
+        str,
+        dict[Ref, Ref],
+        dict[Ref, ObjectID],
     ]:
     ]:
         if (
         if (
             protocol_version is not None
             protocol_version is not None
@@ -3003,10 +3048,10 @@ class AbstractHttpGitClient(GitClient):
                     resp, read = self._smart_request(
                     resp, read = self._smart_request(
                         service.decode("ascii"), base_url, body
                         service.decode("ascii"), base_url, body
                     )
                     )
-                    proto = Protocol(read, None)
+                    proto = Protocol(read, lambda data: None)
                     return server_capabilities, resp, read, proto
                     return server_capabilities, resp, read, proto
 
 
-                proto = Protocol(read, None)  # type: ignore
+                proto = Protocol(read, lambda data: None)
                 server_protocol_version = negotiate_protocol_version(proto)
                 server_protocol_version = negotiate_protocol_version(proto)
                 if server_protocol_version not in GIT_PROTOCOL_VERSIONS:
                 if server_protocol_version not in GIT_PROTOCOL_VERSIONS:
                     raise ValueError(
                     raise ValueError(
@@ -3071,7 +3116,12 @@ class AbstractHttpGitClient(GitClient):
                     if not chunk:
                     if not chunk:
                         break
                         break
                     data += chunk
                     data += chunk
-                (refs, peeled) = split_peeled_refs(read_info_refs(BytesIO(data)))
+                from typing import Optional, cast
+
+                info_refs = read_info_refs(BytesIO(data))
+                (refs, peeled) = split_peeled_refs(
+                    cast(dict[bytes, Optional[bytes]], info_refs)
+                )
                 if ref_prefix is not None:
                 if ref_prefix is not None:
                     refs = filter_ref_prefix(refs, ref_prefix)
                     refs = filter_ref_prefix(refs, ref_prefix)
                 return refs, set(), base_url, {}, peeled
                 return refs, set(), base_url, {}, peeled
@@ -3159,7 +3209,7 @@ class AbstractHttpGitClient(GitClient):
 
 
         resp, read = self._smart_request("git-receive-pack", url, data=body_generator())
         resp, read = self._smart_request("git-receive-pack", url, data=body_generator())
         try:
         try:
-            resp_proto = Protocol(read, None)
+            resp_proto = Protocol(read, lambda data: None)
             ref_status = self._handle_receive_pack_tail(
             ref_status = self._handle_receive_pack_tail(
                 resp_proto, negotiated_capabilities, progress
                 resp_proto, negotiated_capabilities, progress
             )
             )
@@ -3404,7 +3454,7 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
     def _http_request(self, url, headers=None, data=None, raise_for_status=True):
     def _http_request(self, url, headers=None, data=None, raise_for_status=True):
         import urllib3.exceptions
         import urllib3.exceptions
 
 
-        req_headers = self.pool_manager.headers.copy()
+        req_headers = dict(self.pool_manager.headers)
         if headers is not None:
         if headers is not None:
             req_headers.update(headers)
             req_headers.update(headers)
         req_headers["Pragma"] = "no-cache"
         req_headers["Pragma"] = "no-cache"
@@ -3418,10 +3468,10 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
                 request_kwargs["timeout"] = self._timeout
                 request_kwargs["timeout"] = self._timeout
 
 
             if data is None:
             if data is None:
-                resp = self.pool_manager.request("GET", url, **request_kwargs)
+                resp = self.pool_manager.request("GET", url, **request_kwargs)  # type: ignore[arg-type]
             else:
             else:
                 request_kwargs["body"] = data
                 request_kwargs["body"] = data
-                resp = self.pool_manager.request("POST", url, **request_kwargs)
+                resp = self.pool_manager.request("POST", url, **request_kwargs)  # type: ignore[arg-type]
         except urllib3.exceptions.HTTPError as e:
         except urllib3.exceptions.HTTPError as e:
             raise GitProtocolError(str(e)) from e
             raise GitProtocolError(str(e)) from e
 
 
@@ -3435,15 +3485,15 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
             if resp.status != 200:
             if resp.status != 200:
                 raise GitProtocolError(f"unexpected http resp {resp.status} for {url}")
                 raise GitProtocolError(f"unexpected http resp {resp.status} for {url}")
 
 
-        resp.content_type = resp.headers.get("Content-Type")
+        resp.content_type = resp.headers.get("Content-Type")  # type: ignore[attr-defined]
         # Check if geturl() is available (urllib3 version >= 1.23)
         # Check if geturl() is available (urllib3 version >= 1.23)
         try:
         try:
             resp_url = resp.geturl()
             resp_url = resp.geturl()
         except AttributeError:
         except AttributeError:
             # get_redirect_location() is available for urllib3 >= 1.1
             # get_redirect_location() is available for urllib3 >= 1.1
-            resp.redirect_location = resp.get_redirect_location()
+            resp.redirect_location = resp.get_redirect_location()  # type: ignore[attr-defined]
         else:
         else:
-            resp.redirect_location = resp_url if resp_url != url else ""
+            resp.redirect_location = resp_url if resp_url != url else ""  # type: ignore[attr-defined]
         return resp, _wrap_urllib3_exceptions(resp.read)
         return resp, _wrap_urllib3_exceptions(resp.read)
 
 
 
 

+ 9 - 3
dulwich/cloud/gcs.py

@@ -52,7 +52,7 @@ class GcsObjectStore(BucketBasedObjectStore):
         """Return string representation of GcsObjectStore."""
         """Return string representation of GcsObjectStore."""
         return f"{type(self).__name__}({self.bucket!r}, subpath={self.subpath!r})"
         return f"{type(self).__name__}({self.bucket!r}, subpath={self.subpath!r})"
 
 
-    def _remove_pack(self, name: str) -> None:
+    def _remove_pack_by_name(self, name: str) -> None:
         self.bucket.delete_blobs(
         self.bucket.delete_blobs(
             [posixpath.join(self.subpath, name) + "." + ext for ext in ["pack", "idx"]]
             [posixpath.join(self.subpath, name) + "." + ext for ext in ["pack", "idx"]]
         )
         )
@@ -68,10 +68,14 @@ class GcsObjectStore(BucketBasedObjectStore):
 
 
     def _load_pack_data(self, name: str) -> PackData:
     def _load_pack_data(self, name: str) -> PackData:
         b = self.bucket.blob(posixpath.join(self.subpath, name + ".pack"))
         b = self.bucket.blob(posixpath.join(self.subpath, name + ".pack"))
+        from typing import cast
+
+        from ..file import _GitFile
+
         f = tempfile.SpooledTemporaryFile(max_size=PACK_SPOOL_FILE_MAX_SIZE)
         f = tempfile.SpooledTemporaryFile(max_size=PACK_SPOOL_FILE_MAX_SIZE)
         b.download_to_file(f)
         b.download_to_file(f)
         f.seek(0)
         f.seek(0)
-        return PackData(name + ".pack", f)
+        return PackData(name + ".pack", cast(_GitFile, f))
 
 
     def _load_pack_index(self, name: str) -> PackIndex:
     def _load_pack_index(self, name: str) -> PackIndex:
         b = self.bucket.blob(posixpath.join(self.subpath, name + ".idx"))
         b = self.bucket.blob(posixpath.join(self.subpath, name + ".idx"))
@@ -85,7 +89,9 @@ class GcsObjectStore(BucketBasedObjectStore):
             lambda: self._load_pack_data(name), lambda: self._load_pack_index(name)
             lambda: self._load_pack_data(name), lambda: self._load_pack_index(name)
         )
         )
 
 
-    def _upload_pack(self, basename: str, pack_file: BinaryIO, index_file: BinaryIO) -> None:
+    def _upload_pack(
+        self, basename: str, pack_file: BinaryIO, index_file: BinaryIO
+    ) -> None:
         idxblob = self.bucket.blob(posixpath.join(self.subpath, basename + ".idx"))
         idxblob = self.bucket.blob(posixpath.join(self.subpath, basename + ".idx"))
         datablob = self.bucket.blob(posixpath.join(self.subpath, basename + ".pack"))
         datablob = self.bucket.blob(posixpath.join(self.subpath, basename + ".pack"))
         idxblob.upload_from_file(index_file)
         idxblob.upload_from_file(index_file)

+ 1 - 3
dulwich/commit_graph.py

@@ -567,9 +567,7 @@ def write_commit_graph(
 
 
     graph_path = os.path.join(info_dir, b"commit-graph")
     graph_path = os.path.join(info_dir, b"commit-graph")
     with GitFile(graph_path, "wb") as f:
     with GitFile(graph_path, "wb") as f:
-        from typing import BinaryIO, cast
-
-        graph.write_to_file(cast(BinaryIO, f))
+        graph.write_to_file(f)
 
 
 
 
 def get_reachable_commits(
 def get_reachable_commits(

+ 39 - 12
dulwich/config.py

@@ -41,14 +41,11 @@ from contextlib import suppress
 from pathlib import Path
 from pathlib import Path
 from typing import (
 from typing import (
     IO,
     IO,
-    Any,
-    BinaryIO,
     Callable,
     Callable,
     Generic,
     Generic,
     Optional,
     Optional,
     TypeVar,
     TypeVar,
     Union,
     Union,
-    cast,
     overload,
     overload,
 )
 )
 
 
@@ -60,7 +57,7 @@ ConfigValue = Union[str, bytes, bool, int]
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 # Type for file opener callback
 # Type for file opener callback
-FileOpener = Callable[[Union[str, os.PathLike]], BinaryIO]
+FileOpener = Callable[[Union[str, os.PathLike]], IO[bytes]]
 
 
 # Type for includeIf condition matcher
 # Type for includeIf condition matcher
 # Takes the condition value (e.g., "main" for onbranch:main) and returns bool
 # Takes the condition value (e.g., "main" for onbranch:main) and returns bool
@@ -194,7 +191,7 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
           default_factory: Optional factory function for default values
           default_factory: Optional factory function for default values
         """
         """
         self._real: list[tuple[K, V]] = []
         self._real: list[tuple[K, V]] = []
-        self._keyed: dict[Any, V] = {}
+        self._keyed: dict[ConfigKey, V] = {}
         self._default_factory = default_factory
         self._default_factory = default_factory
 
 
     @classmethod
     @classmethod
@@ -239,7 +236,31 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
 
 
     def keys(self) -> KeysView[K]:
     def keys(self) -> KeysView[K]:
         """Return a view of the dictionary's keys."""
         """Return a view of the dictionary's keys."""
-        return self._keyed.keys()  # type: ignore[return-value]
+        # Return a view of the original keys (not lowercased)
+        # We need to deduplicate since _real can have duplicates
+        seen = set()
+        unique_keys = []
+        for k, _ in self._real:
+            lower = lower_key(k)
+            if lower not in seen:
+                seen.add(lower)
+                unique_keys.append(k)
+        from collections.abc import KeysView as ABCKeysView
+
+        class UniqueKeysView(ABCKeysView[K]):
+            def __init__(self, keys: list[K]):
+                self._keys = keys
+
+            def __contains__(self, key: object) -> bool:
+                return key in self._keys
+
+            def __iter__(self):
+                return iter(self._keys)
+
+            def __len__(self) -> int:
+                return len(self._keys)
+
+        return UniqueKeysView(unique_keys)
 
 
     def items(self) -> ItemsView[K, V]:
     def items(self) -> ItemsView[K, V]:
         """Return a view of the dictionary's (key, value) pairs in insertion order."""
         """Return a view of the dictionary's (key, value) pairs in insertion order."""
@@ -267,7 +288,13 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
 
 
     def __iter__(self) -> Iterator[K]:
     def __iter__(self) -> Iterator[K]:
         """Iterate over the dictionary's keys."""
         """Iterate over the dictionary's keys."""
-        return iter(self._keyed)
+        # Return iterator over original keys (not lowercased), deduplicated
+        seen = set()
+        for k, _ in self._real:
+            lower = lower_key(k)
+            if lower not in seen:
+                seen.add(lower)
+                yield k
 
 
     def values(self) -> ValuesView[V]:
     def values(self) -> ValuesView[V]:
         """Return a view of the dictionary's values."""
         """Return a view of the dictionary's values."""
@@ -898,7 +925,7 @@ class ConfigFile(ConfigDict):
     @classmethod
     @classmethod
     def from_file(
     def from_file(
         cls,
         cls,
-        f: BinaryIO,
+        f: IO[bytes],
         *,
         *,
         config_dir: Optional[str] = None,
         config_dir: Optional[str] = None,
         included_paths: Optional[set[str]] = None,
         included_paths: Optional[set[str]] = None,
@@ -1075,8 +1102,8 @@ class ConfigFile(ConfigDict):
             opener: FileOpener
             opener: FileOpener
             if file_opener is None:
             if file_opener is None:
 
 
-                def opener(path: Union[str, os.PathLike]) -> BinaryIO:
-                    return cast(BinaryIO, GitFile(path, "rb"))
+                def opener(path: Union[str, os.PathLike]) -> IO[bytes]:
+                    return GitFile(path, "rb")
             else:
             else:
                 opener = file_opener
                 opener = file_opener
 
 
@@ -1236,8 +1263,8 @@ class ConfigFile(ConfigDict):
         opener: FileOpener
         opener: FileOpener
         if file_opener is None:
         if file_opener is None:
 
 
-            def opener(p: Union[str, os.PathLike]) -> BinaryIO:
-                return cast(BinaryIO, GitFile(p, "rb"))
+            def opener(p: Union[str, os.PathLike]) -> IO[bytes]:
+                return GitFile(p, "rb")
         else:
         else:
             opener = file_opener
             opener = file_opener
 
 

+ 12 - 6
dulwich/contrib/swift.py

@@ -43,6 +43,7 @@ from typing import BinaryIO, Callable, Optional, Union, cast
 
 
 from geventhttpclient import HTTPClient
 from geventhttpclient import HTTPClient
 
 
+from ..file import _GitFile
 from ..greenthreads import GreenThreadsMissingObjectFinder
 from ..greenthreads import GreenThreadsMissingObjectFinder
 from ..lru_cache import LRUSizeCache
 from ..lru_cache import LRUSizeCache
 from ..object_store import INFODIR, PACKDIR, ObjectContainer, PackBasedObjectStore
 from ..object_store import INFODIR, PACKDIR, ObjectContainer, PackBasedObjectStore
@@ -823,7 +824,11 @@ class SwiftObjectStore(PackBasedObjectStore):
               The created SwiftPack or None if empty
               The created SwiftPack or None if empty
             """
             """
             f.seek(0)
             f.seek(0)
-            pack = PackData(file=f, filename="")
+            from typing import cast
+
+            from ..file import _GitFile
+
+            pack = PackData(file=cast(_GitFile, f), filename="")
             entries = pack.sorted_entries()
             entries = pack.sorted_entries()
             if entries:
             if entries:
                 basename = posixpath.join(
                 basename = posixpath.join(
@@ -875,7 +880,8 @@ class SwiftObjectStore(PackBasedObjectStore):
         fd, path = tempfile.mkstemp(prefix="tmp_pack_")
         fd, path = tempfile.mkstemp(prefix="tmp_pack_")
         f = os.fdopen(fd, "w+b")
         f = os.fdopen(fd, "w+b")
         try:
         try:
-            indexer = PackIndexer(f, resolve_ext_ref=None)
+            pack_data = PackData(file=cast(_GitFile, f), filename=path)
+            indexer = PackIndexer(pack_data, resolve_ext_ref=None)
             copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer)
             copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer)
             copier.verify()
             copier.verify()
             return self._complete_thin_pack(f, path, copier, indexer)
             return self._complete_thin_pack(f, path, copier, indexer)
@@ -928,7 +934,7 @@ class SwiftObjectStore(PackBasedObjectStore):
 
 
         # Write pack info.
         # Write pack info.
         f.seek(0)
         f.seek(0)
-        pack_data = PackData(filename="", file=f)
+        pack_data = PackData(filename="", file=cast(_GitFile, f))
         index_file.seek(0)
         index_file.seek(0)
         pack_index = load_pack_index_file("", index_file)
         pack_index = load_pack_index_file("", index_file)
         serialized_pack_info = pack_info_create(pack_data, pack_index)
         serialized_pack_info = pack_info_create(pack_data, pack_index)
@@ -1030,17 +1036,17 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
         del self._refs[name]
         del self._refs[name]
         return True
         return True
 
 
-    def allkeys(self) -> Iterator[bytes]:
+    def allkeys(self) -> set[bytes]:
         """Get all reference names.
         """Get all reference names.
 
 
         Returns:
         Returns:
-          Iterator of reference names as bytes
+          Set of reference names as bytes
         """
         """
         try:
         try:
             self._refs[b"HEAD"] = self._refs[b"refs/heads/master"]
             self._refs[b"HEAD"] = self._refs[b"refs/heads/master"]
         except KeyError:
         except KeyError:
             pass
             pass
-        return iter(self._refs.keys())
+        return set(self._refs.keys())
 
 
 
 
 class SwiftRepo(BaseRepo):
 class SwiftRepo(BaseRepo):

+ 12 - 6
dulwich/diff.py

@@ -47,11 +47,11 @@ Example usage:
 import logging
 import logging
 import os
 import os
 import stat
 import stat
-from typing import BinaryIO, Optional, cast
+from typing import BinaryIO, Optional
 
 
 from .index import ConflictedIndexEntry, commit_index
 from .index import ConflictedIndexEntry, commit_index
 from .object_store import iter_tree_contents
 from .object_store import iter_tree_contents
-from .objects import S_ISGITLINK, Blob
+from .objects import S_ISGITLINK, Blob, Commit
 from .patch import write_blob_diff, write_object_diff
 from .patch import write_blob_diff, write_object_diff
 from .repo import Repo
 from .repo import Repo
 
 
@@ -90,12 +90,16 @@ def diff_index_to_tree(
     if commit_sha is None:
     if commit_sha is None:
         try:
         try:
             commit_sha = repo.refs[b"HEAD"]
             commit_sha = repo.refs[b"HEAD"]
-            old_tree = repo[commit_sha].tree
+            old_commit = repo[commit_sha]
+            assert isinstance(old_commit, Commit)
+            old_tree = old_commit.tree
         except KeyError:
         except KeyError:
             # No HEAD means no commits yet
             # No HEAD means no commits yet
             old_tree = None
             old_tree = None
     else:
     else:
-        old_tree = repo[commit_sha].tree
+        old_commit = repo[commit_sha]
+        assert isinstance(old_commit, Commit)
+        old_tree = old_commit.tree
 
 
     # Get tree from index
     # Get tree from index
     index = repo.open_index()
     index = repo.open_index()
@@ -125,7 +129,9 @@ def diff_working_tree_to_tree(
         commit_sha: SHA of commit to compare against
         commit_sha: SHA of commit to compare against
         paths: Optional list of paths to filter (as bytes)
         paths: Optional list of paths to filter (as bytes)
     """
     """
-    tree = repo[commit_sha].tree
+    commit = repo[commit_sha]
+    assert isinstance(commit, Commit)
+    tree = commit.tree
     normalizer = repo.get_blob_normalizer()
     normalizer = repo.get_blob_normalizer()
     filter_callback = normalizer.checkin_normalize
     filter_callback = normalizer.checkin_normalize
 
 
@@ -382,7 +388,7 @@ def diff_working_tree_to_index(
         old_obj = repo.object_store[old_sha]
         old_obj = repo.object_store[old_sha]
         # Type check and cast to Blob
         # Type check and cast to Blob
         if isinstance(old_obj, Blob):
         if isinstance(old_obj, Blob):
-            old_blob = cast(Blob, old_obj)
+            old_blob = old_obj
         else:
         else:
             old_blob = None
             old_blob = None
 
 

+ 4 - 4
dulwich/dumb.py

@@ -24,7 +24,7 @@
 import os
 import os
 import tempfile
 import tempfile
 import zlib
 import zlib
-from collections.abc import Iterator
+from collections.abc import Iterator, Sequence
 from io import BytesIO
 from io import BytesIO
 from typing import Any, Callable, Optional
 from typing import Any, Callable, Optional
 from urllib.parse import urljoin
 from urllib.parse import urljoin
@@ -338,9 +338,9 @@ class DumbHTTPObjectStore(BaseObjectStore):
 
 
     def add_objects(
     def add_objects(
         self,
         self,
-        objects: Iterator[ShaFile],
-        progress: Optional[Callable[[int], None]] = None,
-    ) -> None:
+        objects: Sequence[tuple[ShaFile, Optional[str]]],
+        progress: Optional[Callable[[str], None]] = None,
+    ) -> Optional["Pack"]:
         """Add a set of objects to this object store."""
         """Add a set of objects to this object store."""
         raise NotImplementedError("Cannot add objects to dumb HTTP repository")
         raise NotImplementedError("Cannot add objects to dumb HTTP repository")
 
 

+ 53 - 9
dulwich/file.py

@@ -24,10 +24,15 @@
 import os
 import os
 import sys
 import sys
 import warnings
 import warnings
-from collections.abc import Iterator
+from collections.abc import Iterable, Iterator
 from types import TracebackType
 from types import TracebackType
 from typing import IO, Any, ClassVar, Literal, Optional, Union, overload
 from typing import IO, Any, ClassVar, Literal, Optional, Union, overload
 
 
+if sys.version_info >= (3, 12):
+    from collections.abc import Buffer
+else:
+    Buffer = Union[bytes, bytearray, memoryview]
+
 
 
 def ensure_dir_exists(dirname: Union[str, bytes, os.PathLike]) -> None:
 def ensure_dir_exists(dirname: Union[str, bytes, os.PathLike]) -> None:
     """Ensure a directory exists, creating if necessary."""
     """Ensure a directory exists, creating if necessary."""
@@ -136,7 +141,7 @@ class FileLocked(Exception):
         super().__init__(filename, lockfilename)
         super().__init__(filename, lockfilename)
 
 
 
 
-class _GitFile:
+class _GitFile(IO[bytes]):
     """File that follows the git locking protocol for writes.
     """File that follows the git locking protocol for writes.
 
 
     All writes to a file foo will be written into foo.lock in the same
     All writes to a file foo will be written into foo.lock in the same
@@ -148,7 +153,6 @@ class _GitFile:
     """
     """
 
 
     PROXY_PROPERTIES: ClassVar[set[str]] = {
     PROXY_PROPERTIES: ClassVar[set[str]] = {
-        "closed",
         "encoding",
         "encoding",
         "errors",
         "errors",
         "mode",
         "mode",
@@ -158,15 +162,19 @@ class _GitFile:
     }
     }
     PROXY_METHODS: ClassVar[set[str]] = {
     PROXY_METHODS: ClassVar[set[str]] = {
         "__iter__",
         "__iter__",
+        "__next__",
         "flush",
         "flush",
         "fileno",
         "fileno",
         "isatty",
         "isatty",
         "read",
         "read",
+        "readable",
         "readline",
         "readline",
         "readlines",
         "readlines",
         "seek",
         "seek",
+        "seekable",
         "tell",
         "tell",
         "truncate",
         "truncate",
+        "writable",
         "write",
         "write",
         "writelines",
         "writelines",
     }
     }
@@ -195,9 +203,6 @@ class _GitFile:
         self._file = os.fdopen(fd, mode, bufsize)
         self._file = os.fdopen(fd, mode, bufsize)
         self._closed = False
         self._closed = False
 
 
-        for method in self.PROXY_METHODS:
-            setattr(self, method, getattr(self._file, method))
-
     def __iter__(self) -> Iterator[bytes]:
     def __iter__(self) -> Iterator[bytes]:
         """Iterate over lines in the file."""
         """Iterate over lines in the file."""
         return iter(self._file)
         return iter(self._file)
@@ -267,20 +272,59 @@ class _GitFile:
         else:
         else:
             self.close()
             self.close()
 
 
+    @property
+    def closed(self) -> bool:
+        """Return whether the file is closed."""
+        return self._closed
+
     def __getattr__(self, name: str) -> Any:  # noqa: ANN401
     def __getattr__(self, name: str) -> Any:  # noqa: ANN401
         """Proxy property calls to the underlying file."""
         """Proxy property calls to the underlying file."""
         if name in self.PROXY_PROPERTIES:
         if name in self.PROXY_PROPERTIES:
             return getattr(self._file, name)
             return getattr(self._file, name)
         raise AttributeError(name)
         raise AttributeError(name)
 
 
+    # Implement IO[bytes] methods by delegating to the underlying file
+    def read(self, size: int = -1) -> bytes:
+        return self._file.read(size)
+
+    def write(self, data: Buffer, /) -> int:
+        return self._file.write(data)
+
+    def readline(self, size: int = -1) -> bytes:
+        return self._file.readline(size)
+
+    def readlines(self, hint: int = -1) -> list[bytes]:
+        return self._file.readlines(hint)
+
+    def writelines(self, lines: Iterable[Buffer], /) -> None:
+        return self._file.writelines(lines)
+
+    def seek(self, offset: int, whence: int = 0) -> int:
+        return self._file.seek(offset, whence)
+
+    def tell(self) -> int:
+        return self._file.tell()
+
+    def flush(self) -> None:
+        return self._file.flush()
+
+    def truncate(self, size: Optional[int] = None) -> int:
+        return self._file.truncate(size)
+
+    def fileno(self) -> int:
+        return self._file.fileno()
+
+    def isatty(self) -> bool:
+        return self._file.isatty()
+
     def readable(self) -> bool:
     def readable(self) -> bool:
-        """Return whether the file is readable."""
         return self._file.readable()
         return self._file.readable()
 
 
     def writable(self) -> bool:
     def writable(self) -> bool:
-        """Return whether the file is writable."""
         return self._file.writable()
         return self._file.writable()
 
 
     def seekable(self) -> bool:
     def seekable(self) -> bool:
-        """Return whether the file is seekable."""
         return self._file.seekable()
         return self._file.seekable()
+
+    def __next__(self) -> bytes:
+        return next(iter(self._file))

+ 15 - 7
dulwich/filters.py

@@ -30,7 +30,7 @@ from .objects import Blob
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .config import StackedConfig
     from .config import StackedConfig
-    from .repo import Repo
+    from .repo import BaseRepo
 
 
 
 
 class FilterError(Exception):
 class FilterError(Exception):
@@ -128,7 +128,9 @@ class FilterRegistry:
     """Registry for filter drivers."""
     """Registry for filter drivers."""
 
 
     def __init__(
     def __init__(
-        self, config: Optional["StackedConfig"] = None, repo: Optional["Repo"] = None
+        self,
+        config: Optional["StackedConfig"] = None,
+        repo: Optional["BaseRepo"] = None,
     ) -> None:
     ) -> None:
         """Initialize FilterRegistry.
         """Initialize FilterRegistry.
 
 
@@ -211,8 +213,12 @@ class FilterRegistry:
         required = self.config.get_boolean(("filter", name), "required", False)
         required = self.config.get_boolean(("filter", name), "required", False)
 
 
         if clean_cmd or smudge_cmd:
         if clean_cmd or smudge_cmd:
-            # Get repository working directory
-            repo_path = self.repo.path if self.repo else None
+            # Get repository working directory (only for Repo, not BaseRepo)
+            from .repo import Repo
+
+            repo_path = (
+                self.repo.path if self.repo and isinstance(self.repo, Repo) else None
+            )
             return ProcessFilterDriver(clean_cmd, smudge_cmd, required, repo_path)
             return ProcessFilterDriver(clean_cmd, smudge_cmd, required, repo_path)
 
 
         return None
         return None
@@ -221,8 +227,10 @@ class FilterRegistry:
         """Create LFS filter driver."""
         """Create LFS filter driver."""
         from .lfs import LFSFilterDriver, LFSStore
         from .lfs import LFSFilterDriver, LFSStore
 
 
-        # If we have a repo, use its LFS store
-        if registry.repo is not None:
+        # If we have a Repo (not just BaseRepo), use its LFS store
+        from .repo import Repo
+
+        if registry.repo is not None and isinstance(registry.repo, Repo):
             lfs_store = LFSStore.from_repo(registry.repo, create=True)
             lfs_store = LFSStore.from_repo(registry.repo, create=True)
         else:
         else:
             # Fall back to creating a temporary LFS store
             # Fall back to creating a temporary LFS store
@@ -389,7 +397,7 @@ class FilterBlobNormalizer:
         config_stack: Optional["StackedConfig"],
         config_stack: Optional["StackedConfig"],
         gitattributes: GitAttributes,
         gitattributes: GitAttributes,
         filter_registry: Optional[FilterRegistry] = None,
         filter_registry: Optional[FilterRegistry] = None,
-        repo: Optional["Repo"] = None,
+        repo: Optional["BaseRepo"] = None,
     ) -> None:
     ) -> None:
         """Initialize FilterBlobNormalizer.
         """Initialize FilterBlobNormalizer.
 
 

+ 24 - 22
dulwich/index.py

@@ -32,13 +32,13 @@ from collections.abc import Generator, Iterable, Iterator
 from dataclasses import dataclass
 from dataclasses import dataclass
 from enum import Enum
 from enum import Enum
 from typing import (
 from typing import (
+    IO,
     TYPE_CHECKING,
     TYPE_CHECKING,
     Any,
     Any,
     BinaryIO,
     BinaryIO,
     Callable,
     Callable,
     Optional,
     Optional,
     Union,
     Union,
-    cast,
 )
 )
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -588,7 +588,7 @@ def read_cache_time(f: BinaryIO) -> tuple[int, int]:
     return struct.unpack(">LL", f.read(8))
     return struct.unpack(">LL", f.read(8))
 
 
 
 
-def write_cache_time(f: BinaryIO, t: Union[int, float, tuple[int, int]]) -> None:
+def write_cache_time(f: IO[bytes], t: Union[int, float, tuple[int, int]]) -> None:
     """Write a cache time.
     """Write a cache time.
 
 
     Args:
     Args:
@@ -664,7 +664,7 @@ def read_cache_entry(
 
 
 
 
 def write_cache_entry(
 def write_cache_entry(
-    f: BinaryIO, entry: SerializedIndexEntry, version: int, previous_path: bytes = b""
+    f: IO[bytes], entry: SerializedIndexEntry, version: int, previous_path: bytes = b""
 ) -> None:
 ) -> None:
     """Write an index entry to a file.
     """Write an index entry to a file.
 
 
@@ -745,7 +745,7 @@ def read_index_header(f: BinaryIO) -> tuple[int, int]:
     return version, num_entries
     return version, num_entries
 
 
 
 
-def write_index_extension(f: BinaryIO, extension: IndexExtension) -> None:
+def write_index_extension(f: IO[bytes], extension: IndexExtension) -> None:
     """Write an index extension.
     """Write an index extension.
 
 
     Args:
     Args:
@@ -868,7 +868,7 @@ def read_index_dict(
 
 
 
 
 def write_index(
 def write_index(
-    f: BinaryIO,
+    f: IO[bytes],
     entries: list[SerializedIndexEntry],
     entries: list[SerializedIndexEntry],
     version: Optional[int] = None,
     version: Optional[int] = None,
     extensions: Optional[list[IndexExtension]] = None,
     extensions: Optional[list[IndexExtension]] = None,
@@ -909,7 +909,7 @@ def write_index(
 
 
 
 
 def write_index_dict(
 def write_index_dict(
-    f: BinaryIO,
+    f: IO[bytes],
     entries: dict[bytes, Union[IndexEntry, ConflictedIndexEntry]],
     entries: dict[bytes, Union[IndexEntry, ConflictedIndexEntry]],
     version: Optional[int] = None,
     version: Optional[int] = None,
     extensions: Optional[list[IndexExtension]] = None,
     extensions: Optional[list[IndexExtension]] = None,
@@ -1007,8 +1007,6 @@ class Index:
 
 
     def write(self) -> None:
     def write(self) -> None:
         """Write current contents of index to disk."""
         """Write current contents of index to disk."""
-        from typing import BinaryIO, cast
-
         f = GitFile(self._filename, "wb")
         f = GitFile(self._filename, "wb")
         try:
         try:
             # Filter out extensions with no meaningful data
             # Filter out extensions with no meaningful data
@@ -1022,7 +1020,7 @@ class Index:
             if self._skip_hash:
             if self._skip_hash:
                 # When skipHash is enabled, write the index without computing SHA1
                 # When skipHash is enabled, write the index without computing SHA1
                 write_index_dict(
                 write_index_dict(
-                    cast(BinaryIO, f),
+                    f,
                     self._byname,
                     self._byname,
                     version=self._version,
                     version=self._version,
                     extensions=meaningful_extensions,
                     extensions=meaningful_extensions,
@@ -1031,9 +1029,9 @@ class Index:
                 f.write(b"\x00" * 20)
                 f.write(b"\x00" * 20)
                 f.close()
                 f.close()
             else:
             else:
-                sha1_writer = SHA1Writer(cast(BinaryIO, f))
+                sha1_writer = SHA1Writer(f)
                 write_index_dict(
                 write_index_dict(
-                    cast(BinaryIO, sha1_writer),
+                    sha1_writer,
                     self._byname,
                     self._byname,
                     version=self._version,
                     version=self._version,
                     extensions=meaningful_extensions,
                     extensions=meaningful_extensions,
@@ -1050,9 +1048,7 @@ class Index:
         f = GitFile(self._filename, "rb")
         f = GitFile(self._filename, "rb")
         try:
         try:
             sha1_reader = SHA1Reader(f)
             sha1_reader = SHA1Reader(f)
-            entries, version, extensions = read_index_dict_with_version(
-                cast(BinaryIO, sha1_reader)
-            )
+            entries, version, extensions = read_index_dict_with_version(sha1_reader)
             self._version = version
             self._version = version
             self._extensions = extensions
             self._extensions = extensions
             self.update(entries)
             self.update(entries)
@@ -1411,7 +1407,9 @@ def build_file_from_blob(
     *,
     *,
     honor_filemode: bool = True,
     honor_filemode: bool = True,
     tree_encoding: str = "utf-8",
     tree_encoding: str = "utf-8",
-    symlink_fn: Optional[Callable] = None,
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ] = None,
 ) -> os.stat_result:
 ) -> os.stat_result:
     """Build a file or symlink on disk based on a Git object.
     """Build a file or symlink on disk based on a Git object.
 
 
@@ -1596,7 +1594,9 @@ def build_index_from_tree(
     tree_id: bytes,
     tree_id: bytes,
     honor_filemode: bool = True,
     honor_filemode: bool = True,
     validate_path_element: Callable[[bytes], bool] = validate_path_element_default,
     validate_path_element: Callable[[bytes], bool] = validate_path_element_default,
-    symlink_fn: Optional[Callable] = None,
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ] = None,
     blob_normalizer: Optional["BlobNormalizer"] = None,
     blob_normalizer: Optional["BlobNormalizer"] = None,
     tree_encoding: str = "utf-8",
     tree_encoding: str = "utf-8",
 ) -> None:
 ) -> None:
@@ -1956,7 +1956,9 @@ def _transition_to_file(
     entry: IndexEntry,
     entry: IndexEntry,
     index: Index,
     index: Index,
     honor_filemode: bool,
     honor_filemode: bool,
-    symlink_fn: Optional[Callable[[bytes, bytes], None]],
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ],
     blob_normalizer: Optional["BlobNormalizer"],
     blob_normalizer: Optional["BlobNormalizer"],
     tree_encoding: str = "utf-8",
     tree_encoding: str = "utf-8",
 ) -> None:
 ) -> None:
@@ -2208,7 +2210,9 @@ def update_working_tree(
     change_iterator: Iterator["TreeChange"],
     change_iterator: Iterator["TreeChange"],
     honor_filemode: bool = True,
     honor_filemode: bool = True,
     validate_path_element: Optional[Callable[[bytes], bool]] = None,
     validate_path_element: Optional[Callable[[bytes], bool]] = None,
-    symlink_fn: Optional[Callable] = None,
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ] = None,
     force_remove_untracked: bool = False,
     force_remove_untracked: bool = False,
     blob_normalizer: Optional["BlobNormalizer"] = None,
     blob_normalizer: Optional["BlobNormalizer"] = None,
     tree_encoding: str = "utf-8",
     tree_encoding: str = "utf-8",
@@ -2685,10 +2689,8 @@ class locked_index:
             self._file.abort()
             self._file.abort()
             return
             return
         try:
         try:
-            from typing import BinaryIO, cast
-
-            f = SHA1Writer(cast(BinaryIO, self._file))
-            write_index_dict(cast(BinaryIO, f), self._index._byname)
+            f = SHA1Writer(self._file)
+            write_index_dict(f, self._index._byname)
         except BaseException:
         except BaseException:
             self._file.abort()
             self._file.abort()
         else:
         else:

+ 110 - 53
dulwich/object_store.py

@@ -33,11 +33,11 @@ from collections.abc import Iterable, Iterator, Sequence
 from contextlib import suppress
 from contextlib import suppress
 from io import BytesIO
 from io import BytesIO
 from typing import (
 from typing import (
+    TYPE_CHECKING,
     Callable,
     Callable,
     Optional,
     Optional,
     Protocol,
     Protocol,
     Union,
     Union,
-    cast,
 )
 )
 
 
 from .errors import NotTreeError
 from .errors import NotTreeError
@@ -82,6 +82,18 @@ from .pack import (
 from .protocol import DEPTH_INFINITE
 from .protocol import DEPTH_INFINITE
 from .refs import PEELED_TAG_SUFFIX, Ref
 from .refs import PEELED_TAG_SUFFIX, Ref
 
 
+if TYPE_CHECKING:
+    from .commit_graph import CommitGraph
+    from .diff_tree import RenameDetector
+
+
+class GraphWalker(Protocol):
+    """Protocol for graph walker objects."""
+
+    def __next__(self) -> Optional[bytes]: ...
+    def ack(self, sha: bytes) -> None: ...
+
+
 INFODIR = "info"
 INFODIR = "info"
 PACKDIR = "pack"
 PACKDIR = "pack"
 
 
@@ -95,7 +107,9 @@ PACK_MODE = 0o444 if sys.platform != "win32" else 0o644
 DEFAULT_TEMPFILE_GRACE_PERIOD = 14 * 24 * 60 * 60  # 2 weeks
 DEFAULT_TEMPFILE_GRACE_PERIOD = 14 * 24 * 60 * 60  # 2 weeks
 
 
 
 
-def find_shallow(store: 'BaseObjectStore', heads: Any, depth: int) -> tuple:
+def find_shallow(
+    store: ObjectContainer, heads: Iterable[bytes], depth: int
+) -> tuple[set[bytes], set[bytes]]:
     """Find shallow commits according to a given depth.
     """Find shallow commits according to a given depth.
 
 
     Args:
     Args:
@@ -107,7 +121,7 @@ def find_shallow(store: 'BaseObjectStore', heads: Any, depth: int) -> tuple:
         considered shallow and unshallow according to the arguments. Note that
         considered shallow and unshallow according to the arguments. Note that
         these sets may overlap if a commit is reachable along multiple paths.
         these sets may overlap if a commit is reachable along multiple paths.
     """
     """
-    parents = {}
+    parents: dict[bytes, list[bytes]] = {}
     commit_graph = store.get_commit_graph()
     commit_graph = store.get_commit_graph()
 
 
     def get_parents(sha: bytes) -> list[bytes]:
     def get_parents(sha: bytes) -> list[bytes]:
@@ -121,7 +135,9 @@ def find_shallow(store: 'BaseObjectStore', heads: Any, depth: int) -> tuple:
                     parents[sha] = result
                     parents[sha] = result
                     return result
                     return result
             # Fall back to loading the object
             # Fall back to loading the object
-            result = store[sha].parents
+            commit = store[sha]
+            assert isinstance(commit, Commit)
+            result = commit.parents
             parents[sha] = result
             parents[sha] = result
         return result
         return result
 
 
@@ -150,7 +166,7 @@ def find_shallow(store: 'BaseObjectStore', heads: Any, depth: int) -> tuple:
 
 
 
 
 def get_depth(
 def get_depth(
-    store: 'BaseObjectStore',
+    store: ObjectContainer,
     head: bytes,
     head: bytes,
     get_parents: Callable = lambda commit: commit.parents,
     get_parents: Callable = lambda commit: commit.parents,
     max_depth: Optional[int] = None,
     max_depth: Optional[int] = None,
@@ -233,7 +249,7 @@ class BaseObjectStore:
         return self.contains_loose(sha1)
         return self.contains_loose(sha1)
 
 
     @property
     @property
-    def packs(self) -> Any:
+    def packs(self) -> list[Pack]:
         """Iterable of pack objects."""
         """Iterable of pack objects."""
         raise NotImplementedError
         raise NotImplementedError
 
 
@@ -251,15 +267,19 @@ class BaseObjectStore:
         type_num, uncomp = self.get_raw(sha1)
         type_num, uncomp = self.get_raw(sha1)
         return ShaFile.from_raw_string(type_num, uncomp, sha=sha1)
         return ShaFile.from_raw_string(type_num, uncomp, sha=sha1)
 
 
-    def __iter__(self) -> Any:
+    def __iter__(self) -> Iterator[bytes]:
         """Iterate over the SHAs that are present in this store."""
         """Iterate over the SHAs that are present in this store."""
         raise NotImplementedError(self.__iter__)
         raise NotImplementedError(self.__iter__)
 
 
-    def add_object(self, obj: Any) -> None:
+    def add_object(self, obj: ShaFile) -> None:
         """Add a single object to this object store."""
         """Add a single object to this object store."""
         raise NotImplementedError(self.add_object)
         raise NotImplementedError(self.add_object)
 
 
-    def add_objects(self, objects: Any, progress: Optional[Callable] = None) -> None:
+    def add_objects(
+        self,
+        objects: Sequence[tuple[ShaFile, Optional[str]]],
+        progress: Optional[Callable] = None,
+    ) -> Optional["Pack"]:
         """Add a set of objects to this object store.
         """Add a set of objects to this object store.
 
 
         Args:
         Args:
@@ -275,9 +295,15 @@ class BaseObjectStore:
         want_unchanged: bool = False,
         want_unchanged: bool = False,
         include_trees: bool = False,
         include_trees: bool = False,
         change_type_same: bool = False,
         change_type_same: bool = False,
-        rename_detector: Optional[Any] = None,
-        paths: Optional[Any] = None,
-    ) -> Any:
+        rename_detector: Optional["RenameDetector"] = None,
+        paths: Optional[list[bytes]] = None,
+    ) -> Iterator[
+        tuple[
+            tuple[Optional[bytes], Optional[bytes]],
+            tuple[Optional[int], Optional[int]],
+            tuple[Optional[bytes], Optional[bytes]],
+        ]
+    ]:
         """Find the differences between the contents of two trees.
         """Find the differences between the contents of two trees.
 
 
         Args:
         Args:
@@ -310,7 +336,9 @@ class BaseObjectStore:
                 (change.old.sha, change.new.sha),
                 (change.old.sha, change.new.sha),
             )
             )
 
 
-    def iter_tree_contents(self, tree_id: bytes, include_trees: bool = False) -> Any:
+    def iter_tree_contents(
+        self, tree_id: bytes, include_trees: bool = False
+    ) -> Iterator[tuple[bytes, int, bytes]]:
         """Iterate the contents of a tree and all subtrees.
         """Iterate the contents of a tree and all subtrees.
 
 
         Iteration is depth-first pre-order, as in e.g. os.walk.
         Iteration is depth-first pre-order, as in e.g. os.walk.
@@ -352,13 +380,13 @@ class BaseObjectStore:
 
 
     def find_missing_objects(
     def find_missing_objects(
         self,
         self,
-        haves: Any,
-        wants: Any,
-        shallow: Optional[Any] = None,
+        haves: Iterable[bytes],
+        wants: Iterable[bytes],
+        shallow: Optional[set[bytes]] = None,
         progress: Optional[Callable] = None,
         progress: Optional[Callable] = None,
         get_tagged: Optional[Callable] = None,
         get_tagged: Optional[Callable] = None,
         get_parents: Callable = lambda commit: commit.parents,
         get_parents: Callable = lambda commit: commit.parents,
-    ) -> Any:
+    ) -> Iterator[tuple[bytes, Optional[bytes]]]:
         """Find the missing objects required for a set of revisions.
         """Find the missing objects required for a set of revisions.
 
 
         Args:
         Args:
@@ -385,7 +413,7 @@ class BaseObjectStore:
         )
         )
         return iter(finder)
         return iter(finder)
 
 
-    def find_common_revisions(self, graphwalker: Any) -> list[bytes]:
+    def find_common_revisions(self, graphwalker: GraphWalker) -> list[bytes]:
         """Find which revisions this store has in common using graphwalker.
         """Find which revisions this store has in common using graphwalker.
 
 
         Args:
         Args:
@@ -402,7 +430,12 @@ class BaseObjectStore:
         return haves
         return haves
 
 
     def generate_pack_data(
     def generate_pack_data(
-        self, have: Any, want: Any, shallow: Optional[Any] = None, progress: Optional[Callable] = None, ofs_delta: bool = True
+        self,
+        have: Iterable[bytes],
+        want: Iterable[bytes],
+        shallow: Optional[set[bytes]] = None,
+        progress: Optional[Callable] = None,
+        ofs_delta: bool = True,
     ) -> tuple[int, Iterator[UnpackedObject]]:
     ) -> tuple[int, Iterator[UnpackedObject]]:
         """Generate pack data objects for a set of wants/haves.
         """Generate pack data objects for a set of wants/haves.
 
 
@@ -439,7 +472,7 @@ class BaseObjectStore:
             DeprecationWarning,
             DeprecationWarning,
             stacklevel=2,
             stacklevel=2,
         )
         )
-        return peel_sha(self, sha)[1]
+        return peel_sha(self, sha)[1].id
 
 
     def _get_depth(
     def _get_depth(
         self,
         self,
@@ -486,7 +519,7 @@ class BaseObjectStore:
             if sha.startswith(prefix):
             if sha.startswith(prefix):
                 yield sha
                 yield sha
 
 
-    def get_commit_graph(self) -> Optional[Any]:
+    def get_commit_graph(self) -> Optional["CommitGraph"]:
         """Get the commit graph for this object store.
         """Get the commit graph for this object store.
 
 
         Returns:
         Returns:
@@ -494,7 +527,9 @@ class BaseObjectStore:
         """
         """
         return None
         return None
 
 
-    def write_commit_graph(self, refs: Optional[Any] = None, reachable: bool = True) -> None:
+    def write_commit_graph(
+        self, refs: Optional[list[bytes]] = None, reachable: bool = True
+    ) -> None:
         """Write a commit graph file for this object store.
         """Write a commit graph file for this object store.
 
 
         Args:
         Args:
@@ -571,8 +606,11 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         raise NotImplementedError(self.add_pack)
         raise NotImplementedError(self.add_pack)
 
 
     def add_pack_data(
     def add_pack_data(
-        self, count: int, unpacked_objects: Iterator[UnpackedObject], progress: Optional[Callable] = None
-    ) -> None:
+        self,
+        count: int,
+        unpacked_objects: Iterator[UnpackedObject],
+        progress: Optional[Callable] = None,
+    ) -> Optional["Pack"]:
         """Add pack data to this object store.
         """Add pack data to this object store.
 
 
         Args:
         Args:
@@ -582,7 +620,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         """
         """
         if count == 0:
         if count == 0:
             # Don't bother writing an empty pack file
             # Don't bother writing an empty pack file
-            return
+            return None
         f, commit, abort = self.add_pack()
         f, commit, abort = self.add_pack()
         try:
         try:
             write_pack_data(
             write_pack_data(
@@ -627,7 +665,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
                 return True
                 return True
         return False
         return False
 
 
-    def _add_cached_pack(self, base_name: str, pack: Any) -> None:
+    def _add_cached_pack(self, base_name: str, pack: Pack) -> None:
         """Add a newly appeared pack to the cache by path."""
         """Add a newly appeared pack to the cache by path."""
         prev_pack = self._pack_cache.get(base_name)
         prev_pack = self._pack_cache.get(base_name)
         if prev_pack is not pack:
         if prev_pack is not pack:
@@ -653,7 +691,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         remote_has = missing_objects.get_remote_has()
         remote_has = missing_objects.get_remote_has()
         object_ids = list(missing_objects)
         object_ids = list(missing_objects)
         return len(object_ids), generate_unpacked_objects(
         return len(object_ids), generate_unpacked_objects(
-            cast(PackedObjectContainer, self),
+            self,
             object_ids,
             object_ids,
             progress=progress,
             progress=progress,
             ofs_delta=ofs_delta,
             ofs_delta=ofs_delta,
@@ -667,8 +705,8 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
             (name, pack) = pack_cache.popitem()
             (name, pack) = pack_cache.popitem()
             pack.close()
             pack.close()
 
 
-    def _iter_cached_packs(self) -> Any:
-        return self._pack_cache.values()
+    def _iter_cached_packs(self) -> Iterator[Pack]:
+        return iter(self._pack_cache.values())
 
 
     def _update_pack_cache(self) -> list[Pack]:
     def _update_pack_cache(self) -> list[Pack]:
         raise NotImplementedError(self._update_pack_cache)
         raise NotImplementedError(self._update_pack_cache)
@@ -681,7 +719,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         self._clear_cached_packs()
         self._clear_cached_packs()
 
 
     @property
     @property
-    def packs(self) -> Any:
+    def packs(self) -> list[Pack]:
         """List with pack objects."""
         """List with pack objects."""
         return list(self._iter_cached_packs()) + list(self._update_pack_cache())
         return list(self._iter_cached_packs()) + list(self._update_pack_cache())
 
 
@@ -699,12 +737,12 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
                 count += 1
                 count += 1
         return count
         return count
 
 
-    def _iter_alternate_objects(self) -> Any:
+    def _iter_alternate_objects(self) -> Iterator[bytes]:
         """Iterate over the SHAs of all the objects in alternate stores."""
         """Iterate over the SHAs of all the objects in alternate stores."""
         for alternate in self.alternates:
         for alternate in self.alternates:
             yield from alternate
             yield from alternate
 
 
-    def _iter_loose_objects(self) -> Any:
+    def _iter_loose_objects(self) -> Iterator[bytes]:
         """Iterate over the SHAs of all loose objects."""
         """Iterate over the SHAs of all loose objects."""
         raise NotImplementedError(self._iter_loose_objects)
         raise NotImplementedError(self._iter_loose_objects)
 
 
@@ -719,7 +757,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         """
         """
         raise NotImplementedError(self.delete_loose_object)
         raise NotImplementedError(self.delete_loose_object)
 
 
-    def _remove_pack(self, name: str) -> None:
+    def _remove_pack(self, pack: "Pack") -> None:
         raise NotImplementedError(self._remove_pack)
         raise NotImplementedError(self._remove_pack)
 
 
     def pack_loose_objects(self) -> int:
     def pack_loose_objects(self) -> int:
@@ -727,15 +765,17 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
 
 
         Returns: Number of objects packed
         Returns: Number of objects packed
         """
         """
-        objects = set()
+        objects: list[tuple[ShaFile, None]] = []
         for sha in self._iter_loose_objects():
         for sha in self._iter_loose_objects():
-            objects.add((self._get_loose_object(sha), None))
-        self.add_objects(list(objects))
+            obj = self._get_loose_object(sha)
+            if obj is not None:
+                objects.append((obj, None))
+        self.add_objects(objects)
         for obj, path in objects:
         for obj, path in objects:
             self.delete_loose_object(obj.id)
             self.delete_loose_object(obj.id)
         return len(objects)
         return len(objects)
 
 
-    def repack(self, exclude: Optional[set] = None) -> None:
+    def repack(self, exclude: Optional[set] = None) -> int:
         """Repack the packs in this repository.
         """Repack the packs in this repository.
 
 
         Note that this implementation is fairly naive and currently keeps all
         Note that this implementation is fairly naive and currently keeps all
@@ -751,11 +791,13 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         excluded_loose_objects = set()
         excluded_loose_objects = set()
         for sha in self._iter_loose_objects():
         for sha in self._iter_loose_objects():
             if sha not in exclude:
             if sha not in exclude:
-                loose_objects.add(self._get_loose_object(sha))
+                obj = self._get_loose_object(sha)
+                if obj is not None:
+                    loose_objects.add(obj)
             else:
             else:
                 excluded_loose_objects.add(sha)
                 excluded_loose_objects.add(sha)
 
 
-        objects = {(obj, None) for obj in loose_objects}
+        objects: set[tuple[ShaFile, None]] = {(obj, None) for obj in loose_objects}
         old_packs = {p.name(): p for p in self.packs}
         old_packs = {p.name(): p for p in self.packs}
         for name, pack in old_packs.items():
         for name, pack in old_packs.items():
             objects.update(
             objects.update(
@@ -767,12 +809,14 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
             # The name of the consolidated pack might match the name of a
             # The name of the consolidated pack might match the name of a
             # pre-existing pack. Take care not to remove the newly created
             # pre-existing pack. Take care not to remove the newly created
             # consolidated pack.
             # consolidated pack.
-            consolidated = self.add_objects(objects)
-            old_packs.pop(consolidated.name(), None)
+            consolidated = self.add_objects(list(objects))
+            if consolidated is not None:
+                old_packs.pop(consolidated.name(), None)
 
 
         # Delete loose objects that were packed
         # Delete loose objects that were packed
         for obj in loose_objects:
         for obj in loose_objects:
-            self.delete_loose_object(obj.id)
+            if obj is not None:
+                self.delete_loose_object(obj.id)
         # Delete excluded loose objects
         # Delete excluded loose objects
         for sha in excluded_loose_objects:
         for sha in excluded_loose_objects:
             self.delete_loose_object(sha)
             self.delete_loose_object(sha)
@@ -928,9 +972,9 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
                 yield o
                 yield o
                 todo.remove(o.id)
                 todo.remove(o.id)
         for oid in todo:
         for oid in todo:
-            o = self._get_loose_object(oid)
-            if o is not None:
-                yield o
+            loose_obj: Optional[ShaFile] = self._get_loose_object(oid)
+            if loose_obj is not None:
+                yield loose_obj
             elif not allow_missing:
             elif not allow_missing:
                 raise KeyError(oid)
                 raise KeyError(oid)
 
 
@@ -978,7 +1022,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         self,
         self,
         objects: Sequence[tuple[ShaFile, Optional[str]]],
         objects: Sequence[tuple[ShaFile, Optional[str]]],
         progress: Optional[Callable[[str], None]] = None,
         progress: Optional[Callable[[str], None]] = None,
-    ) -> None:
+    ) -> Optional["Pack"]:
         """Add a set of objects to this object store.
         """Add a set of objects to this object store.
 
 
         Args:
         Args:
@@ -997,6 +1041,8 @@ class DiskObjectStore(PackBasedObjectStore):
 
 
     path: Union[str, os.PathLike]
     path: Union[str, os.PathLike]
     pack_dir: Union[str, os.PathLike]
     pack_dir: Union[str, os.PathLike]
+    _alternates: Optional[list["DiskObjectStore"]]
+    _commit_graph: Optional["CommitGraph"]
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -1229,7 +1275,7 @@ class DiskObjectStore(PackBasedObjectStore):
 
 
     def _get_shafile_path(self, sha):
     def _get_shafile_path(self, sha):
         # Check from object dir
         # Check from object dir
-        return hex_to_filename(self.path, sha)
+        return hex_to_filename(os.fspath(self.path), sha)
 
 
     def _iter_loose_objects(self):
     def _iter_loose_objects(self):
         for base in os.listdir(self.path):
         for base in os.listdir(self.path):
@@ -1330,9 +1376,9 @@ class DiskObjectStore(PackBasedObjectStore):
         os.remove(pack.index.path)
         os.remove(pack.index.path)
 
 
     def _get_pack_basepath(self, entries):
     def _get_pack_basepath(self, entries):
-        suffix = iter_sha1(entry[0] for entry in entries)
+        suffix_bytes = iter_sha1(entry[0] for entry in entries)
         # TODO: Handle self.pack_dir being bytes
         # TODO: Handle self.pack_dir being bytes
-        suffix = suffix.decode("ascii")
+        suffix = suffix_bytes.decode("ascii")
         return os.path.join(self.pack_dir, "pack-" + suffix)
         return os.path.join(self.pack_dir, "pack-" + suffix)
 
 
     def _complete_pack(self, f, path, num_objects, indexer, progress=None):
     def _complete_pack(self, f, path, num_objects, indexer, progress=None):
@@ -1451,6 +1497,7 @@ class DiskObjectStore(PackBasedObjectStore):
         def commit():
         def commit():
             if f.tell() > 0:
             if f.tell() > 0:
                 f.seek(0)
                 f.seek(0)
+
                 with PackData(path, f) as pd:
                 with PackData(path, f) as pd:
                     indexer = PackIndexer.for_pack_data(
                     indexer = PackIndexer.for_pack_data(
                         pd, resolve_ext_ref=self.get_raw
                         pd, resolve_ext_ref=self.get_raw
@@ -1783,6 +1830,7 @@ class MemoryObjectStore(BaseObjectStore):
             size = f.tell()
             size = f.tell()
             if size > 0:
             if size > 0:
                 f.seek(0)
                 f.seek(0)
+
                 p = PackData.from_file(f, size)
                 p = PackData.from_file(f, size)
                 for obj in PackInflater.for_pack_data(p, self.get_raw):
                 for obj in PackInflater.for_pack_data(p, self.get_raw):
                     self.add_object(obj)
                     self.add_object(obj)
@@ -2119,7 +2167,7 @@ class ObjectStoreGraphWalker:
     heads: set[ObjectID]
     heads: set[ObjectID]
     """Revisions without descendants in the local repo."""
     """Revisions without descendants in the local repo."""
 
 
-    get_parents: Callable[[ObjectID], ObjectID]
+    get_parents: Callable[[ObjectID], list[ObjectID]]
     """Function to retrieve parents in the local repo."""
     """Function to retrieve parents in the local repo."""
 
 
     shallow: set[ObjectID]
     shallow: set[ObjectID]
@@ -2215,7 +2263,7 @@ def commit_tree_changes(object_store, tree, changes):
     """
     """
     # TODO(jelmer): Save up the objects and add them using .add_objects
     # TODO(jelmer): Save up the objects and add them using .add_objects
     # rather than with individual calls to .add_object.
     # rather than with individual calls to .add_object.
-    nested_changes = {}
+    nested_changes: dict[bytes, list[tuple[bytes, Optional[int], Optional[bytes]]]] = {}
     for path, new_mode, new_sha in changes:
     for path, new_mode, new_sha in changes:
         try:
         try:
             (dirname, subpath) = path.split(b"/", 1)
             (dirname, subpath) = path.split(b"/", 1)
@@ -2450,8 +2498,16 @@ class BucketBasedObjectStore(PackBasedObjectStore):
         """
         """
         # Doesn't exist..
         # Doesn't exist..
 
 
-    def _remove_pack(self, name) -> None:
-        raise NotImplementedError(self._remove_pack)
+    def pack_loose_objects(self) -> int:
+        """Pack loose objects. Returns number of objects packed.
+
+        BucketBasedObjectStore doesn't support loose objects, so this is a no-op.
+        """
+        return 0
+
+    def _remove_pack_by_name(self, name: str) -> None:
+        """Remove a pack by name. Subclasses should implement this."""
+        raise NotImplementedError(self._remove_pack_by_name)
 
 
     def _iter_pack_names(self) -> Iterator[str]:
     def _iter_pack_names(self) -> Iterator[str]:
         raise NotImplementedError(self._iter_pack_names)
         raise NotImplementedError(self._iter_pack_names)
@@ -2496,6 +2552,7 @@ class BucketBasedObjectStore(PackBasedObjectStore):
                 return None
                 return None
 
 
             pf.seek(0)
             pf.seek(0)
+
             p = PackData(pf.name, pf)
             p = PackData(pf.name, pf)
             entries = p.sorted_entries()
             entries = p.sorted_entries()
             basename = iter_sha1(entry[0] for entry in entries).decode("ascii")
             basename = iter_sha1(entry[0] for entry in entries).decode("ascii")

+ 2 - 1
dulwich/objects.py

@@ -430,7 +430,8 @@ class ShaFile:
             self._sha = None
             self._sha = None
             self._chunked_text = self._serialize()
             self._chunked_text = self._serialize()
             self._needs_serialization = False
             self._needs_serialization = False
-        return self._chunked_text  # type: ignore
+        assert self._chunked_text is not None
+        return self._chunked_text
 
 
     def as_raw_string(self) -> bytes:
     def as_raw_string(self) -> bytes:
         """Return raw string with serialization of the object.
         """Return raw string with serialization of the object.

+ 30 - 15
dulwich/objectspec.py

@@ -24,6 +24,7 @@
 from typing import TYPE_CHECKING, Optional, Union
 from typing import TYPE_CHECKING, Optional, Union
 
 
 from .objects import Commit, ShaFile, Tag, Tree
 from .objects import Commit, ShaFile, Tag, Tree
+from .repo import BaseRepo
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
     from .object_store import BaseObjectStore
@@ -40,9 +41,9 @@ def to_bytes(text: Union[str, bytes]) -> bytes:
     Returns:
     Returns:
       Bytes representation of text
       Bytes representation of text
     """
     """
-    if getattr(text, "encode", None) is not None:
-        text = text.encode("ascii")  # type: ignore
-    return text  # type: ignore
+    if isinstance(text, str):
+        return text.encode("ascii")
+    return text
 
 
 
 
 def _resolve_object(repo: "Repo", ref: bytes) -> "ShaFile":
 def _resolve_object(repo: "Repo", ref: bytes) -> "ShaFile":
@@ -136,7 +137,9 @@ def parse_object(repo: "Repo", objectish: Union[bytes, str]) -> "ShaFile":
                         raise ValueError(
                         raise ValueError(
                             f"Commit {commit.id.decode('ascii', 'replace')} has no parents"
                             f"Commit {commit.id.decode('ascii', 'replace')} has no parents"
                         )
                         )
-                    commit = repo[commit.parents[0]]
+                    parent_obj = repo[commit.parents[0]]
+                    assert isinstance(parent_obj, Commit)
+                    commit = parent_obj
                 obj = commit
                 obj = commit
             else:  # sep == b"^"
             else:  # sep == b"^"
                 # Get N-th parent (or commit itself if N=0)
                 # Get N-th parent (or commit itself if N=0)
@@ -157,11 +160,13 @@ def parse_object(repo: "Repo", objectish: Union[bytes, str]) -> "ShaFile":
     return _resolve_object(repo, objectish)
     return _resolve_object(repo, objectish)
 
 
 
 
-def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "Tree":
+def parse_tree(
+    repo: "BaseRepo", treeish: Union[bytes, str, Tree, Commit, Tag]
+) -> "Tree":
     """Parse a string referring to a tree.
     """Parse a string referring to a tree.
 
 
     Args:
     Args:
-      repo: A `Repo` object
+      repo: A repository object
       treeish: A string referring to a tree, or a Tree, Commit, or Tag object
       treeish: A string referring to a tree, or a Tree, Commit, or Tag object
     Returns: A Tree object
     Returns: A Tree object
     Raises:
     Raises:
@@ -173,7 +178,9 @@ def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "
 
 
     # If it's a Commit, return its tree
     # If it's a Commit, return its tree
     if isinstance(treeish, Commit):
     if isinstance(treeish, Commit):
-        return repo[treeish.tree]
+        tree = repo[treeish.tree]
+        assert isinstance(tree, Tree)
+        return tree
 
 
     # For Tag objects or strings, use the existing logic
     # For Tag objects or strings, use the existing logic
     if isinstance(treeish, Tag):
     if isinstance(treeish, Tag):
@@ -181,7 +188,7 @@ def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "
     else:
     else:
         treeish = to_bytes(treeish)
         treeish = to_bytes(treeish)
     try:
     try:
-        treeish = parse_ref(repo, treeish)
+        treeish = parse_ref(repo.refs, treeish)
     except KeyError:  # treeish is commit sha
     except KeyError:  # treeish is commit sha
         pass
         pass
     try:
     try:
@@ -190,15 +197,21 @@ def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "
         # Try parsing as commit (handles short hashes)
         # Try parsing as commit (handles short hashes)
         try:
         try:
             commit = parse_commit(repo, treeish)
             commit = parse_commit(repo, treeish)
-            return repo[commit.tree]
+            assert isinstance(commit, Commit)
+            tree = repo[commit.tree]
+            assert isinstance(tree, Tree)
+            return tree
         except KeyError:
         except KeyError:
             raise KeyError(treeish)
             raise KeyError(treeish)
-    if o.type_name == b"commit":
-        return repo[o.tree]
-    elif o.type_name == b"tag":
+    if isinstance(o, Commit):
+        tree = repo[o.tree]
+        assert isinstance(tree, Tree)
+        return tree
+    elif isinstance(o, Tag):
         # Tag handling - dereference and recurse
         # Tag handling - dereference and recurse
         obj_type, obj_sha = o.object
         obj_type, obj_sha = o.object
         return parse_tree(repo, obj_sha)
         return parse_tree(repo, obj_sha)
+    assert isinstance(o, Tree)
     return o
     return o
 
 
 
 
@@ -383,11 +396,13 @@ def scan_for_short_id(
     raise AmbiguousShortId(prefix, ret)
     raise AmbiguousShortId(prefix, ret)
 
 
 
 
-def parse_commit(repo: "Repo", committish: Union[str, bytes, Commit, Tag]) -> "Commit":
+def parse_commit(
+    repo: "BaseRepo", committish: Union[str, bytes, Commit, Tag]
+) -> "Commit":
     """Parse a string referring to a single commit.
     """Parse a string referring to a single commit.
 
 
     Args:
     Args:
-      repo: A` Repo` object
+      repo: A repository object
       committish: A string referring to a single commit, or a Commit or Tag object.
       committish: A string referring to a single commit, or a Commit or Tag object.
     Returns: A Commit object
     Returns: A Commit object
     Raises:
     Raises:
@@ -426,7 +441,7 @@ def parse_commit(repo: "Repo", committish: Union[str, bytes, Commit, Tag]) -> "C
     else:
     else:
         return dereference_tag(obj)
         return dereference_tag(obj)
     try:
     try:
-        obj = repo[parse_ref(repo, committish)]
+        obj = repo[parse_ref(repo.refs, committish)]
     except KeyError:
     except KeyError:
         pass
         pass
     else:
     else:

+ 137 - 46
dulwich/pack.py

@@ -43,6 +43,7 @@ try:
 except ModuleNotFoundError:
 except ModuleNotFoundError:
     from difflib import SequenceMatcher
     from difflib import SequenceMatcher
 
 
+import hashlib
 import os
 import os
 import struct
 import struct
 import sys
 import sys
@@ -53,6 +54,7 @@ from hashlib import sha1
 from itertools import chain
 from itertools import chain
 from os import SEEK_CUR, SEEK_END
 from os import SEEK_CUR, SEEK_END
 from struct import unpack_from
 from struct import unpack_from
+from types import TracebackType
 from typing import (
 from typing import (
     IO,
     IO,
     TYPE_CHECKING,
     TYPE_CHECKING,
@@ -64,6 +66,7 @@ from typing import (
     Protocol,
     Protocol,
     TypeVar,
     TypeVar,
     Union,
     Union,
+    cast,
 )
 )
 
 
 try:
 try:
@@ -129,12 +132,13 @@ class ObjectContainer(Protocol):
         self,
         self,
         objects: Sequence[tuple[ShaFile, Optional[str]]],
         objects: Sequence[tuple[ShaFile, Optional[str]]],
         progress: Optional[Callable[[str], None]] = None,
         progress: Optional[Callable[[str], None]] = None,
-    ) -> None:
+    ) -> Optional["Pack"]:
         """Add a set of objects to this object store.
         """Add a set of objects to this object store.
 
 
         Args:
         Args:
           objects: Iterable over a list of (object, path) tuples
           objects: Iterable over a list of (object, path) tuples
           progress: Progress callback for object insertion
           progress: Progress callback for object insertion
+        Returns: Optional Pack object of the objects written.
         """
         """
 
 
     def __contains__(self, sha1: bytes) -> bool:
     def __contains__(self, sha1: bytes) -> bool:
@@ -331,6 +335,7 @@ class UnpackedObject:
     def sha(self) -> bytes:
     def sha(self) -> bytes:
         """Return the binary SHA of this object."""
         """Return the binary SHA of this object."""
         if self._sha is None:
         if self._sha is None:
+            assert self.obj_type_num is not None and self.obj_chunks is not None
             self._sha = obj_sha(self.obj_type_num, self.obj_chunks)
             self._sha = obj_sha(self.obj_type_num, self.obj_chunks)
         return self._sha
         return self._sha
 
 
@@ -643,6 +648,18 @@ class PackIndex:
         """Yield all the SHA1's of the objects in the index, sorted."""
         """Yield all the SHA1's of the objects in the index, sorted."""
         raise NotImplementedError(self._itersha)
         raise NotImplementedError(self._itersha)
 
 
+    def iter_prefix(self, prefix: bytes) -> Iterator[bytes]:
+        """Iterate over all SHA1s with the given prefix.
+
+        Args:
+            prefix: Binary prefix to match
+        Returns: Iterator of matching SHA1s
+        """
+        # Default implementation for PackIndex classes that don't override
+        for sha, _, _ in self.iterentries():
+            if sha.startswith(prefix):
+                yield sha
+
     def close(self) -> None:
     def close(self) -> None:
         """Close any open files."""
         """Close any open files."""
 
 
@@ -729,11 +746,12 @@ class FilePackIndex(PackIndex):
     """
     """
 
 
     _fan_out_table: list[int]
     _fan_out_table: list[int]
+    _file: Union[IO[bytes], _GitFile]
 
 
     def __init__(
     def __init__(
         self,
         self,
         filename: Union[str, os.PathLike],
         filename: Union[str, os.PathLike],
-        file: Optional[BinaryIO] = None,
+        file: Optional[Union[IO[bytes], _GitFile]] = None,
         contents: Optional[Union[bytes, "mmap.mmap"]] = None,
         contents: Optional[Union[bytes, "mmap.mmap"]] = None,
         size: Optional[int] = None,
         size: Optional[int] = None,
     ) -> None:
     ) -> None:
@@ -924,7 +942,11 @@ class PackIndex1(FilePackIndex):
     """Version 1 Pack Index file."""
     """Version 1 Pack Index file."""
 
 
     def __init__(
     def __init__(
-        self, filename: Union[str, os.PathLike], file: Optional[BinaryIO] = None, contents: Optional[bytes] = None, size: Optional[int] = None
+        self,
+        filename: Union[str, os.PathLike],
+        file: Optional[Union[IO[bytes], _GitFile]] = None,
+        contents: Optional[bytes] = None,
+        size: Optional[int] = None,
     ) -> None:
     ) -> None:
         """Initialize a version 1 pack index.
         """Initialize a version 1 pack index.
 
 
@@ -959,7 +981,11 @@ class PackIndex2(FilePackIndex):
     """Version 2 Pack Index file."""
     """Version 2 Pack Index file."""
 
 
     def __init__(
     def __init__(
-        self, filename: Union[str, os.PathLike], file: Optional[BinaryIO] = None, contents: Optional[bytes] = None, size: Optional[int] = None
+        self,
+        filename: Union[str, os.PathLike],
+        file: Optional[Union[IO[bytes], _GitFile]] = None,
+        contents: Optional[bytes] = None,
+        size: Optional[int] = None,
     ) -> None:
     ) -> None:
         """Initialize a version 2 pack index.
         """Initialize a version 2 pack index.
 
 
@@ -1013,7 +1039,11 @@ class PackIndex3(FilePackIndex):
     """
     """
 
 
     def __init__(
     def __init__(
-        self, filename: Union[str, os.PathLike], file: Optional[BinaryIO] = None, contents: Optional[bytes] = None, size: Optional[int] = None
+        self,
+        filename: Union[str, os.PathLike],
+        file: Optional[Union[IO[bytes], _GitFile]] = None,
+        contents: Optional[bytes] = None,
+        size: Optional[int] = None,
     ) -> None:
     ) -> None:
         """Initialize a version 3 pack index.
         """Initialize a version 3 pack index.
 
 
@@ -1201,7 +1231,12 @@ class PackStreamReader:
     appropriate.
     appropriate.
     """
     """
 
 
-    def __init__(self, read_all: Callable[[int], bytes], read_some: Optional[Callable[[int], bytes]] = None, zlib_bufsize: int = _ZLIB_BUFSIZE) -> None:
+    def __init__(
+        self,
+        read_all: Callable[[int], bytes],
+        read_some: Optional[Callable[[int], bytes]] = None,
+        zlib_bufsize: int = _ZLIB_BUFSIZE,
+    ) -> None:
         self.read_all = read_all
         self.read_all = read_all
         if read_some is None:
         if read_some is None:
             self.read_some = read_all
             self.read_some = read_all
@@ -1211,7 +1246,7 @@ class PackStreamReader:
         self._offset = 0
         self._offset = 0
         self._rbuf = BytesIO()
         self._rbuf = BytesIO()
         # trailer is a deque to avoid memory allocation on small reads
         # trailer is a deque to avoid memory allocation on small reads
-        self._trailer: deque[bytes] = deque()
+        self._trailer: deque[int] = deque()
         self._zlib_bufsize = zlib_bufsize
         self._zlib_bufsize = zlib_bufsize
 
 
     def _read(self, read: Callable[[int], bytes], size: int) -> bytes:
     def _read(self, read: Callable[[int], bytes], size: int) -> bytes:
@@ -1343,7 +1378,13 @@ class PackStreamCopier(PackStreamReader):
     appropriate and written out to the given file-like object.
     appropriate and written out to the given file-like object.
     """
     """
 
 
-    def __init__(self, read_all: Callable, read_some: Callable, outfile: Any, delta_iter: Optional[Any] = None) -> None:
+    def __init__(
+        self,
+        read_all: Callable,
+        read_some: Callable,
+        outfile: IO[bytes],
+        delta_iter: Optional["DeltaChainIterator"] = None,
+    ) -> None:
         """Initialize the copier.
         """Initialize the copier.
 
 
         Args:
         Args:
@@ -1393,7 +1434,9 @@ def obj_sha(type: int, chunks: Union[bytes, Iterable[bytes]]) -> bytes:
     return sha.digest()
     return sha.digest()
 
 
 
 
-def compute_file_sha(f: IO[bytes], start_ofs: int = 0, end_ofs: int = 0, buffer_size: int = 1 << 16) -> "sha1":
+def compute_file_sha(
+    f: IO[bytes], start_ofs: int = 0, end_ofs: int = 0, buffer_size: int = 1 << 16
+) -> "hashlib._Hash":
     """Hash a portion of a file into a new SHA.
     """Hash a portion of a file into a new SHA.
 
 
     Args:
     Args:
@@ -1450,7 +1493,7 @@ class PackData:
     def __init__(
     def __init__(
         self,
         self,
         filename: Union[str, os.PathLike],
         filename: Union[str, os.PathLike],
-        file: Optional[Any] = None,
+        file: Optional[IO[bytes]] = None,
         size: Optional[int] = None,
         size: Optional[int] = None,
         *,
         *,
         delta_window_size: Optional[int] = None,
         delta_window_size: Optional[int] = None,
@@ -1477,6 +1520,7 @@ class PackData:
         self.depth = depth
         self.depth = depth
         self.threads = threads
         self.threads = threads
         self.big_file_threshold = big_file_threshold
         self.big_file_threshold = big_file_threshold
+        self._file: IO[bytes]
 
 
         if file is None:
         if file is None:
             self._file = GitFile(self._filename, "rb")
             self._file = GitFile(self._filename, "rb")
@@ -1500,7 +1544,7 @@ class PackData:
         return os.path.basename(self._filename)
         return os.path.basename(self._filename)
 
 
     @property
     @property
-    def path(self) -> str:
+    def path(self) -> Union[str, os.PathLike]:
         """Get the full path of the pack file.
         """Get the full path of the pack file.
 
 
         Returns:
         Returns:
@@ -1509,7 +1553,7 @@ class PackData:
         return self._filename
         return self._filename
 
 
     @classmethod
     @classmethod
-    def from_file(cls, file: Any, size: Optional[int] = None) -> 'PackData':
+    def from_file(cls, file: IO[bytes], size: Optional[int] = None) -> "PackData":
         """Create a PackData object from an open file.
         """Create a PackData object from an open file.
 
 
         Args:
         Args:
@@ -1522,7 +1566,7 @@ class PackData:
         return cls(str(file), file=file, size=size)
         return cls(str(file), file=file, size=size)
 
 
     @classmethod
     @classmethod
-    def from_path(cls, path: Union[str, os.PathLike]) -> 'PackData':
+    def from_path(cls, path: Union[str, os.PathLike]) -> "PackData":
         """Create a PackData object from a file path.
         """Create a PackData object from a file path.
 
 
         Args:
         Args:
@@ -1537,15 +1581,20 @@ class PackData:
         """Close the underlying pack file."""
         """Close the underlying pack file."""
         self._file.close()
         self._file.close()
 
 
-    def __enter__(self) -> 'PackData':
+    def __enter__(self) -> "PackData":
         """Enter context manager."""
         """Enter context manager."""
         return self
         return self
 
 
-    def __exit__(self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional[Any]) -> None:
+    def __exit__(
+        self,
+        exc_type: Optional[type],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
         """Exit context manager."""
         """Exit context manager."""
         self.close()
         self.close()
 
 
-    def __eq__(self, other: Any) -> bool:
+    def __eq__(self, other: object) -> bool:
         if isinstance(other, PackData):
         if isinstance(other, PackData):
             return self.get_stored_checksum() == other.get_stored_checksum()
             return self.get_stored_checksum() == other.get_stored_checksum()
         return False
         return False
@@ -1568,9 +1617,9 @@ class PackData:
 
 
         Returns: 20-byte binary SHA1 digest
         Returns: 20-byte binary SHA1 digest
         """
         """
-        return compute_file_sha(self._file, end_ofs=-20).digest()
+        return compute_file_sha(cast(IO[bytes], self._file), end_ofs=-20).digest()
 
 
-    def iter_unpacked(self, *, include_comp: bool = False) -> Any:
+    def iter_unpacked(self, *, include_comp: bool = False) -> Iterator[UnpackedObject]:
         self._file.seek(self._header_size)
         self._file.seek(self._header_size)
 
 
         if self._num_objects is None:
         if self._num_objects is None:
@@ -1608,7 +1657,7 @@ class PackData:
         self,
         self,
         progress: Optional[ProgressFn] = None,
         progress: Optional[ProgressFn] = None,
         resolve_ext_ref: Optional[ResolveExtRefFn] = None,
         resolve_ext_ref: Optional[ResolveExtRefFn] = None,
-    ) -> Any:
+    ) -> list[tuple[bytes, int, int]]:
         """Return entries in this pack, sorted by SHA.
         """Return entries in this pack, sorted by SHA.
 
 
         Args:
         Args:
@@ -1621,7 +1670,12 @@ class PackData:
             self.iterentries(progress=progress, resolve_ext_ref=resolve_ext_ref)
             self.iterentries(progress=progress, resolve_ext_ref=resolve_ext_ref)
         )
         )
 
 
-    def create_index_v1(self, filename: str, progress: Optional[Callable] = None, resolve_ext_ref: Optional[Callable] = None) -> bytes:
+    def create_index_v1(
+        self,
+        filename: str,
+        progress: Optional[Callable] = None,
+        resolve_ext_ref: Optional[Callable] = None,
+    ) -> bytes:
         """Create a version 1 file for this data file.
         """Create a version 1 file for this data file.
 
 
         Args:
         Args:
@@ -1636,7 +1690,12 @@ class PackData:
         with GitFile(filename, "wb") as f:
         with GitFile(filename, "wb") as f:
             return write_pack_index_v1(f, entries, self.calculate_checksum())
             return write_pack_index_v1(f, entries, self.calculate_checksum())
 
 
-    def create_index_v2(self, filename: str, progress: Optional[Callable] = None, resolve_ext_ref: Optional[Callable] = None) -> bytes:
+    def create_index_v2(
+        self,
+        filename: str,
+        progress: Optional[Callable] = None,
+        resolve_ext_ref: Optional[Callable] = None,
+    ) -> bytes:
         """Create a version 2 index file for this data file.
         """Create a version 2 index file for this data file.
 
 
         Args:
         Args:
@@ -1652,7 +1711,11 @@ class PackData:
             return write_pack_index_v2(f, entries, self.calculate_checksum())
             return write_pack_index_v2(f, entries, self.calculate_checksum())
 
 
     def create_index_v3(
     def create_index_v3(
-        self, filename: str, progress: Optional[Callable] = None, resolve_ext_ref: Optional[Callable] = None, hash_algorithm: int = 1
+        self,
+        filename: str,
+        progress: Optional[Callable] = None,
+        resolve_ext_ref: Optional[Callable] = None,
+        hash_algorithm: int = 1,
     ) -> bytes:
     ) -> bytes:
         """Create a version 3 index file for this data file.
         """Create a version 3 index file for this data file.
 
 
@@ -1672,7 +1735,12 @@ class PackData:
             )
             )
 
 
     def create_index(
     def create_index(
-        self, filename: str, progress: Optional[Callable] = None, version: int = 2, resolve_ext_ref: Optional[Callable] = None, hash_algorithm: int = 1
+        self,
+        filename: str,
+        progress: Optional[Callable] = None,
+        version: int = 2,
+        resolve_ext_ref: Optional[Callable] = None,
+        hash_algorithm: int = 1,
     ) -> bytes:
     ) -> bytes:
         """Create an  index file for this data file.
         """Create an  index file for this data file.
 
 
@@ -1766,14 +1834,13 @@ class DeltaChainIterator(Generic[T]):
     _compute_crc32 = False
     _compute_crc32 = False
     _include_comp = False
     _include_comp = False
 
 
-    def __init__(self, file_obj, *, resolve_ext_ref=None) -> None:
+    def __init__(self, file_obj: Any, *, resolve_ext_ref: Optional[Callable] = None) -> None:
         """Initialize DeltaChainIterator.
         """Initialize DeltaChainIterator.
 
 
         Args:
         Args:
             file_obj: File object to read pack data from
             file_obj: File object to read pack data from
             resolve_ext_ref: Optional function to resolve external references
             resolve_ext_ref: Optional function to resolve external references
         """
         """
-    def __init__(self, file_obj: Any, *, resolve_ext_ref: Optional[Callable] = None) -> None:
         self._file = file_obj
         self._file = file_obj
         self._resolve_ext_ref = resolve_ext_ref
         self._resolve_ext_ref = resolve_ext_ref
         self._pending_ofs: dict[int, list[int]] = defaultdict(list)
         self._pending_ofs: dict[int, list[int]] = defaultdict(list)
@@ -1782,7 +1849,9 @@ class DeltaChainIterator(Generic[T]):
         self._ext_refs: list[bytes] = []
         self._ext_refs: list[bytes] = []
 
 
     @classmethod
     @classmethod
-    def for_pack_data(cls, pack_data: PackData, resolve_ext_ref: Optional[Callable] = None) -> 'DeltaChainIterator':
+    def for_pack_data(
+        cls, pack_data: PackData, resolve_ext_ref: Optional[Callable] = None
+    ) -> "DeltaChainIterator":
         """Create a DeltaChainIterator from pack data.
         """Create a DeltaChainIterator from pack data.
 
 
         Args:
         Args:
@@ -1878,7 +1947,7 @@ class DeltaChainIterator(Generic[T]):
         """
         """
         self._file = pack_data._file
         self._file = pack_data._file
 
 
-    def _walk_all_chains(self) -> Any:
+    def _walk_all_chains(self) -> Iterator[T]:
         for offset, type_num in self._full_ofs:
         for offset, type_num in self._full_ofs:
             yield from self._follow_chain(offset, type_num, None)
             yield from self._follow_chain(offset, type_num, None)
         yield from self._walk_ref_chains()
         yield from self._walk_ref_chains()
@@ -1888,7 +1957,7 @@ class DeltaChainIterator(Generic[T]):
         if self._pending_ref:
         if self._pending_ref:
             raise UnresolvedDeltas([sha_to_hex(s) for s in self._pending_ref])
             raise UnresolvedDeltas([sha_to_hex(s) for s in self._pending_ref])
 
 
-    def _walk_ref_chains(self) -> Any:
+    def _walk_ref_chains(self) -> Iterator[T]:
         if not self._resolve_ext_ref:
         if not self._resolve_ext_ref:
             self._ensure_no_pending()
             self._ensure_no_pending()
             return
             return
@@ -1910,12 +1979,13 @@ class DeltaChainIterator(Generic[T]):
 
 
         self._ensure_no_pending()
         self._ensure_no_pending()
 
 
-    def _result(self, unpacked: UnpackedObject) -> Any:
+    def _result(self, unpacked: UnpackedObject) -> T:
         raise NotImplementedError
         raise NotImplementedError
 
 
     def _resolve_object(
     def _resolve_object(
         self, offset: int, obj_type_num: int, base_chunks: Optional[list[bytes]]
         self, offset: int, obj_type_num: int, base_chunks: Optional[list[bytes]]
     ) -> UnpackedObject:
     ) -> UnpackedObject:
+        assert self._file is not None
         self._file.seek(offset)
         self._file.seek(offset)
         unpacked, _ = unpack_object(
         unpacked, _ = unpack_object(
             self._file.read,
             self._file.read,
@@ -1931,7 +2001,9 @@ class DeltaChainIterator(Generic[T]):
             unpacked.obj_chunks = apply_delta(base_chunks, unpacked.decomp_chunks)
             unpacked.obj_chunks = apply_delta(base_chunks, unpacked.decomp_chunks)
         return unpacked
         return unpacked
 
 
-    def _follow_chain(self, offset: int, obj_type_num: int, base_chunks: list[bytes]) -> Iterator[T]:
+    def _follow_chain(
+        self, offset: int, obj_type_num: int, base_chunks: Optional[list[bytes]]
+    ) -> Iterator[T]:
         # Unlike PackData.get_object_at, there is no need to cache offsets as
         # Unlike PackData.get_object_at, there is no need to cache offsets as
         # this approach by design inflates each object exactly once.
         # this approach by design inflates each object exactly once.
         todo = [(offset, obj_type_num, base_chunks)]
         todo = [(offset, obj_type_num, base_chunks)]
@@ -2004,19 +2076,19 @@ class PackInflater(DeltaChainIterator[ShaFile]):
         Returns:
         Returns:
             ShaFile object from the unpacked data
             ShaFile object from the unpacked data
         """
         """
+    def _result(self, unpacked: UnpackedObject) -> ShaFile:
         return unpacked.sha_file()
         return unpacked.sha_file()
 
 
 
 
 class SHA1Reader(BinaryIO):
 class SHA1Reader(BinaryIO):
     """Wrapper for file-like object that remembers the SHA1 of its data."""
     """Wrapper for file-like object that remembers the SHA1 of its data."""
 
 
-    def __init__(self, f) -> None:
+    def __init__(self, f: IO[bytes]) -> None:
         """Initialize SHA1Reader.
         """Initialize SHA1Reader.
 
 
         Args:
         Args:
             f: File-like object to wrap
             f: File-like object to wrap
         """
         """
-    def __init__(self, f: BinaryIO) -> None:
         self.f = f
         self.f = f
         self.sha1 = sha1(b"")
         self.sha1 = sha1(b"")
 
 
@@ -2126,19 +2198,24 @@ class SHA1Reader(BinaryIO):
         """
         """
         raise UnsupportedOperation("write")
         raise UnsupportedOperation("write")
 
 
-    def writelines(self, lines: Any) -> None:
+    def writelines(self, lines: Iterable[bytes], /) -> None:  # type: ignore[override]
         raise UnsupportedOperation("writelines")
         raise UnsupportedOperation("writelines")
 
 
-    def write(self, data: bytes) -> int:
+    def write(self, data: bytes, /) -> int:  # type: ignore[override]
         raise UnsupportedOperation("write")
         raise UnsupportedOperation("write")
 
 
-    def __enter__(self) -> 'SHA1Reader':
+    def __enter__(self) -> "SHA1Reader":
         return self
         return self
 
 
-    def __exit__(self, type: Optional[type], value: Optional[BaseException], traceback: Optional[Any]) -> None:
+    def __exit__(
+        self,
+        type: Optional[type],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
         self.close()
         self.close()
 
 
-    def __iter__(self) -> 'SHA1Reader':
+    def __iter__(self) -> "SHA1Reader":
         return self
         return self
 
 
     def __next__(self) -> bytes:
     def __next__(self) -> bytes:
@@ -2181,10 +2258,10 @@ class SHA1Writer(BinaryIO):
         Args:
         Args:
             f: File-like object to wrap
             f: File-like object to wrap
         """
         """
-    def __init__(self, f: BinaryIO) -> None:
         self.f = f
         self.f = f
         self.length = 0
         self.length = 0
         self.sha1 = sha1(b"")
         self.sha1 = sha1(b"")
+        self.digest: Optional[bytes] = None
 
 
     def write(self, data) -> int:
     def write(self, data) -> int:
         """Write data and update SHA1.
         """Write data and update SHA1.
@@ -2195,7 +2272,6 @@ class SHA1Writer(BinaryIO):
         Returns:
         Returns:
             Number of bytes written
             Number of bytes written
         """
         """
-    def write(self, data: bytes) -> int:
         self.sha1.update(data)
         self.sha1.update(data)
         self.f.write(data)
         self.f.write(data)
         self.length += len(data)
         self.length += len(data)
@@ -2220,8 +2296,9 @@ class SHA1Writer(BinaryIO):
             The SHA1 digest bytes
             The SHA1 digest bytes
         """
         """
         sha = self.write_sha()
         sha = self.write_sha()
+    def close(self) -> None:
+        self.digest = self.write_sha()
         self.f.close()
         self.f.close()
-        return sha
 
 
     def offset(self) -> int:
     def offset(self) -> int:
         """Get the total number of bytes written.
         """Get the total number of bytes written.
@@ -2281,13 +2358,12 @@ class SHA1Writer(BinaryIO):
         """
         """
         raise UnsupportedOperation("readlines")
         raise UnsupportedOperation("readlines")
 
 
-    def writelines(self, lines) -> None:
+    def writelines(self, lines: Iterable[bytes]) -> None:
         """Write multiple lines to the file.
         """Write multiple lines to the file.
 
 
         Args:
         Args:
             lines: Iterable of lines to write
             lines: Iterable of lines to write
         """
         """
-    def writelines(self, lines: Any) -> None:
         for line in lines:
         for line in lines:
             self.write(line)
             self.write(line)
 
 
@@ -2309,6 +2385,18 @@ class SHA1Writer(BinaryIO):
 
 
     def __iter__(self) -> 'SHA1Writer':
     def __iter__(self) -> 'SHA1Writer':
         """Return iterator."""
         """Return iterator."""
+    def __enter__(self) -> "SHA1Writer":
+        return self
+
+    def __exit__(
+        self,
+        type: Optional[type],
+        value: Optional[BaseException],
+        traceback: Optional[TracebackType],
+    ) -> None:
+        self.close()
+
+    def __iter__(self) -> "SHA1Writer":
         return self
         return self
 
 
     def __next__(self) -> bytes:
     def __next__(self) -> bytes:
@@ -2498,7 +2586,7 @@ def find_reusable_deltas(
 
 
 
 
 def deltify_pack_objects(
 def deltify_pack_objects(
-    objects: Union[Iterator[bytes], Iterator[tuple[ShaFile, Optional[bytes]]]],
+    objects: Union[Iterator[ShaFile], Iterator[tuple[ShaFile, Optional[bytes]]]],
     *,
     *,
     window_size: Optional[int] = None,
     window_size: Optional[int] = None,
     progress=None,
     progress=None,
@@ -2927,7 +3015,7 @@ def write_pack_index_v1(f: BinaryIO, entries: list[tuple[bytes, int, Optional[in
     Returns: The SHA of the written index file
     Returns: The SHA of the written index file
     """
     """
     f = SHA1Writer(f)
     f = SHA1Writer(f)
-    fan_out_table = defaultdict(lambda: 0)
+    fan_out_table: dict[int, int] = defaultdict(lambda: 0)
     for name, _offset, _entry_checksum in entries:
     for name, _offset, _entry_checksum in entries:
         fan_out_table[ord(name[:1])] += 1
         fan_out_table[ord(name[:1])] += 1
     # Fan-out table
     # Fan-out table
@@ -3489,6 +3577,7 @@ class Pack:
             ofs: (sha, crc32) for (sha, ofs, crc32) in self.index.iterentries()
             ofs: (sha, crc32) for (sha, ofs, crc32) in self.index.iterentries()
         }
         }
         for unpacked in self.data.iter_unpacked(include_comp=include_comp):
         for unpacked in self.data.iter_unpacked(include_comp=include_comp):
+            assert unpacked.offset is not None
             (sha, crc32) = ofs_to_entries[unpacked.offset]
             (sha, crc32) = ofs_to_entries[unpacked.offset]
             unpacked._sha = sha
             unpacked._sha = sha
             unpacked.crc32 = crc32
             unpacked.crc32 = crc32
@@ -3589,8 +3678,10 @@ class Pack:
             object count
             object count
         Returns: Iterator of tuples with (sha, offset, crc32)
         Returns: Iterator of tuples with (sha, offset, crc32)
         """
         """
-        return self.data.sorted_entries(
-            progress=progress, resolve_ext_ref=self.resolve_ext_ref
+        return iter(
+            self.data.sorted_entries(
+                progress=progress, resolve_ext_ref=self.resolve_ext_ref
+            )
         )
         )
 
 
     def get_unpacked_object(
     def get_unpacked_object(

+ 8 - 9
dulwich/patch.py

@@ -30,6 +30,7 @@ import time
 from collections.abc import Generator
 from collections.abc import Generator
 from difflib import SequenceMatcher
 from difflib import SequenceMatcher
 from typing import (
 from typing import (
+    IO,
     TYPE_CHECKING,
     TYPE_CHECKING,
     BinaryIO,
     BinaryIO,
     Optional,
     Optional,
@@ -48,7 +49,7 @@ FIRST_FEW_BYTES = 8000
 
 
 
 
 def write_commit_patch(
 def write_commit_patch(
-    f: BinaryIO,
+    f: IO[bytes],
     commit: "Commit",
     commit: "Commit",
     contents: Union[str, bytes],
     contents: Union[str, bytes],
     progress: tuple[int, int],
     progress: tuple[int, int],
@@ -231,7 +232,7 @@ def patch_filename(p: Optional[bytes], root: bytes) -> bytes:
 
 
 
 
 def write_object_diff(
 def write_object_diff(
-    f: BinaryIO,
+    f: IO[bytes],
     store: "BaseObjectStore",
     store: "BaseObjectStore",
     old_file: tuple[Optional[bytes], Optional[int], Optional[bytes]],
     old_file: tuple[Optional[bytes], Optional[int], Optional[bytes]],
     new_file: tuple[Optional[bytes], Optional[int], Optional[bytes]],
     new_file: tuple[Optional[bytes], Optional[int], Optional[bytes]],
@@ -264,19 +265,17 @@ def write_object_diff(
         Returns:
         Returns:
             Blob object
             Blob object
         """
         """
-        from typing import cast
-
         if hexsha is None:
         if hexsha is None:
-            return cast(Blob, Blob.from_string(b""))
+            return Blob.from_string(b"")
         elif mode is not None and S_ISGITLINK(mode):
         elif mode is not None and S_ISGITLINK(mode):
-            return cast(Blob, Blob.from_string(b"Subproject commit " + hexsha + b"\n"))
+            return Blob.from_string(b"Subproject commit " + hexsha + b"\n")
         else:
         else:
             obj = store[hexsha]
             obj = store[hexsha]
             if isinstance(obj, Blob):
             if isinstance(obj, Blob):
                 return obj
                 return obj
             else:
             else:
                 # Fallback for non-blob objects
                 # Fallback for non-blob objects
-                return cast(Blob, Blob.from_string(obj.as_raw_string()))
+                return Blob.from_string(obj.as_raw_string())
 
 
     def lines(content: "Blob") -> list[bytes]:
     def lines(content: "Blob") -> list[bytes]:
         """Split blob content into lines.
         """Split blob content into lines.
@@ -356,7 +355,7 @@ def gen_diff_header(
 
 
 # TODO(jelmer): Support writing unicode, rather than bytes.
 # TODO(jelmer): Support writing unicode, rather than bytes.
 def write_blob_diff(
 def write_blob_diff(
-    f: BinaryIO,
+    f: IO[bytes],
     old_file: tuple[Optional[bytes], Optional[int], Optional["Blob"]],
     old_file: tuple[Optional[bytes], Optional[int], Optional["Blob"]],
     new_file: tuple[Optional[bytes], Optional[int], Optional["Blob"]],
     new_file: tuple[Optional[bytes], Optional[int], Optional["Blob"]],
 ) -> None:
 ) -> None:
@@ -403,7 +402,7 @@ def write_blob_diff(
 
 
 
 
 def write_tree_diff(
 def write_tree_diff(
-    f: BinaryIO,
+    f: IO[bytes],
     store: "BaseObjectStore",
     store: "BaseObjectStore",
     old_tree: Optional[bytes],
     old_tree: Optional[bytes],
     new_tree: Optional[bytes],
     new_tree: Optional[bytes],

Різницю між файлами не показано, бо вона завелика
+ 331 - 117
dulwich/porcelain.py


+ 0 - 12
dulwich/protocol.py

@@ -258,18 +258,6 @@ def pkt_seq(*seq: Optional[bytes]) -> bytes:
     return b"".join([pkt_line(s) for s in seq]) + pkt_line(None)
     return b"".join([pkt_line(s) for s in seq]) + pkt_line(None)
 
 
 
 
-def filter_ref_prefix(
-    refs: dict[bytes, bytes], prefixes: Iterable[bytes]
-) -> dict[bytes, bytes]:
-    """Filter refs to only include those with a given prefix.
-
-    Args:
-      refs: A list of refs.
-      prefixes: The prefixes to filter by.
-    """
-    return {k: v for k, v in refs.items() if any(k.startswith(p) for p in prefixes)}
-
-
 class Protocol:
 class Protocol:
     """Class for interacting with a remote git process over the wire.
     """Class for interacting with a remote git process over the wire.
 
 

+ 13 - 3
dulwich/rebase.py

@@ -32,7 +32,7 @@ from dulwich.graph import find_merge_base
 from dulwich.merge import three_way_merge
 from dulwich.merge import three_way_merge
 from dulwich.objects import Commit
 from dulwich.objects import Commit
 from dulwich.objectspec import parse_commit
 from dulwich.objectspec import parse_commit
-from dulwich.repo import Repo
+from dulwich.repo import BaseRepo, Repo
 
 
 
 
 class RebaseError(Exception):
 class RebaseError(Exception):
@@ -529,7 +529,7 @@ class DiskRebaseStateManager:
 class MemoryRebaseStateManager:
 class MemoryRebaseStateManager:
     """Manages rebase state in memory for MemoryRepo."""
     """Manages rebase state in memory for MemoryRepo."""
 
 
-    def __init__(self, repo: Repo) -> None:
+    def __init__(self, repo: BaseRepo) -> None:
         """Initialize MemoryRebaseStateManager.
         """Initialize MemoryRebaseStateManager.
 
 
         Args:
         Args:
@@ -642,6 +642,8 @@ class Rebaser:
         if branch is None:
         if branch is None:
             # Use current HEAD
             # Use current HEAD
             head_ref, head_sha = self.repo.refs.follow(b"HEAD")
             head_ref, head_sha = self.repo.refs.follow(b"HEAD")
+            if head_sha is None:
+                raise ValueError("HEAD does not point to a valid commit")
             branch_commit = self.repo[head_sha]
             branch_commit = self.repo[head_sha]
         else:
         else:
             # Parse the branch reference
             # Parse the branch reference
@@ -664,6 +666,7 @@ class Rebaser:
         commits = []
         commits = []
         current = branch_commit
         current = branch_commit
         while current.id != merge_base:
         while current.id != merge_base:
+            assert isinstance(current, Commit)
             commits.append(current)
             commits.append(current)
             if not current.parents:
             if not current.parents:
                 break
                 break
@@ -691,6 +694,9 @@ class Rebaser:
         parent = self.repo[commit.parents[0]]
         parent = self.repo[commit.parents[0]]
         onto_commit = self.repo[onto]
         onto_commit = self.repo[onto]
 
 
+        assert isinstance(parent, Commit)
+        assert isinstance(onto_commit, Commit)
+
         # Perform three-way merge
         # Perform three-way merge
         merged_tree, conflicts = three_way_merge(
         merged_tree, conflicts = three_way_merge(
             self.object_store, parent, onto_commit, commit
             self.object_store, parent, onto_commit, commit
@@ -798,7 +804,9 @@ class Rebaser:
 
 
         if new_sha:
         if new_sha:
             # Success - add to done list
             # Success - add to done list
-            self._done.append(self.repo[new_sha])
+            new_commit = self.repo[new_sha]
+            assert isinstance(new_commit, Commit)
+            self._done.append(new_commit)
             self._save_rebase_state()
             self._save_rebase_state()
 
 
             # Continue with next commit if any
             # Continue with next commit if any
@@ -822,6 +830,8 @@ class Rebaser:
             raise RebaseError("No rebase in progress")
             raise RebaseError("No rebase in progress")
 
 
         # Restore original HEAD
         # Restore original HEAD
+        if self._original_head is None:
+            raise RebaseError("No original HEAD to restore")
         self.repo.refs[b"HEAD"] = self._original_head
         self.repo.refs[b"HEAD"] = self._original_head
 
 
         # Clean up rebase state
         # Clean up rebase state

+ 146 - 44
dulwich/refs.py

@@ -25,9 +25,19 @@
 import os
 import os
 import types
 import types
 import warnings
 import warnings
-from collections.abc import Iterator
+from collections.abc import Iterable, Iterator
 from contextlib import suppress
 from contextlib import suppress
-from typing import TYPE_CHECKING, Any, Callable, Optional, Union
+from typing import (
+    IO,
+    TYPE_CHECKING,
+    Any,
+    BinaryIO,
+    Callable,
+    Optional,
+    TypeVar,
+    Union,
+    cast,
+)
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .file import _GitFile
     from .file import _GitFile
@@ -136,7 +146,23 @@ def parse_remote_ref(ref: bytes) -> tuple[bytes, bytes]:
 class RefsContainer:
 class RefsContainer:
     """A container for refs."""
     """A container for refs."""
 
 
-    def __init__(self, logger: Optional[Callable[[bytes, Optional[bytes], Optional[bytes], Optional[bytes], Optional[int], Optional[int], Optional[bytes]], None]] = None) -> None:
+    def __init__(
+        self,
+        logger: Optional[
+            Callable[
+                [
+                    bytes,
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[int],
+                    Optional[int],
+                    Optional[bytes],
+                ],
+                None,
+            ]
+        ] = None,
+    ) -> None:
         self._logger = logger
         self._logger = logger
 
 
     def _log(
     def _log(
@@ -246,14 +272,14 @@ class RefsContainer:
         for ref in to_delete:
         for ref in to_delete:
             self.remove_if_equals(b"/".join((base, ref)), None, message=message)
             self.remove_if_equals(b"/".join((base, ref)), None, message=message)
 
 
-    def allkeys(self) -> Iterator[Ref]:
+    def allkeys(self) -> set[Ref]:
         """All refs present in this container."""
         """All refs present in this container."""
         raise NotImplementedError(self.allkeys)
         raise NotImplementedError(self.allkeys)
 
 
     def __iter__(self) -> Iterator[Ref]:
     def __iter__(self) -> Iterator[Ref]:
         return iter(self.allkeys())
         return iter(self.allkeys())
 
 
-    def keys(self, base: Optional[bytes] = None) -> Union[Iterator[Ref], set[bytes]]:
+    def keys(self, base=None):
         """Refs present in this container.
         """Refs present in this container.
 
 
         Args:
         Args:
@@ -339,16 +365,16 @@ class RefsContainer:
         """
         """
         raise NotImplementedError(self.read_loose_ref)
         raise NotImplementedError(self.read_loose_ref)
 
 
-    def follow(self, name: bytes) -> tuple[list[bytes], bytes]:
+    def follow(self, name: bytes) -> tuple[list[bytes], Optional[bytes]]:
         """Follow a reference name.
         """Follow a reference name.
 
 
         Returns: a tuple of (refnames, sha), wheres refnames are the names of
         Returns: a tuple of (refnames, sha), wheres refnames are the names of
             references in the chain
             references in the chain
         """
         """
-        contents = SYMREF + name
+        contents: Optional[bytes] = SYMREF + name
         depth = 0
         depth = 0
         refnames = []
         refnames = []
-        while contents.startswith(SYMREF):
+        while contents and contents.startswith(SYMREF):
             refname = contents[len(SYMREF) :]
             refname = contents[len(SYMREF) :]
             refnames.append(refname)
             refnames.append(refname)
             contents = self.read_ref(refname)
             contents = self.read_ref(refname)
@@ -404,7 +430,13 @@ class RefsContainer:
         raise NotImplementedError(self.set_if_equals)
         raise NotImplementedError(self.set_if_equals)
 
 
     def add_if_new(
     def add_if_new(
-        self, name: bytes, ref: bytes, committer: Optional[bytes] = None, timestamp: Optional[int] = None, timezone: Optional[int] = None, message: Optional[bytes] = None
+        self,
+        name: bytes,
+        ref: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Add a new reference only if it does not already exist.
         """Add a new reference only if it does not already exist.
 
 
@@ -486,7 +518,9 @@ class RefsContainer:
         ret = {}
         ret = {}
         for src in self.allkeys():
         for src in self.allkeys():
             try:
             try:
-                dst = parse_symref_value(self.read_ref(src))
+                ref_value = self.read_ref(src)
+                assert ref_value is not None
+                dst = parse_symref_value(ref_value)
             except ValueError:
             except ValueError:
                 pass
                 pass
             else:
             else:
@@ -509,14 +543,31 @@ class DictRefsContainer(RefsContainer):
     threadsafe.
     threadsafe.
     """
     """
 
 
-    def __init__(self, refs: dict[bytes, bytes], logger: Optional[Callable[[bytes, Optional[bytes], Optional[bytes], Optional[bytes], Optional[int], Optional[int], Optional[bytes]], None]] = None) -> None:
+    def __init__(
+        self,
+        refs: dict[bytes, bytes],
+        logger: Optional[
+            Callable[
+                [
+                    bytes,
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[int],
+                    Optional[int],
+                    Optional[bytes],
+                ],
+                None,
+            ]
+        ] = None,
+    ) -> None:
         super().__init__(logger=logger)
         super().__init__(logger=logger)
         self._refs = refs
         self._refs = refs
         self._peeled: dict[bytes, ObjectID] = {}
         self._peeled: dict[bytes, ObjectID] = {}
         self._watchers: set[Any] = set()
         self._watchers: set[Any] = set()
 
 
-    def allkeys(self) -> Iterator[bytes]:
-        return self._refs.keys()
+    def allkeys(self) -> set[bytes]:
+        return set(self._refs.keys())
 
 
     def read_loose_ref(self, name: bytes) -> Optional[bytes]:
     def read_loose_ref(self, name: bytes) -> Optional[bytes]:
         return self._refs.get(name, None)
         return self._refs.get(name, None)
@@ -707,14 +758,14 @@ class DictRefsContainer(RefsContainer):
 class InfoRefsContainer(RefsContainer):
 class InfoRefsContainer(RefsContainer):
     """Refs container that reads refs from a info/refs file."""
     """Refs container that reads refs from a info/refs file."""
 
 
-    def __init__(self, f: Any) -> None:
-        self._refs = {}
-        self._peeled = {}
+    def __init__(self, f: BinaryIO) -> None:
+        self._refs: dict[bytes, bytes] = {}
+        self._peeled: dict[bytes, bytes] = {}
         refs = read_info_refs(f)
         refs = read_info_refs(f)
         (self._refs, self._peeled) = split_peeled_refs(refs)
         (self._refs, self._peeled) = split_peeled_refs(refs)
 
 
-    def allkeys(self) -> Iterator[bytes]:
-        return self._refs.keys()
+    def allkeys(self) -> set[bytes]:
+        return set(self._refs.keys())
 
 
     def read_loose_ref(self, name: bytes) -> Optional[bytes]:
     def read_loose_ref(self, name: bytes) -> Optional[bytes]:
         return self._refs.get(name, None)
         return self._refs.get(name, None)
@@ -736,7 +787,20 @@ class DiskRefsContainer(RefsContainer):
         self,
         self,
         path: Union[str, bytes, os.PathLike],
         path: Union[str, bytes, os.PathLike],
         worktree_path: Optional[Union[str, bytes, os.PathLike]] = None,
         worktree_path: Optional[Union[str, bytes, os.PathLike]] = None,
-        logger: Optional[Callable[[bytes, Optional[bytes], Optional[bytes], Optional[bytes], Optional[int], Optional[int], Optional[bytes]], None]] = None,
+        logger: Optional[
+            Callable[
+                [
+                    bytes,
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[int],
+                    Optional[int],
+                    Optional[bytes],
+                ],
+                None,
+            ]
+        ] = None,
     ) -> None:
     ) -> None:
         """Initialize DiskRefsContainer."""
         """Initialize DiskRefsContainer."""
         super().__init__(logger=logger)
         super().__init__(logger=logger)
@@ -746,8 +810,8 @@ class DiskRefsContainer(RefsContainer):
             self.worktree_path = self.path
             self.worktree_path = self.path
         else:
         else:
             self.worktree_path = os.fsencode(os.fspath(worktree_path))
             self.worktree_path = os.fsencode(os.fspath(worktree_path))
-        self._packed_refs = None
-        self._peeled_refs = None
+        self._packed_refs: Optional[dict[bytes, bytes]] = None
+        self._peeled_refs: Optional[dict[bytes, bytes]] = None
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         """Return string representation of DiskRefsContainer."""
         """Return string representation of DiskRefsContainer."""
@@ -772,7 +836,7 @@ class DiskRefsContainer(RefsContainer):
                 subkeys.add(key[len(base) :].strip(b"/"))
                 subkeys.add(key[len(base) :].strip(b"/"))
         return subkeys
         return subkeys
 
 
-    def allkeys(self) -> Iterator[bytes]:
+    def allkeys(self) -> set[bytes]:
         allkeys = set()
         allkeys = set()
         if os.path.exists(self.refpath(HEADREF)):
         if os.path.exists(self.refpath(HEADREF)):
             allkeys.add(HEADREF)
             allkeys.add(HEADREF)
@@ -878,7 +942,11 @@ class DiskRefsContainer(RefsContainer):
             to a tag, but no cached information is available, None is returned.
             to a tag, but no cached information is available, None is returned.
         """
         """
         self.get_packed_refs()
         self.get_packed_refs()
-        if self._peeled_refs is None or name not in self._packed_refs:
+        if (
+            self._peeled_refs is None
+            or self._packed_refs is None
+            or name not in self._packed_refs
+        ):
             # No cache: no peeled refs were read, or this ref is loose
             # No cache: no peeled refs were read, or this ref is loose
             return None
             return None
         if name in self._peeled_refs:
         if name in self._peeled_refs:
@@ -927,13 +995,14 @@ class DiskRefsContainer(RefsContainer):
             self._packed_refs = None
             self._packed_refs = None
             self.get_packed_refs()
             self.get_packed_refs()
 
 
-            if name not in self._packed_refs:
+            if self._packed_refs is None or name not in self._packed_refs:
                 f.abort()
                 f.abort()
                 return
                 return
 
 
             del self._packed_refs[name]
             del self._packed_refs[name]
-            with suppress(KeyError):
-                del self._peeled_refs[name]
+            if self._peeled_refs is not None:
+                with suppress(KeyError):
+                    del self._peeled_refs[name]
             write_packed_refs(f, self._packed_refs, self._peeled_refs)
             write_packed_refs(f, self._packed_refs, self._peeled_refs)
             f.close()
             f.close()
         except BaseException:
         except BaseException:
@@ -1240,7 +1309,7 @@ def _split_ref_line(line: bytes) -> tuple[bytes, bytes]:
     return (sha, name)
     return (sha, name)
 
 
 
 
-def read_packed_refs(f: Any) -> Iterator[tuple[bytes, bytes]]:
+def read_packed_refs(f: IO[bytes]) -> Iterator[tuple[bytes, bytes]]:
     """Read a packed refs file.
     """Read a packed refs file.
 
 
     Args:
     Args:
@@ -1256,7 +1325,9 @@ def read_packed_refs(f: Any) -> Iterator[tuple[bytes, bytes]]:
         yield _split_ref_line(line)
         yield _split_ref_line(line)
 
 
 
 
-def read_packed_refs_with_peeled(f: Any) -> Iterator[tuple[bytes, bytes, Optional[bytes]]]:
+def read_packed_refs_with_peeled(
+    f: IO[bytes],
+) -> Iterator[tuple[bytes, bytes, Optional[bytes]]]:
     """Read a packed refs file including peeled refs.
     """Read a packed refs file including peeled refs.
 
 
     Assumes the "# pack-refs with: peeled" line was already read. Yields tuples
     Assumes the "# pack-refs with: peeled" line was already read. Yields tuples
@@ -1288,7 +1359,11 @@ def read_packed_refs_with_peeled(f: Any) -> Iterator[tuple[bytes, bytes, Optiona
         yield (sha, name, None)
         yield (sha, name, None)
 
 
 
 
-def write_packed_refs(f: Any, packed_refs: dict[bytes, bytes], peeled_refs: Optional[dict[bytes, bytes]] = None) -> None:
+def write_packed_refs(
+    f: IO[bytes],
+    packed_refs: dict[bytes, bytes],
+    peeled_refs: Optional[dict[bytes, bytes]] = None,
+) -> None:
     """Write a packed refs file.
     """Write a packed refs file.
 
 
     Args:
     Args:
@@ -1306,7 +1381,7 @@ def write_packed_refs(f: Any, packed_refs: dict[bytes, bytes], peeled_refs: Opti
             f.write(b"^" + peeled_refs[refname] + b"\n")
             f.write(b"^" + peeled_refs[refname] + b"\n")
 
 
 
 
-def read_info_refs(f: Any) -> dict[bytes, bytes]:
+def read_info_refs(f: BinaryIO) -> dict[bytes, bytes]:
     """Read info/refs file.
     """Read info/refs file.
 
 
     Args:
     Args:
@@ -1322,7 +1397,9 @@ def read_info_refs(f: Any) -> dict[bytes, bytes]:
     return ret
     return ret
 
 
 
 
-def write_info_refs(refs: dict[bytes, bytes], store: ObjectContainer) -> Iterator[bytes]:
+def write_info_refs(
+    refs: dict[bytes, bytes], store: ObjectContainer
+) -> Iterator[bytes]:
     """Generate info refs."""
     """Generate info refs."""
     # TODO: Avoid recursive import :(
     # TODO: Avoid recursive import :(
     from .object_store import peel_sha
     from .object_store import peel_sha
@@ -1346,26 +1423,33 @@ def is_local_branch(x: bytes) -> bool:
     return x.startswith(LOCAL_BRANCH_PREFIX)
     return x.startswith(LOCAL_BRANCH_PREFIX)
 
 
 
 
-def strip_peeled_refs(refs: dict[bytes, bytes]) -> dict[bytes, bytes]:
+T = TypeVar("T", dict[bytes, bytes], dict[bytes, Optional[bytes]])
+
+
+def strip_peeled_refs(refs: T) -> T:
     """Remove all peeled refs."""
     """Remove all peeled refs."""
     return {
     return {
         ref: sha for (ref, sha) in refs.items() if not ref.endswith(PEELED_TAG_SUFFIX)
         ref: sha for (ref, sha) in refs.items() if not ref.endswith(PEELED_TAG_SUFFIX)
     }
     }
 
 
 
 
-def split_peeled_refs(refs: dict[bytes, bytes]) -> tuple[dict[bytes, bytes], dict[bytes, bytes]]:
+def split_peeled_refs(refs: T) -> tuple[T, dict[bytes, bytes]]:
     """Split peeled refs from regular refs."""
     """Split peeled refs from regular refs."""
-    peeled = {}
-    regular = {}
+    peeled: dict[bytes, bytes] = {}
+    regular = {k: v for k, v in refs.items() if not k.endswith(PEELED_TAG_SUFFIX)}
+
     for ref, sha in refs.items():
     for ref, sha in refs.items():
         if ref.endswith(PEELED_TAG_SUFFIX):
         if ref.endswith(PEELED_TAG_SUFFIX):
-            peeled[ref[: -len(PEELED_TAG_SUFFIX)]] = sha
-        else:
-            regular[ref] = sha
+            # Only add to peeled dict if sha is not None
+            if sha is not None:
+                peeled[ref[: -len(PEELED_TAG_SUFFIX)]] = sha
+
     return regular, peeled
     return regular, peeled
 
 
 
 
-def _set_origin_head(refs: RefsContainer, origin: bytes, origin_head: Optional[bytes]) -> None:
+def _set_origin_head(
+    refs: RefsContainer, origin: bytes, origin_head: Optional[bytes]
+) -> None:
     # set refs/remotes/origin/HEAD
     # set refs/remotes/origin/HEAD
     origin_base = b"refs/remotes/" + origin + b"/"
     origin_base = b"refs/remotes/" + origin + b"/"
     if origin_head and origin_head.startswith(LOCAL_BRANCH_PREFIX):
     if origin_head and origin_head.startswith(LOCAL_BRANCH_PREFIX):
@@ -1409,7 +1493,9 @@ def _set_default_branch(
     return head_ref
     return head_ref
 
 
 
 
-def _set_head(refs: RefsContainer, head_ref: bytes, ref_message: Optional[bytes]) -> Optional[bytes]:
+def _set_head(
+    refs: RefsContainer, head_ref: bytes, ref_message: Optional[bytes]
+) -> Optional[bytes]:
     if head_ref.startswith(LOCAL_TAG_PREFIX):
     if head_ref.startswith(LOCAL_TAG_PREFIX):
         # detach HEAD at specified tag
         # detach HEAD at specified tag
         head = refs[head_ref]
         head = refs[head_ref]
@@ -1432,7 +1518,7 @@ def _set_head(refs: RefsContainer, head_ref: bytes, ref_message: Optional[bytes]
 def _import_remote_refs(
 def _import_remote_refs(
     refs_container: RefsContainer,
     refs_container: RefsContainer,
     remote_name: str,
     remote_name: str,
-    refs: dict[str, str],
+    refs: dict[bytes, Optional[bytes]],
     message: Optional[bytes] = None,
     message: Optional[bytes] = None,
     prune: bool = False,
     prune: bool = False,
     prune_tags: bool = False,
     prune_tags: bool = False,
@@ -1441,7 +1527,7 @@ def _import_remote_refs(
     branches = {
     branches = {
         n[len(LOCAL_BRANCH_PREFIX) :]: v
         n[len(LOCAL_BRANCH_PREFIX) :]: v
         for (n, v) in stripped_refs.items()
         for (n, v) in stripped_refs.items()
-        if n.startswith(LOCAL_BRANCH_PREFIX)
+        if n.startswith(LOCAL_BRANCH_PREFIX) and v is not None
     }
     }
     refs_container.import_refs(
     refs_container.import_refs(
         b"refs/remotes/" + remote_name.encode(),
         b"refs/remotes/" + remote_name.encode(),
@@ -1452,14 +1538,18 @@ def _import_remote_refs(
     tags = {
     tags = {
         n[len(LOCAL_TAG_PREFIX) :]: v
         n[len(LOCAL_TAG_PREFIX) :]: v
         for (n, v) in stripped_refs.items()
         for (n, v) in stripped_refs.items()
-        if n.startswith(LOCAL_TAG_PREFIX) and not n.endswith(PEELED_TAG_SUFFIX)
+        if n.startswith(LOCAL_TAG_PREFIX)
+        and not n.endswith(PEELED_TAG_SUFFIX)
+        and v is not None
     }
     }
     refs_container.import_refs(
     refs_container.import_refs(
         LOCAL_TAG_PREFIX, tags, message=message, prune=prune_tags
         LOCAL_TAG_PREFIX, tags, message=message, prune=prune_tags
     )
     )
 
 
 
 
-def serialize_refs(store: ObjectContainer, refs: dict[bytes, bytes]) -> dict[bytes, bytes]:
+def serialize_refs(
+    store: ObjectContainer, refs: dict[bytes, bytes]
+) -> dict[bytes, bytes]:
     """Serialize refs with peeled refs.
     """Serialize refs with peeled refs.
 
 
     Args:
     Args:
@@ -1556,6 +1646,7 @@ class locked_ref:
         if not self._file:
         if not self._file:
             raise RuntimeError("locked_ref not in context")
             raise RuntimeError("locked_ref not in context")
 
 
+        assert self._realname is not None
         current_ref = self._refs_container.read_loose_ref(self._realname)
         current_ref = self._refs_container.read_loose_ref(self._realname)
         if current_ref is None:
         if current_ref is None:
             current_ref = self._refs_container.get_packed_refs().get(
             current_ref = self._refs_container.get_packed_refs().get(
@@ -1622,3 +1713,14 @@ class locked_ref:
             self._refs_container._remove_packed_ref(self._realname)
             self._refs_container._remove_packed_ref(self._realname)
 
 
         self._deleted = True
         self._deleted = True
+
+
+def filter_ref_prefix(refs: T, prefixes: Iterable[bytes]) -> T:
+    """Filter refs to only include those with a given prefix.
+
+    Args:
+      refs: A dictionary of refs.
+      prefixes: The prefixes to filter by.
+    """
+    filtered = {k: v for k, v in refs.items() if any(k.startswith(p) for p in prefixes)}
+    return cast(T, filtered)

+ 1 - 1
dulwich/reftable.py

@@ -700,7 +700,7 @@ class ReftableReader:
         # Read magic bytes
         # Read magic bytes
         magic = self.f.read(4)
         magic = self.f.read(4)
         if magic != REFTABLE_MAGIC:
         if magic != REFTABLE_MAGIC:
-            raise ValueError(f"Invalid reftable magic: {magic}")
+            raise ValueError(f"Invalid reftable magic: {magic!r}")
 
 
         # Read version + block size (4 bytes total, big-endian network order)
         # Read version + block size (4 bytes total, big-endian network order)
         # Format: uint8(version) + uint24(block_size)
         # Format: uint8(version) + uint24(block_size)

+ 62 - 38
dulwich/repo.py

@@ -34,7 +34,7 @@ import stat
 import sys
 import sys
 import time
 import time
 import warnings
 import warnings
-from collections.abc import Iterable
+from collections.abc import Iterable, Iterator
 from io import BytesIO
 from io import BytesIO
 from typing import (
 from typing import (
     TYPE_CHECKING,
     TYPE_CHECKING,
@@ -42,6 +42,7 @@ from typing import (
     BinaryIO,
     BinaryIO,
     Callable,
     Callable,
     Optional,
     Optional,
+    TypeVar,
     Union,
     Union,
 )
 )
 
 
@@ -52,8 +53,11 @@ if TYPE_CHECKING:
     from .attrs import GitAttributes
     from .attrs import GitAttributes
     from .config import ConditionMatcher, ConfigFile, StackedConfig
     from .config import ConditionMatcher, ConfigFile, StackedConfig
     from .index import Index
     from .index import Index
+    from .line_ending import BlobNormalizer
     from .notes import Notes
     from .notes import Notes
-    from .object_store import BaseObjectStore
+    from .object_store import BaseObjectStore, GraphWalker, UnpackedObject
+    from .rebase import RebaseStateManager
+    from .walk import Walker
     from .worktree import WorkTree
     from .worktree import WorkTree
 
 
 from . import replace_me
 from . import replace_me
@@ -116,6 +120,8 @@ from .refs import (
 
 
 CONTROLDIR = ".git"
 CONTROLDIR = ".git"
 OBJECTDIR = "objects"
 OBJECTDIR = "objects"
+
+T = TypeVar("T", bound="ShaFile")
 REFSDIR = "refs"
 REFSDIR = "refs"
 REFSDIR_TAGS = "tags"
 REFSDIR_TAGS = "tags"
 REFSDIR_HEADS = "heads"
 REFSDIR_HEADS = "heads"
@@ -248,11 +254,11 @@ def check_user_identity(identity: bytes) -> None:
     try:
     try:
         fst, snd = identity.split(b" <", 1)
         fst, snd = identity.split(b" <", 1)
     except ValueError as exc:
     except ValueError as exc:
-        raise InvalidUserIdentity(identity) from exc
+        raise InvalidUserIdentity(identity.decode("utf-8", "replace")) from exc
     if b">" not in snd:
     if b">" not in snd:
-        raise InvalidUserIdentity(identity)
+        raise InvalidUserIdentity(identity.decode("utf-8", "replace"))
     if b"\0" in identity or b"\n" in identity:
     if b"\0" in identity or b"\n" in identity:
-        raise InvalidUserIdentity(identity)
+        raise InvalidUserIdentity(identity.decode("utf-8", "replace"))
 
 
 
 
 def parse_graftpoints(
 def parse_graftpoints(
@@ -333,7 +339,12 @@ def _set_filesystem_hidden(path: str) -> None:
 
 
 
 
 class ParentsProvider:
 class ParentsProvider:
-    def __init__(self, store: "BaseObjectStore", grafts: dict = {}, shallows: list = []) -> None:
+    def __init__(
+        self,
+        store: "BaseObjectStore",
+        grafts: dict = {},
+        shallows: Iterable[bytes] = [],
+    ) -> None:
         self.store = store
         self.store = store
         self.grafts = grafts
         self.grafts = grafts
         self.shallows = set(shallows)
         self.shallows = set(shallows)
@@ -341,7 +352,9 @@ class ParentsProvider:
         # Get commit graph once at initialization for performance
         # Get commit graph once at initialization for performance
         self.commit_graph = store.get_commit_graph()
         self.commit_graph = store.get_commit_graph()
 
 
-    def get_parents(self, commit_id: bytes, commit: Optional[Any] = None) -> list[bytes]:
+    def get_parents(
+        self, commit_id: bytes, commit: Optional[Commit] = None
+    ) -> list[bytes]:
         try:
         try:
             return self.grafts[commit_id]
             return self.grafts[commit_id]
         except KeyError:
         except KeyError:
@@ -357,7 +370,9 @@ class ParentsProvider:
 
 
         # Fallback to reading the commit object
         # Fallback to reading the commit object
         if commit is None:
         if commit is None:
-            commit = self.store[commit_id]
+            obj = self.store[commit_id]
+            assert isinstance(obj, Commit)
+            commit = obj
         return commit.parents
         return commit.parents
 
 
 
 
@@ -472,7 +487,11 @@ class BaseRepo:
         raise NotImplementedError(self.open_index)
         raise NotImplementedError(self.open_index)
 
 
     def fetch(
     def fetch(
-        self, target: "BaseRepo", determine_wants: Optional[Callable] = None, progress: Optional[Callable] = None, depth: Optional[int] = None
+        self,
+        target: "BaseRepo",
+        determine_wants: Optional[Callable] = None,
+        progress: Optional[Callable] = None,
+        depth: Optional[int] = None,
     ) -> dict:
     ) -> dict:
         """Fetch objects into another repository.
         """Fetch objects into another repository.
 
 
@@ -498,7 +517,7 @@ class BaseRepo:
     def fetch_pack_data(
     def fetch_pack_data(
         self,
         self,
         determine_wants: Callable,
         determine_wants: Callable,
-        graph_walker: Any,
+        graph_walker: "GraphWalker",
         progress: Optional[Callable],
         progress: Optional[Callable],
         *,
         *,
         get_tagged: Optional[Callable] = None,
         get_tagged: Optional[Callable] = None,
@@ -533,7 +552,7 @@ class BaseRepo:
     def find_missing_objects(
     def find_missing_objects(
         self,
         self,
         determine_wants: Callable,
         determine_wants: Callable,
-        graph_walker: Any,
+        graph_walker: "GraphWalker",
         progress: Optional[Callable],
         progress: Optional[Callable],
         *,
         *,
         get_tagged: Optional[Callable] = None,
         get_tagged: Optional[Callable] = None,
@@ -563,16 +582,17 @@ class BaseRepo:
         current_shallow = set(getattr(graph_walker, "shallow", set()))
         current_shallow = set(getattr(graph_walker, "shallow", set()))
 
 
         if depth not in (None, 0):
         if depth not in (None, 0):
+            assert depth is not None
             shallow, not_shallow = find_shallow(self.object_store, wants, depth)
             shallow, not_shallow = find_shallow(self.object_store, wants, depth)
             # Only update if graph_walker has shallow attribute
             # Only update if graph_walker has shallow attribute
             if hasattr(graph_walker, "shallow"):
             if hasattr(graph_walker, "shallow"):
                 graph_walker.shallow.update(shallow - not_shallow)
                 graph_walker.shallow.update(shallow - not_shallow)
                 new_shallow = graph_walker.shallow - current_shallow
                 new_shallow = graph_walker.shallow - current_shallow
-                unshallow = graph_walker.unshallow = not_shallow & current_shallow
+                unshallow = graph_walker.unshallow = not_shallow & current_shallow  # type: ignore[attr-defined]
                 if hasattr(graph_walker, "update_shallow"):
                 if hasattr(graph_walker, "update_shallow"):
                     graph_walker.update_shallow(new_shallow, unshallow)
                     graph_walker.update_shallow(new_shallow, unshallow)
         else:
         else:
-            unshallow = getattr(graph_walker, "unshallow", frozenset())
+            unshallow = getattr(graph_walker, "unshallow", set())
 
 
         if wants == []:
         if wants == []:
             # TODO(dborowitz): find a way to short-circuit that doesn't change
             # TODO(dborowitz): find a way to short-circuit that doesn't change
@@ -596,7 +616,7 @@ class BaseRepo:
                 def __len__(self) -> int:
                 def __len__(self) -> int:
                     return 0
                     return 0
 
 
-                def __iter__(self) -> Any:
+                def __iter__(self) -> Iterator[tuple[bytes, Optional[bytes]]]:
                     yield from []
                     yield from []
 
 
             return DummyMissingObjectFinder()  # type: ignore
             return DummyMissingObjectFinder()  # type: ignore
@@ -615,7 +635,7 @@ class BaseRepo:
 
 
         parents_provider = ParentsProvider(self.object_store, shallows=current_shallow)
         parents_provider = ParentsProvider(self.object_store, shallows=current_shallow)
 
 
-        def get_parents(commit: Any) -> list[bytes]:
+        def get_parents(commit: Commit) -> list[bytes]:
             """Get parents for a commit using the parents provider.
             """Get parents for a commit using the parents provider.
 
 
             Args:
             Args:
@@ -638,11 +658,11 @@ class BaseRepo:
 
 
     def generate_pack_data(
     def generate_pack_data(
         self,
         self,
-        have: list[ObjectID],
-        want: list[ObjectID],
+        have: Iterable[ObjectID],
+        want: Iterable[ObjectID],
         progress: Optional[Callable[[str], None]] = None,
         progress: Optional[Callable[[str], None]] = None,
         ofs_delta: Optional[bool] = None,
         ofs_delta: Optional[bool] = None,
-    ) -> Any:
+    ) -> tuple[int, Iterator["UnpackedObject"]]:
         """Generate pack data objects for a set of wants/haves.
         """Generate pack data objects for a set of wants/haves.
 
 
         Args:
         Args:
@@ -697,18 +717,18 @@ class BaseRepo:
         # TODO: move this method to WorkTree
         # TODO: move this method to WorkTree
         return self.refs[b"HEAD"]
         return self.refs[b"HEAD"]
 
 
-    def _get_object(self, sha: bytes, cls: Any) -> Any:
+    def _get_object(self, sha: bytes, cls: type[T]) -> T:
         assert len(sha) in (20, 40)
         assert len(sha) in (20, 40)
         ret = self.get_object(sha)
         ret = self.get_object(sha)
         if not isinstance(ret, cls):
         if not isinstance(ret, cls):
             if cls is Commit:
             if cls is Commit:
-                raise NotCommitError(ret)
+                raise NotCommitError(ret.id)
             elif cls is Blob:
             elif cls is Blob:
-                raise NotBlobError(ret)
+                raise NotBlobError(ret.id)
             elif cls is Tree:
             elif cls is Tree:
-                raise NotTreeError(ret)
+                raise NotTreeError(ret.id)
             elif cls is Tag:
             elif cls is Tag:
-                raise NotTagError(ret)
+                raise NotTagError(ret.id)
             else:
             else:
                 raise Exception(f"Type invalid: {ret.type_name!r} != {cls.type_name!r}")
                 raise Exception(f"Type invalid: {ret.type_name!r} != {cls.type_name!r}")
         return ret
         return ret
@@ -776,14 +796,14 @@ class BaseRepo:
         """
         """
         raise NotImplementedError(self.set_description)
         raise NotImplementedError(self.set_description)
 
 
-    def get_rebase_state_manager(self) -> Any:
+    def get_rebase_state_manager(self) -> "RebaseStateManager":
         """Get the appropriate rebase state manager for this repository.
         """Get the appropriate rebase state manager for this repository.
 
 
         Returns: RebaseStateManager instance
         Returns: RebaseStateManager instance
         """
         """
         raise NotImplementedError(self.get_rebase_state_manager)
         raise NotImplementedError(self.get_rebase_state_manager)
 
 
-    def get_blob_normalizer(self) -> Any:
+    def get_blob_normalizer(self) -> "BlobNormalizer":
         """Return a BlobNormalizer object for checkin/checkout operations.
         """Return a BlobNormalizer object for checkin/checkout operations.
 
 
         Returns: BlobNormalizer instance
         Returns: BlobNormalizer instance
@@ -831,7 +851,9 @@ class BaseRepo:
         with f:
         with f:
             return {line.strip() for line in f}
             return {line.strip() for line in f}
 
 
-    def update_shallow(self, new_shallow: Any, new_unshallow: Any) -> None:
+    def update_shallow(
+        self, new_shallow: Optional[set[bytes]], new_unshallow: Optional[set[bytes]]
+    ) -> None:
         """Update the list of shallow objects.
         """Update the list of shallow objects.
 
 
         Args:
         Args:
@@ -873,7 +895,7 @@ class BaseRepo:
 
 
         return Notes(self.object_store, self.refs)
         return Notes(self.object_store, self.refs)
 
 
-    def get_walker(self, include: Optional[list[bytes]] = None, **kwargs) -> Any:
+    def get_walker(self, include: Optional[list[bytes]] = None, **kwargs) -> "Walker":
         """Obtain a walker for this repository.
         """Obtain a walker for this repository.
 
 
         Args:
         Args:
@@ -910,7 +932,7 @@ class BaseRepo:
 
 
         return Walker(self.object_store, include, **kwargs)
         return Walker(self.object_store, include, **kwargs)
 
 
-    def __getitem__(self, name: Union[ObjectID, Ref]) -> Any:
+    def __getitem__(self, name: Union[ObjectID, Ref]) -> "ShaFile":
         """Retrieve a Git object by SHA1 or ref.
         """Retrieve a Git object by SHA1 or ref.
 
 
         Args:
         Args:
@@ -1002,7 +1024,7 @@ class BaseRepo:
         for sha in to_remove:
         for sha in to_remove:
             del self._graftpoints[sha]
             del self._graftpoints[sha]
 
 
-    def _read_heads(self, name: str) -> Any:
+    def _read_heads(self, name: str) -> list[bytes]:
         f = self.get_named_file(name)
         f = self.get_named_file(name)
         if f is None:
         if f is None:
             return []
             return []
@@ -1028,17 +1050,17 @@ class BaseRepo:
         message: Optional[bytes] = None,
         message: Optional[bytes] = None,
         committer: Optional[bytes] = None,
         committer: Optional[bytes] = None,
         author: Optional[bytes] = None,
         author: Optional[bytes] = None,
-        commit_timestamp: Optional[Any] = None,
-        commit_timezone: Optional[Any] = None,
-        author_timestamp: Optional[Any] = None,
-        author_timezone: Optional[Any] = None,
+        commit_timestamp: Optional[float] = None,
+        commit_timezone: Optional[int] = None,
+        author_timestamp: Optional[float] = None,
+        author_timezone: Optional[int] = None,
         tree: Optional[ObjectID] = None,
         tree: Optional[ObjectID] = None,
         encoding: Optional[bytes] = None,
         encoding: Optional[bytes] = None,
         ref: Optional[Ref] = b"HEAD",
         ref: Optional[Ref] = b"HEAD",
         merge_heads: Optional[list[ObjectID]] = None,
         merge_heads: Optional[list[ObjectID]] = None,
         no_verify: bool = False,
         no_verify: bool = False,
         sign: bool = False,
         sign: bool = False,
-    ) -> Any:
+    ) -> bytes:
         """Create a new commit.
         """Create a new commit.
 
 
         If not specified, committer and author default to
         If not specified, committer and author default to
@@ -1097,9 +1119,9 @@ def read_gitfile(f: BinaryIO) -> str:
     Returns: A path
     Returns: A path
     """
     """
     cs = f.read()
     cs = f.read()
-    if not cs.startswith("gitdir: "):
+    if not cs.startswith(b"gitdir: "):
         raise ValueError("Expected file to start with 'gitdir: '")
         raise ValueError("Expected file to start with 'gitdir: '")
-    return cs[len("gitdir: ") :].rstrip("\n")
+    return cs[len(b"gitdir: ") :].rstrip(b"\n").decode("utf-8")
 
 
 
 
 class UnsupportedVersion(Exception):
 class UnsupportedVersion(Exception):
@@ -1183,7 +1205,7 @@ class Repo(BaseRepo):
         self.bare = bare
         self.bare = bare
         if bare is False:
         if bare is False:
             if os.path.isfile(hidden_path):
             if os.path.isfile(hidden_path):
-                with open(hidden_path) as f:
+                with open(hidden_path, "rb") as f:
                     path = read_gitfile(f)
                     path = read_gitfile(f)
                 self._controldir = os.path.join(root, path)
                 self._controldir = os.path.join(root, path)
             else:
             else:
@@ -2018,6 +2040,7 @@ class Repo(BaseRepo):
                 if isinstance(head, Tag):
                 if isinstance(head, Tag):
                     _cls, obj = head.object
                     _cls, obj = head.object
                     head = self.get_object(obj)
                     head = self.get_object(obj)
+                assert isinstance(head, Commit)
                 tree = head.tree
                 tree = head.tree
             except KeyError:
             except KeyError:
                 # No HEAD, no attributes from tree
                 # No HEAD, no attributes from tree
@@ -2026,6 +2049,7 @@ class Repo(BaseRepo):
         if tree is not None:
         if tree is not None:
             try:
             try:
                 tree_obj = self[tree]
                 tree_obj = self[tree]
+                assert isinstance(tree_obj, Tree)
                 if b".gitattributes" in tree_obj:
                 if b".gitattributes" in tree_obj:
                     _, attrs_sha = tree_obj[b".gitattributes"]
                     _, attrs_sha = tree_obj[b".gitattributes"]
                     attrs_blob = self[attrs_sha]
                     attrs_blob = self[attrs_sha]
@@ -2114,7 +2138,7 @@ class MemoryRepo(BaseRepo):
 
 
         self._reflog: list[Any] = []
         self._reflog: list[Any] = []
         refs_container = DictRefsContainer({}, logger=self._append_reflog)
         refs_container = DictRefsContainer({}, logger=self._append_reflog)
-        BaseRepo.__init__(self, MemoryObjectStore(), refs_container)  # type: ignore
+        BaseRepo.__init__(self, MemoryObjectStore(), refs_container)  # type: ignore[arg-type]
         self._named_files: dict[str, bytes] = {}
         self._named_files: dict[str, bytes] = {}
         self.bare = True
         self.bare = True
         self._config = ConfigFile()
         self._config = ConfigFile()

+ 15 - 3
dulwich/server.py

@@ -52,7 +52,7 @@ import time
 import zlib
 import zlib
 from collections.abc import Iterable, Iterator
 from collections.abc import Iterable, Iterator
 from functools import partial
 from functools import partial
-from typing import TYPE_CHECKING, Optional, cast
+from typing import TYPE_CHECKING, Optional
 from typing import Protocol as TypingProtocol
 from typing import Protocol as TypingProtocol
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -551,10 +551,12 @@ def _split_proto_line(line, allowed):
             COMMAND_SHALLOW,
             COMMAND_SHALLOW,
             COMMAND_UNSHALLOW,
             COMMAND_UNSHALLOW,
         ):
         ):
+            assert fields[1] is not None
             if not valid_hexsha(fields[1]):
             if not valid_hexsha(fields[1]):
                 raise GitProtocolError("Invalid sha")
                 raise GitProtocolError("Invalid sha")
             return tuple(fields)
             return tuple(fields)
         elif command == COMMAND_DEEPEN:
         elif command == COMMAND_DEEPEN:
+            assert fields[1] is not None
             return command, int(fields[1])
             return command, int(fields[1])
     raise GitProtocolError(f"Received invalid line from client: {line!r}")
     raise GitProtocolError(f"Received invalid line from client: {line!r}")
 
 
@@ -633,6 +635,9 @@ class AckGraphWalkerImpl:
         """
         """
         raise NotImplementedError
         raise NotImplementedError
 
 
+    def handle_done(self, done_required, done_received):
+        raise NotImplementedError
+
 
 
 class _ProtocolGraphWalker:
 class _ProtocolGraphWalker:
     """A graph walker that knows the git protocol.
     """A graph walker that knows the git protocol.
@@ -784,6 +789,7 @@ class _ProtocolGraphWalker:
         """
         """
         if len(have_ref) != 40:
         if len(have_ref) != 40:
             raise ValueError(f"invalid sha {have_ref!r}")
             raise ValueError(f"invalid sha {have_ref!r}")
+        assert self._impl is not None
         return self._impl.ack(have_ref)
         return self._impl.ack(have_ref)
 
 
     def reset(self) -> None:
     def reset(self) -> None:
@@ -799,7 +805,8 @@ class _ProtocolGraphWalker:
         if not self._cached:
         if not self._cached:
             if not self._impl and self.stateless_rpc:
             if not self._impl and self.stateless_rpc:
                 return None
                 return None
-            return next(self._impl)
+            assert self._impl is not None
+            return next(self._impl)  # type: ignore[call-overload]
         self._cache_index += 1
         self._cache_index += 1
         if self._cache_index > len(self._cache):
         if self._cache_index > len(self._cache):
             return None
             return None
@@ -884,6 +891,7 @@ class _ProtocolGraphWalker:
         Returns: True if done handling succeeded
         Returns: True if done handling succeeded
         """
         """
         # Delegate this to the implementation.
         # Delegate this to the implementation.
+        assert self._impl is not None
         return self._impl.handle_done(done_required, done_received)
         return self._impl.handle_done(done_required, done_received)
 
 
     def set_wants(self, wants) -> None:
     def set_wants(self, wants) -> None:
@@ -1400,7 +1408,11 @@ class UploadArchiveHandler(Handler):
                 format = arguments[i].decode("ascii")
                 format = arguments[i].decode("ascii")
             else:
             else:
                 commit_sha = self.repo.refs[argument]
                 commit_sha = self.repo.refs[argument]
-                tree = cast(Tree, store[cast(Commit, store[commit_sha]).tree])
+                commit_obj = store[commit_sha]
+                assert isinstance(commit_obj, Commit)
+                tree_obj = store[commit_obj.tree]
+                assert isinstance(tree_obj, Tree)
+                tree = tree_obj
             i += 1
             i += 1
         self.proto.write_pkt_line(b"ACK")
         self.proto.write_pkt_line(b"ACK")
         self.proto.write_pkt_line(None)
         self.proto.write_pkt_line(None)

+ 11 - 5
dulwich/stash.py

@@ -23,7 +23,7 @@
 
 
 import os
 import os
 import sys
 import sys
-from typing import TYPE_CHECKING, Optional, TypedDict
+from typing import TYPE_CHECKING, Optional, TypedDict, Union
 
 
 from .diff_tree import tree_changes
 from .diff_tree import tree_changes
 from .file import GitFile
 from .file import GitFile
@@ -162,10 +162,16 @@ class Stash:
             symlink_fn = symlink
             symlink_fn = symlink
         else:
         else:
 
 
-            def symlink_fn(source, target) -> None:  # type: ignore
-                mode = "w" + ("b" if isinstance(source, bytes) else "")
-                with open(target, mode) as f:
-                    f.write(source)
+            def symlink_fn(
+                src: Union[str, bytes, os.PathLike],
+                dst: Union[str, bytes, os.PathLike],
+                target_is_directory: bool = False,
+                *,
+                dir_fd: Optional[int] = None,
+            ) -> None:
+                mode = "w" + ("b" if isinstance(src, bytes) else "")
+                with open(dst, mode) as f:
+                    f.write(src)
 
 
         # Get blob normalizer for line ending conversion
         # Get blob normalizer for line ending conversion
         blob_normalizer = self._repo.get_blob_normalizer()
         blob_normalizer = self._repo.get_blob_normalizer()

+ 1 - 1
dulwich/tests/test_object_store.py

@@ -444,7 +444,7 @@ class FindShallowTests(TestCase):
     def make_linear_commits(self, n, message=b""):
     def make_linear_commits(self, n, message=b""):
         """Create a linear chain of commits."""
         """Create a linear chain of commits."""
         commits = []
         commits = []
-        parents = []
+        parents: list[bytes] = []
         for _ in range(n):
         for _ in range(n):
             commits.append(self.make_commit(parents=parents, message=message))
             commits.append(self.make_commit(parents=parents, message=message))
             parents = [commits[-1].id]
             parents = [commits[-1].id]

+ 19 - 4
dulwich/tests/utils.py

@@ -165,7 +165,9 @@ def make_tag(target: ShaFile, **attrs: Any) -> Tag:
     return make_object(Tag, **all_attrs)
     return make_object(Tag, **all_attrs)
 
 
 
 
-def functest_builder(method: Callable[[Any, Any], None], func: Any) -> Callable[[Any], None]:
+def functest_builder(
+    method: Callable[[Any, Any], None], func: Any
+) -> Callable[[Any], None]:
     """Generate a test method that tests the given function."""
     """Generate a test method that tests the given function."""
 
 
     def do_test(self: Any) -> None:
     def do_test(self: Any) -> None:
@@ -174,7 +176,9 @@ def functest_builder(method: Callable[[Any, Any], None], func: Any) -> Callable[
     return do_test
     return do_test
 
 
 
 
-def ext_functest_builder(method: Callable[[Any, Any], None], func: Any) -> Callable[[Any], None]:
+def ext_functest_builder(
+    method: Callable[[Any, Any], None], func: Any
+) -> Callable[[Any], None]:
     """Generate a test method that tests the given extension function.
     """Generate a test method that tests the given extension function.
 
 
     This is intended to generate test methods that test both a pure-Python
     This is intended to generate test methods that test both a pure-Python
@@ -204,7 +208,11 @@ def ext_functest_builder(method: Callable[[Any, Any], None], func: Any) -> Calla
     return do_test
     return do_test
 
 
 
 
-def build_pack(f: BinaryIO, objects_spec: list[tuple[int, Any]], store: Optional[BaseObjectStore] = None) -> list[tuple[int, int, bytes, bytes, int]]:
+def build_pack(
+    f: BinaryIO,
+    objects_spec: list[tuple[int, Any]],
+    store: Optional[BaseObjectStore] = None,
+) -> list[tuple[int, int, bytes, bytes, int]]:
     """Write test pack data from a concise spec.
     """Write test pack data from a concise spec.
 
 
     Args:
     Args:
@@ -282,7 +290,14 @@ def build_pack(f: BinaryIO, objects_spec: list[tuple[int, Any]], store: Optional
     return expected
     return expected
 
 
 
 
-def build_commit_graph(object_store: BaseObjectStore, commit_spec: list[list[int]], trees: Optional[dict[int, list[Union[tuple[bytes, ShaFile], tuple[bytes, ShaFile, int]]]]] = None, attrs: Optional[dict[int, dict[str, Any]]] = None) -> list[Commit]:
+def build_commit_graph(
+    object_store: BaseObjectStore,
+    commit_spec: list[list[int]],
+    trees: Optional[
+        dict[int, list[Union[tuple[bytes, ShaFile], tuple[bytes, ShaFile, int]]]]
+    ] = None,
+    attrs: Optional[dict[int, dict[str, Any]]] = None,
+) -> list[Commit]:
     """Build a commit graph from a concise specification.
     """Build a commit graph from a concise specification.
 
 
     Sample usage:
     Sample usage:

+ 12 - 5
dulwich/walk.py

@@ -87,7 +87,9 @@ class WalkEntry:
                 parent = None
                 parent = None
             elif len(self._get_parents(commit)) == 1:
             elif len(self._get_parents(commit)) == 1:
                 changes_func = tree_changes
                 changes_func = tree_changes
-                parent = cast(Commit, self._store[self._get_parents(commit)[0]]).tree
+                parent_commit = self._store[self._get_parents(commit)[0]]
+                assert isinstance(parent_commit, Commit)
+                parent = parent_commit.tree
                 if path_prefix:
                 if path_prefix:
                     mode, subtree_sha = parent.lookup_path(
                     mode, subtree_sha = parent.lookup_path(
                         self._store.__getitem__,
                         self._store.__getitem__,
@@ -96,9 +98,12 @@ class WalkEntry:
                     parent = self._store[subtree_sha]
                     parent = self._store[subtree_sha]
             else:
             else:
                 # For merge commits, we need to handle multiple parents differently
                 # For merge commits, we need to handle multiple parents differently
-                parent = [
-                    cast(Commit, self._store[p]).tree for p in self._get_parents(commit)
-                ]
+                parent_trees = []
+                for p in self._get_parents(commit):
+                    parent_commit = self._store[p]
+                    assert isinstance(parent_commit, Commit)
+                    parent_trees.append(parent_commit.tree)
+                parent = parent_trees
                 # Use a lambda to adapt the signature
                 # Use a lambda to adapt the signature
                 changes_func = cast(
                 changes_func = cast(
                     Any,
                     Any,
@@ -200,7 +205,9 @@ class _CommitTimeQueue:
                     # some caching (which DiskObjectStore currently does not).
                     # some caching (which DiskObjectStore currently does not).
                     # We could either add caching in this class or pass around
                     # We could either add caching in this class or pass around
                     # parsed queue entry objects instead of commits.
                     # parsed queue entry objects instead of commits.
-                    todo.append(cast(Commit, self._store[parent]))
+                    parent_commit = self._store[parent]
+                    assert isinstance(parent_commit, Commit)
+                    todo.append(parent_commit)
                 excluded.add(parent)
                 excluded.add(parent)
 
 
     def next(self) -> Optional[WalkEntry]:
     def next(self) -> Optional[WalkEntry]:

+ 86 - 23
dulwich/web.py

@@ -26,9 +26,10 @@ import os
 import re
 import re
 import sys
 import sys
 import time
 import time
-from collections.abc import Iterator
+from collections.abc import Iterable, Iterator
 from io import BytesIO
 from io import BytesIO
-from typing import Any, BinaryIO, Callable, ClassVar, Optional, cast
+from types import TracebackType
+from typing import Any, BinaryIO, Callable, ClassVar, Optional, Union, cast
 from urllib.parse import parse_qs
 from urllib.parse import parse_qs
 from wsgiref.simple_server import (
 from wsgiref.simple_server import (
     ServerHandler,
     ServerHandler,
@@ -36,6 +37,7 @@ from wsgiref.simple_server import (
     WSGIServer,
     WSGIServer,
     make_server,
     make_server,
 )
 )
+from wsgiref.types import StartResponse, WSGIApplication, WSGIEnvironment
 
 
 from dulwich import log_utils
 from dulwich import log_utils
 
 
@@ -45,6 +47,7 @@ from .server import (
     DEFAULT_HANDLERS,
     DEFAULT_HANDLERS,
     Backend,
     Backend,
     DictBackend,
     DictBackend,
+    Handler,
     generate_info_refs,
     generate_info_refs,
     generate_objects_info_packs,
     generate_objects_info_packs,
 )
 )
@@ -292,13 +295,21 @@ def get_info_refs(
         yield req.not_found(str(e))
         yield req.not_found(str(e))
         return
         return
     if service and not req.dumb:
     if service and not req.dumb:
+        if req.handlers is None:
+            yield req.forbidden("No handlers configured")
+            return
         handler_cls = req.handlers.get(service.encode("ascii"), None)
         handler_cls = req.handlers.get(service.encode("ascii"), None)
         if handler_cls is None:
         if handler_cls is None:
             yield req.forbidden("Unsupported service")
             yield req.forbidden("Unsupported service")
             return
             return
         req.nocache()
         req.nocache()
         write = req.respond(HTTP_OK, f"application/x-{service}-advertisement")
         write = req.respond(HTTP_OK, f"application/x-{service}-advertisement")
-        proto = ReceivableProtocol(BytesIO().read, write)
+
+        def write_fn(data: bytes) -> Optional[int]:
+            result = write(data)
+            return len(data) if result is not None else None
+
+        proto = ReceivableProtocol(BytesIO().read, write_fn)
         handler = handler_cls(
         handler = handler_cls(
             backend,
             backend,
             [url_prefix(mat)],
             [url_prefix(mat)],
@@ -425,6 +436,9 @@ def handle_service_request(
     """
     """
     service = mat.group().lstrip("/")
     service = mat.group().lstrip("/")
     logger.info("Handling service request for %s", service)
     logger.info("Handling service request for %s", service)
+    if req.handlers is None:
+        yield req.forbidden("No handlers configured")
+        return
     handler_cls = req.handlers.get(service.encode("ascii"), None)
     handler_cls = req.handlers.get(service.encode("ascii"), None)
     if handler_cls is None:
     if handler_cls is None:
         yield req.forbidden("Unsupported service")
         yield req.forbidden("Unsupported service")
@@ -436,11 +450,16 @@ def handle_service_request(
         return
         return
     req.nocache()
     req.nocache()
     write = req.respond(HTTP_OK, f"application/x-{service}-result")
     write = req.respond(HTTP_OK, f"application/x-{service}-result")
+
+    def write_fn(data: bytes) -> Optional[int]:
+        result = write(data)
+        return len(data) if result is not None else None
+
     if req.environ.get("HTTP_TRANSFER_ENCODING") == "chunked":
     if req.environ.get("HTTP_TRANSFER_ENCODING") == "chunked":
         read = ChunkReader(req.environ["wsgi.input"]).read
         read = ChunkReader(req.environ["wsgi.input"]).read
     else:
     else:
         read = req.environ["wsgi.input"].read
         read = req.environ["wsgi.input"].read
-    proto = ReceivableProtocol(read, write)
+    proto = ReceivableProtocol(read, write_fn)
     # TODO(jelmer): Find a way to pass in repo, rather than having handler_cls
     # TODO(jelmer): Find a way to pass in repo, rather than having handler_cls
     # reopen.
     # reopen.
     handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True)
     handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True)
@@ -455,7 +474,11 @@ class HTTPGitRequest:
     """
     """
 
 
     def __init__(
     def __init__(
-        self, environ: dict[str, Any], start_response: Callable[[str, list[tuple[str, str]]], Any], dumb: bool = False, handlers: Optional[dict[bytes, Callable]] = None
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+        dumb: bool = False,
+        handlers: Optional[dict[bytes, Callable]] = None,
     ) -> None:
     ) -> None:
         """Initialize HTTPGitRequest.
         """Initialize HTTPGitRequest.
 
 
@@ -481,7 +504,7 @@ class HTTPGitRequest:
         status: str = HTTP_OK,
         status: str = HTTP_OK,
         content_type: Optional[str] = None,
         content_type: Optional[str] = None,
         headers: Optional[list[tuple[str, str]]] = None,
         headers: Optional[list[tuple[str, str]]] = None,
-    ) -> Any:
+    ) -> Callable[[bytes], object]:
         """Begin a response with the given status and other headers."""
         """Begin a response with the given status and other headers."""
         if headers:
         if headers:
             self._headers.extend(headers)
             self._headers.extend(headers)
@@ -556,7 +579,11 @@ class HTTPGitApplication:
     }
     }
 
 
     def __init__(
     def __init__(
-        self, backend: Backend, dumb: bool = False, handlers: Optional[dict[bytes, Callable]] = None, fallback_app: Optional[Callable[[dict[str, Any], Callable], Any]] = None
+        self,
+        backend: Backend,
+        dumb: bool = False,
+        handlers: Optional[dict[bytes, Callable]] = None,
+        fallback_app: Optional[WSGIApplication] = None,
     ) -> None:
     ) -> None:
         """Initialize HTTPGitApplication.
         """Initialize HTTPGitApplication.
 
 
@@ -568,12 +595,18 @@ class HTTPGitApplication:
         """
         """
         self.backend = backend
         self.backend = backend
         self.dumb = dumb
         self.dumb = dumb
-        self.handlers = dict(DEFAULT_HANDLERS)
+        self.handlers: dict[bytes, Union[type[Handler], Callable[..., Any]]] = dict(
+            DEFAULT_HANDLERS
+        )
         self.fallback_app = fallback_app
         self.fallback_app = fallback_app
         if handlers is not None:
         if handlers is not None:
             self.handlers.update(handlers)
             self.handlers.update(handlers)
 
 
-    def __call__(self, environ: dict[str, Any], start_response: Callable[[str, list[tuple[str, str]]], Any]) -> list[bytes]:
+    def __call__(
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+    ) -> Iterable[bytes]:
         path = environ["PATH_INFO"]
         path = environ["PATH_INFO"]
         method = environ["REQUEST_METHOD"]
         method = environ["REQUEST_METHOD"]
         req = HTTPGitRequest(
         req = HTTPGitRequest(
@@ -581,6 +614,7 @@ class HTTPGitApplication:
         )
         )
         # environ['QUERY_STRING'] has qs args
         # environ['QUERY_STRING'] has qs args
         handler = None
         handler = None
+        mat = None
         for smethod, spath in self.services.keys():
         for smethod, spath in self.services.keys():
             if smethod != method:
             if smethod != method:
                 continue
                 continue
@@ -589,7 +623,7 @@ class HTTPGitApplication:
                 handler = self.services[smethod, spath]
                 handler = self.services[smethod, spath]
                 break
                 break
 
 
-        if handler is None:
+        if handler is None or mat is None:
             if self.fallback_app is not None:
             if self.fallback_app is not None:
                 return self.fallback_app(environ, start_response)
                 return self.fallback_app(environ, start_response)
             else:
             else:
@@ -601,10 +635,14 @@ class HTTPGitApplication:
 class GunzipFilter:
 class GunzipFilter:
     """WSGI middleware that unzips gzip-encoded requests before passing on to the underlying application."""
     """WSGI middleware that unzips gzip-encoded requests before passing on to the underlying application."""
 
 
-    def __init__(self, application: Any) -> None:
+    def __init__(self, application: WSGIApplication) -> None:
         self.app = application
         self.app = application
 
 
-    def __call__(self, environ: dict[str, Any], start_response: Callable[[str, list[tuple[str, str]]], Any]) -> Any:
+    def __call__(
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+    ) -> Iterable[bytes]:
         import gzip
         import gzip
 
 
         if environ.get("HTTP_CONTENT_ENCODING", "") == "gzip":
         if environ.get("HTTP_CONTENT_ENCODING", "") == "gzip":
@@ -620,10 +658,14 @@ class GunzipFilter:
 class LimitedInputFilter:
 class LimitedInputFilter:
     """WSGI middleware that limits the input length of a request to that specified in Content-Length."""
     """WSGI middleware that limits the input length of a request to that specified in Content-Length."""
 
 
-    def __init__(self, application: Any) -> None:
+    def __init__(self, application: WSGIApplication) -> None:
         self.app = application
         self.app = application
 
 
-    def __call__(self, environ: dict[str, Any], start_response: Callable[[str, list[tuple[str, str]]], Any]) -> Any:
+    def __call__(
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+    ) -> Iterable[bytes]:
         # This is not necessary if this app is run from a conforming WSGI
         # This is not necessary if this app is run from a conforming WSGI
         # server. Unfortunately, there's no way to tell that at this point.
         # server. Unfortunately, there's no way to tell that at this point.
         # TODO: git may used HTTP/1.1 chunked encoding instead of specifying
         # TODO: git may used HTTP/1.1 chunked encoding instead of specifying
@@ -636,11 +678,18 @@ class LimitedInputFilter:
         return self.app(environ, start_response)
         return self.app(environ, start_response)
 
 
 
 
-def make_wsgi_chain(*args: Any, **kwargs: Any) -> Any:
+def make_wsgi_chain(
+    backend: Backend,
+    dumb: bool = False,
+    handlers: Optional[dict[bytes, Callable[..., Any]]] = None,
+    fallback_app: Optional[WSGIApplication] = None,
+) -> WSGIApplication:
     """Factory function to create an instance of HTTPGitApplication,
     """Factory function to create an instance of HTTPGitApplication,
     correctly wrapped with needed middleware.
     correctly wrapped with needed middleware.
     """
     """
-    app = HTTPGitApplication(*args, **kwargs)
+    app = HTTPGitApplication(
+        backend, dumb=dumb, handlers=handlers, fallback_app=fallback_app
+    )
     wrapped_app = LimitedInputFilter(GunzipFilter(app))
     wrapped_app = LimitedInputFilter(GunzipFilter(app))
     return wrapped_app
     return wrapped_app
 
 
@@ -648,32 +697,46 @@ def make_wsgi_chain(*args: Any, **kwargs: Any) -> Any:
 class ServerHandlerLogger(ServerHandler):
 class ServerHandlerLogger(ServerHandler):
     """ServerHandler that uses dulwich's logger for logging exceptions."""
     """ServerHandler that uses dulwich's logger for logging exceptions."""
 
 
-    def log_exception(self, exc_info: Any) -> None:
+    def log_exception(
+        self,
+        exc_info: Union[
+            tuple[type[BaseException], BaseException, TracebackType],
+            tuple[None, None, None],
+            None,
+        ],
+    ) -> None:
         logger.exception(
         logger.exception(
             "Exception happened during processing of request",
             "Exception happened during processing of request",
             exc_info=exc_info,
             exc_info=exc_info,
         )
         )
 
 
-    def log_message(self, format: str, *args: Any) -> None:
+    def log_message(self, format: str, *args: object) -> None:
         logger.info(format, *args)
         logger.info(format, *args)
 
 
-    def log_error(self, *args: Any) -> None:
+    def log_error(self, *args: object) -> None:
         logger.error(*args)
         logger.error(*args)
 
 
 
 
 class WSGIRequestHandlerLogger(WSGIRequestHandler):
 class WSGIRequestHandlerLogger(WSGIRequestHandler):
     """WSGIRequestHandler that uses dulwich's logger for logging exceptions."""
     """WSGIRequestHandler that uses dulwich's logger for logging exceptions."""
 
 
-    def log_exception(self, exc_info: Any) -> None:
+    def log_exception(
+        self,
+        exc_info: Union[
+            tuple[type[BaseException], BaseException, TracebackType],
+            tuple[None, None, None],
+            None,
+        ],
+    ) -> None:
         logger.exception(
         logger.exception(
             "Exception happened during processing of request",
             "Exception happened during processing of request",
             exc_info=exc_info,
             exc_info=exc_info,
         )
         )
 
 
-    def log_message(self, format: str, *args: Any) -> None:
+    def log_message(self, format: str, *args: object) -> None:
         logger.info(format, *args)
         logger.info(format, *args)
 
 
-    def log_error(self, *args: Any) -> None:
+    def log_error(self, *args: object) -> None:
         logger.error(*args)
         logger.error(*args)
 
 
     def handle(self) -> None:
     def handle(self) -> None:
@@ -695,7 +758,7 @@ class WSGIRequestHandlerLogger(WSGIRequestHandler):
 class WSGIServerLogger(WSGIServer):
 class WSGIServerLogger(WSGIServer):
     """WSGIServer that uses dulwich's logger for error handling."""
     """WSGIServer that uses dulwich's logger for error handling."""
 
 
-    def handle_error(self, request: Any, client_address: Any) -> None:
+    def handle_error(self, request: object, client_address: tuple[str, int]) -> None:
         """Handle an error."""
         """Handle an error."""
         logger.exception(
         logger.exception(
             f"Exception happened during processing of request from {client_address!s}"
             f"Exception happened during processing of request from {client_address!s}"

+ 40 - 15
dulwich/worktree.py

@@ -34,9 +34,10 @@ import warnings
 from collections.abc import Iterable, Iterator
 from collections.abc import Iterable, Iterator
 from contextlib import contextmanager
 from contextlib import contextmanager
 from pathlib import Path
 from pathlib import Path
+from typing import Any, Callable, Union
 
 
 from .errors import CommitError, HookError
 from .errors import CommitError, HookError
-from .objects import Commit, ObjectID, Tag, Tree
+from .objects import Blob, Commit, ObjectID, Tag, Tree
 from .refs import SYMREF, Ref
 from .refs import SYMREF, Ref
 from .repo import (
 from .repo import (
     GITDIR,
     GITDIR,
@@ -335,7 +336,7 @@ class WorkTree:
 
 
         index = self._repo.open_index()
         index = self._repo.open_index()
         try:
         try:
-            tree_id = self._repo[b"HEAD"].tree
+            commit = self._repo[b"HEAD"]
         except KeyError:
         except KeyError:
             # no head mean no commit in the repo
             # no head mean no commit in the repo
             for fs_path in fs_paths:
             for fs_path in fs_paths:
@@ -343,6 +344,9 @@ class WorkTree:
                 del index[tree_path]
                 del index[tree_path]
             index.write()
             index.write()
             return
             return
+        else:
+            assert isinstance(commit, Commit), "HEAD must be a commit"
+            tree_id = commit.tree
 
 
         for fs_path in fs_paths:
         for fs_path in fs_paths:
             tree_path = _fs_to_tree_path(fs_path)
             tree_path = _fs_to_tree_path(fs_path)
@@ -367,15 +371,19 @@ class WorkTree:
             except FileNotFoundError:
             except FileNotFoundError:
                 pass
                 pass
 
 
+            blob_obj = self._repo[tree_entry[1]]
+            assert isinstance(blob_obj, Blob)
+            blob_size = len(blob_obj.data)
+
             index_entry = IndexEntry(
             index_entry = IndexEntry(
-                ctime=(self._repo[b"HEAD"].commit_time, 0),
-                mtime=(self._repo[b"HEAD"].commit_time, 0),
+                ctime=(commit.commit_time, 0),
+                mtime=(commit.commit_time, 0),
                 dev=st.st_dev if st else 0,
                 dev=st.st_dev if st else 0,
                 ino=st.st_ino if st else 0,
                 ino=st.st_ino if st else 0,
                 mode=tree_entry[0],
                 mode=tree_entry[0],
                 uid=st.st_uid if st else 0,
                 uid=st.st_uid if st else 0,
                 gid=st.st_gid if st else 0,
                 gid=st.st_gid if st else 0,
-                size=len(self._repo[tree_entry[1]].data),
+                size=blob_size,
                 sha=tree_entry[1],
                 sha=tree_entry[1],
                 flags=0,
                 flags=0,
                 extended_flags=0,
                 extended_flags=0,
@@ -386,7 +394,7 @@ class WorkTree:
 
 
     def commit(
     def commit(
         self,
         self,
-        message: bytes | None = None,
+        message: Union[str, bytes, Callable[[Any, Commit], bytes], None] = None,
         committer: bytes | None = None,
         committer: bytes | None = None,
         author: bytes | None = None,
         author: bytes | None = None,
         commit_timestamp: float | None = None,
         commit_timestamp: float | None = None,
@@ -541,13 +549,18 @@ class WorkTree:
                 if should_sign:
                 if should_sign:
                     c.sign(keyid)
                     c.sign(keyid)
                 self._repo.object_store.add_object(c)
                 self._repo.object_store.add_object(c)
+                message_bytes = (
+                    message.encode() if isinstance(message, str) else message
+                )
                 ok = self._repo.refs.set_if_equals(
                 ok = self._repo.refs.set_if_equals(
                     ref,
                     ref,
                     old_head,
                     old_head,
                     c.id,
                     c.id,
-                    message=b"commit: " + message,
+                    message=b"commit: " + message_bytes,
                     committer=committer,
                     committer=committer,
-                    timestamp=commit_timestamp,
+                    timestamp=int(commit_timestamp)
+                    if commit_timestamp is not None
+                    else None,
                     timezone=commit_timezone,
                     timezone=commit_timezone,
                 )
                 )
             except KeyError:
             except KeyError:
@@ -555,12 +568,17 @@ class WorkTree:
                 if should_sign:
                 if should_sign:
                     c.sign(keyid)
                     c.sign(keyid)
                 self._repo.object_store.add_object(c)
                 self._repo.object_store.add_object(c)
+                message_bytes = (
+                    message.encode() if isinstance(message, str) else message
+                )
                 ok = self._repo.refs.add_if_new(
                 ok = self._repo.refs.add_if_new(
                     ref,
                     ref,
                     c.id,
                     c.id,
-                    message=b"commit: " + message,
+                    message=b"commit: " + message_bytes,
                     committer=committer,
                     committer=committer,
-                    timestamp=commit_timestamp,
+                    timestamp=int(commit_timestamp)
+                    if commit_timestamp is not None
+                    else None,
                     timezone=commit_timezone,
                     timezone=commit_timezone,
                 )
                 )
             if not ok:
             if not ok:
@@ -603,6 +621,9 @@ class WorkTree:
             if isinstance(head, Tag):
             if isinstance(head, Tag):
                 _cls, obj = head.object
                 _cls, obj = head.object
                 head = self._repo.get_object(obj)
                 head = self._repo.get_object(obj)
+            from .objects import Commit
+
+            assert isinstance(head, Commit)
             tree = head.tree
             tree = head.tree
         config = self._repo.get_config()
         config = self._repo.get_config()
         honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
         honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
@@ -616,11 +637,15 @@ class WorkTree:
             symlink_fn = symlink
             symlink_fn = symlink
         else:
         else:
 
 
-            def symlink_fn(source, target) -> None:  # type: ignore
-                with open(
-                    target, "w" + ("b" if isinstance(source, bytes) else "")
-                ) as f:
-                    f.write(source)
+            def symlink_fn(
+                src: Union[str, bytes, os.PathLike],
+                dst: Union[str, bytes, os.PathLike],
+                target_is_directory: bool = False,
+                *,
+                dir_fd: int | None = None,
+            ) -> None:
+                with open(dst, "w" + ("b" if isinstance(src, bytes) else "")) as f:
+                    f.write(src)
 
 
         blob_normalizer = self._repo.get_blob_normalizer()
         blob_normalizer = self._repo.get_blob_normalizer()
         return build_index_from_tree(
         return build_index_from_tree(

+ 2 - 1
pyproject.toml

@@ -25,7 +25,7 @@ classifiers = [
 requires-python = ">=3.9"
 requires-python = ">=3.9"
 dependencies = [
 dependencies = [
     "urllib3>=1.25",
     "urllib3>=1.25",
-    'typing_extensions >=4.0 ; python_version < "3.11"',
+    'typing_extensions >=4.0 ; python_version < "3.12"',
 ]
 ]
 dynamic = ["version"]
 dynamic = ["version"]
 license-files = ["COPYING"]
 license-files = ["COPYING"]
@@ -99,6 +99,7 @@ ignore = [
     "ANN205",
     "ANN205",
     "ANN206",
     "ANN206",
     "E501",  # line too long
     "E501",  # line too long
+    "UP007",  # Use X | Y for type annotations (Python 3.10+ syntax, but we support 3.9+)
 ]
 ]
 
 
 [tool.ruff.lint.pydocstyle]
 [tool.ruff.lint.pydocstyle]

+ 0 - 1
tests/contrib/__init__.py

@@ -23,7 +23,6 @@ import unittest
 
 
 
 
 def test_suite() -> unittest.TestSuite:
 def test_suite() -> unittest.TestSuite:
-
     names = [
     names = [
         "diffstat",
         "diffstat",
         "paramiko_vendor",
         "paramiko_vendor",

+ 20 - 4
tests/test_annotate.py

@@ -110,7 +110,9 @@ class UpdateLinesTestCase(TestCase):
 
 
     def test_update_lines_empty_new(self) -> None:
     def test_update_lines_empty_new(self) -> None:
         """Test update_lines with empty new blob."""
         """Test update_lines with empty new blob."""
-        old_lines: list[tuple[tuple[Any, Any], bytes]] = [(("commit1", "entry1"), b"line1")]
+        old_lines: list[tuple[tuple[Any, Any], bytes]] = [
+            (("commit1", "entry1"), b"line1")
+        ]
         new_blob = b""
         new_blob = b""
         new_history_data = ("commit2", "entry2")
         new_history_data = ("commit2", "entry2")
 
 
@@ -131,7 +133,9 @@ class AnnotateLinesTestCase(TestCase):
 
 
         shutil.rmtree(self.temp_dir)
         shutil.rmtree(self.temp_dir)
 
 
-    def _make_commit(self, blob_content: bytes, message: str, parent: Optional[bytes] = None) -> bytes:
+    def _make_commit(
+        self, blob_content: bytes, message: str, parent: Optional[bytes] = None
+    ) -> bytes:
         """Helper to create a commit with a single file."""
         """Helper to create a commit with a single file."""
         # Create blob
         # Create blob
         blob = Blob()
         blob = Blob()
@@ -222,7 +226,13 @@ class PorcelainAnnotateTestCase(TestCase):
 
 
         shutil.rmtree(self.temp_dir)
         shutil.rmtree(self.temp_dir)
 
 
-    def _make_commit_with_file(self, filename: str, content: bytes, message: str, parent: Optional[bytes] = None) -> bytes:
+    def _make_commit_with_file(
+        self,
+        filename: str,
+        content: bytes,
+        message: str,
+        parent: Optional[bytes] = None,
+    ) -> bytes:
         """Helper to create a commit with a file."""
         """Helper to create a commit with a file."""
         # Create blob
         # Create blob
         blob = Blob()
         blob = Blob()
@@ -315,7 +325,13 @@ class IntegrationTestCase(TestCase):
 
 
         shutil.rmtree(self.temp_dir)
         shutil.rmtree(self.temp_dir)
 
 
-    def _create_file_commit(self, filename: str, content: bytes, message: str, parent: Optional[bytes] = None) -> bytes:
+    def _create_file_commit(
+        self,
+        filename: str,
+        content: bytes,
+        message: str,
+        parent: Optional[bytes] = None,
+    ) -> bytes:
         """Helper to create a commit with file content."""
         """Helper to create a commit with file content."""
         # Write file to working directory
         # Write file to working directory
         filepath = os.path.join(self.temp_dir, filename)
         filepath = os.path.join(self.temp_dir, filename)

+ 3 - 1
tests/test_archive.py

@@ -46,7 +46,9 @@ class ArchiveTests(TestCase):
         self.addCleanup(tf.close)
         self.addCleanup(tf.close)
         self.assertEqual([], tf.getnames())
         self.assertEqual([], tf.getnames())
 
 
-    def _get_example_tar_stream(self, mtime: int, prefix: bytes = b"", format: str = "") -> BytesIO:
+    def _get_example_tar_stream(
+        self, mtime: int, prefix: bytes = b"", format: str = ""
+    ) -> BytesIO:
         store = MemoryObjectStore()
         store = MemoryObjectStore()
         b1 = Blob.from_string(b"somedata")
         b1 = Blob.from_string(b"somedata")
         store.add_object(b1)
         store.add_object(b1)

+ 2 - 2
tests/test_cloud_gcs.py

@@ -39,8 +39,8 @@ class GcsObjectStoreTests(unittest.TestCase):
         self.assertIn("git", repr(self.store))
         self.assertIn("git", repr(self.store))
 
 
     def test_remove_pack(self):
     def test_remove_pack(self):
-        """Test _remove_pack method."""
-        self.store._remove_pack("pack-1234")
+        """Test _remove_pack_by_name method."""
+        self.store._remove_pack_by_name("pack-1234")
         self.mock_bucket.delete_blobs.assert_called_once()
         self.mock_bucket.delete_blobs.assert_called_once()
         args = self.mock_bucket.delete_blobs.call_args[0][0]
         args = self.mock_bucket.delete_blobs.call_args[0][0]
         self.assertEqual(
         self.assertEqual(

+ 2 - 1
tests/test_commit_graph.py

@@ -736,7 +736,8 @@ class CommitGraphGenerationTests(unittest.TestCase):
         self.assertEqual(parents_provider.get_parents(commit1.id), [])  # type: ignore[no-untyped-call]
         self.assertEqual(parents_provider.get_parents(commit1.id), [])  # type: ignore[no-untyped-call]
         self.assertEqual(parents_provider.get_parents(commit2.id), [commit1.id])  # type: ignore[no-untyped-call]
         self.assertEqual(parents_provider.get_parents(commit2.id), [commit1.id])  # type: ignore[no-untyped-call]
         self.assertEqual(
         self.assertEqual(
-            parents_provider.get_parents(commit5.id), [commit3.id, commit4.id]  # type: ignore[no-untyped-call]
+            parents_provider.get_parents(commit5.id),
+            [commit3.id, commit4.id],  # type: ignore[no-untyped-call]
         )
         )
 
 
     def test_performance_with_commit_graph(self) -> None:
     def test_performance_with_commit_graph(self) -> None:

+ 12 - 3
tests/test_dumb.py

@@ -32,7 +32,12 @@ from dulwich.objects import Blob, Commit, ShaFile, Tag, Tree, sha_to_hex
 
 
 
 
 class MockResponse:
 class MockResponse:
-    def __init__(self, status: int = 200, content: bytes = b"", headers: Optional[dict[str, str]] = None) -> None:
+    def __init__(
+        self,
+        status: int = 200,
+        content: bytes = b"",
+        headers: Optional[dict[str, str]] = None,
+    ) -> None:
         self.status = status
         self.status = status
         self.content = content
         self.content = content
         self.headers = headers or {}
         self.headers = headers or {}
@@ -50,7 +55,9 @@ class DumbHTTPObjectStoreTests(TestCase):
         self.responses: dict[str, dict[str, Union[int, bytes]]] = {}
         self.responses: dict[str, dict[str, Union[int, bytes]]] = {}
         self.store = DumbHTTPObjectStore(self.base_url, self._mock_http_request)
         self.store = DumbHTTPObjectStore(self.base_url, self._mock_http_request)
 
 
-    def _mock_http_request(self, url: str, headers: dict[str, str]) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
+    def _mock_http_request(
+        self, url: str, headers: dict[str, str]
+    ) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
         """Mock HTTP request function."""
         """Mock HTTP request function."""
         if url in self.responses:
         if url in self.responses:
             resp_data = self.responses[url]
             resp_data = self.responses[url]
@@ -183,7 +190,9 @@ class DumbRemoteHTTPRepoTests(TestCase):
         self.responses: dict[str, dict[str, Union[int, bytes]]] = {}
         self.responses: dict[str, dict[str, Union[int, bytes]]] = {}
         self.repo = DumbRemoteHTTPRepo(self.base_url, self._mock_http_request)
         self.repo = DumbRemoteHTTPRepo(self.base_url, self._mock_http_request)
 
 
-    def _mock_http_request(self, url: str, headers: dict[str, str]) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
+    def _mock_http_request(
+        self, url: str, headers: dict[str, str]
+    ) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
         """Mock HTTP request function."""
         """Mock HTTP request function."""
         if url in self.responses:
         if url in self.responses:
             resp_data = self.responses[url]
             resp_data = self.responses[url]

+ 1 - 1
tests/test_protocol.py

@@ -36,10 +36,10 @@ from dulwich.protocol import (
     ack_type,
     ack_type,
     extract_capabilities,
     extract_capabilities,
     extract_want_line_capabilities,
     extract_want_line_capabilities,
-    filter_ref_prefix,
     pkt_line,
     pkt_line,
     pkt_seq,
     pkt_seq,
 )
 )
+from dulwich.refs import filter_ref_prefix
 
 
 from . import TestCase
 from . import TestCase
 
 

Деякі файли не було показано, через те що забагато файлів було змінено