Browse Source

Add more typing (#1750)

Jelmer Vernooij 5 months ago
parent
commit
7db9043418

+ 6 - 5
dulwich/annotate.py

@@ -27,8 +27,9 @@ Python's difflib.
 """
 
 import difflib
-from typing import TYPE_CHECKING, Optional, cast
+from typing import TYPE_CHECKING, Optional
 
+from dulwich.objects import Blob
 from dulwich.walk import (
     ORDER_DATE,
     Walker,
@@ -37,7 +38,7 @@ from dulwich.walk import (
 if TYPE_CHECKING:
     from dulwich.diff_tree import TreeChange, TreeEntry
     from dulwich.object_store import BaseObjectStore
-    from dulwich.objects import Blob, Commit
+    from dulwich.objects import Commit
 
 # Walk over ancestry graph breadth-first
 # When checking each revision, find lines that according to difflib.Differ()
@@ -108,7 +109,7 @@ def annotate_lines(
 
     lines_annotated: list[tuple[tuple[Commit, TreeEntry], bytes]] = []
     for commit, entry in reversed(revs):
-        lines_annotated = update_lines(
-            lines_annotated, (commit, entry), cast("Blob", store[entry.sha])
-        )
+        blob_obj = store[entry.sha]
+        assert isinstance(blob_obj, Blob)
+        lines_annotated = update_lines(lines_annotated, (commit, entry), blob_obj)
     return lines_annotated

+ 4 - 2
dulwich/bisect.py

@@ -49,7 +49,7 @@ class BisectState:
         self,
         bad: Optional[bytes] = None,
         good: Optional[list[bytes]] = None,
-        paths: Optional[list[str]] = None,
+        paths: Optional[list[bytes]] = None,
         no_checkout: bool = False,
         term_bad: str = "bad",
         term_good: str = "good",
@@ -103,7 +103,9 @@ class BisectState:
         names_file = os.path.join(self.repo.controldir(), "BISECT_NAMES")
         with open(names_file, "w") as f:
             if paths:
-                f.write("\n".join(paths) + "\n")
+                f.write(
+                    "\n".join(path.decode("utf-8", "replace") for path in paths) + "\n"
+                )
             else:
                 f.write("\n")
 

+ 32 - 7
dulwich/bundle.py

@@ -22,9 +22,30 @@
 """Bundle format support."""
 
 from collections.abc import Iterator
-from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Optional
+from typing import (
+    TYPE_CHECKING,
+    BinaryIO,
+    Callable,
+    Optional,
+    Protocol,
+    runtime_checkable,
+)
+
+from .pack import PackData, UnpackedObject, write_pack_data
+
+
+@runtime_checkable
+class PackDataLike(Protocol):
+    """Protocol for objects that behave like PackData."""
+
+    def __len__(self) -> int:
+        """Return the number of objects in the pack."""
+        ...
+
+    def iter_unpacked(self) -> Iterator[UnpackedObject]:
+        """Iterate over unpacked objects in the pack."""
+        ...
 
-from .pack import PackData, write_pack_data
 
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
@@ -39,7 +60,7 @@ class Bundle:
     capabilities: dict[str, Optional[str]]
     prerequisites: list[tuple[bytes, bytes]]
     references: dict[bytes, bytes]
-    pack_data: PackData
+    pack_data: Optional[PackDataLike]
 
     def __repr__(self) -> str:
         """Return string representation of Bundle."""
@@ -79,10 +100,12 @@ class Bundle:
         """
         from .objects import ShaFile
 
+        if self.pack_data is None:
+            raise ValueError("pack_data is not loaded")
         count = 0
         for unpacked in self.pack_data.iter_unpacked():
             # Convert the unpacked object to a proper git object
-            if unpacked.decomp_chunks:
+            if unpacked.decomp_chunks and unpacked.obj_type_num is not None:
                 git_obj = ShaFile.from_raw_chunks(
                     unpacked.obj_type_num, unpacked.decomp_chunks
                 )
@@ -187,6 +210,8 @@ def write_bundle(f: BinaryIO, bundle: Bundle) -> None:
     for ref, obj_id in bundle.references.items():
         f.write(obj_id + b" " + ref + b"\n")
     f.write(b"\n")
+    if bundle.pack_data is None:
+        raise ValueError("bundle.pack_data is not loaded")
     write_pack_data(
         f.write,
         num_records=len(bundle.pack_data),
@@ -283,14 +308,14 @@ def create_bundle_from_repo(
     # Store the pack objects directly, we'll write them when saving the bundle
     # For now, create a simple wrapper to hold the data
     class _BundlePackData:
-        def __init__(self, count: int, objects: Iterator[Any]) -> None:
+        def __init__(self, count: int, objects: Iterator[UnpackedObject]) -> None:
             self._count = count
             self._objects = list(objects)  # Materialize the iterator
 
         def __len__(self) -> int:
             return self._count
 
-        def iter_unpacked(self) -> Iterator[Any]:
+        def iter_unpacked(self) -> Iterator[UnpackedObject]:
             return iter(self._objects)
 
     pack_data = _BundlePackData(pack_count, pack_objects)
@@ -301,6 +326,6 @@ def create_bundle_from_repo(
     bundle.capabilities = capabilities
     bundle.prerequisites = bundle_prerequisites
     bundle.references = bundle_refs
-    bundle.pack_data = pack_data  # type: ignore[assignment]
+    bundle.pack_data = pack_data
 
     return bundle

+ 118 - 65
dulwich/cli.py

@@ -37,7 +37,7 @@ import subprocess
 import sys
 import tempfile
 from pathlib import Path
-from typing import Callable, ClassVar, Optional, Union
+from typing import BinaryIO, Callable, ClassVar, Optional, Union
 
 from dulwich import porcelain
 
@@ -45,17 +45,31 @@ from .bundle import create_bundle_from_repo, read_bundle, write_bundle
 from .client import GitProtocolError, get_transport_and_path
 from .errors import ApplyDeltaError
 from .index import Index
-from .objects import valid_hexsha
+from .objects import Commit, valid_hexsha
 from .objectspec import parse_commit_range
 from .pack import Pack, sha_to_hex
 from .repo import Repo
 
 
+def to_display_str(value: Union[bytes, str]) -> str:
+    """Convert a bytes or string value to a display string.
+
+    Args:
+        value: The value to convert (bytes or str)
+
+    Returns:
+        A string suitable for display
+    """
+    if isinstance(value, bytes):
+        return value.decode("utf-8", "replace")
+    return value
+
+
 class CommitMessageError(Exception):
     """Raised when there's an issue with the commit message."""
 
 
-def signal_int(signal, frame) -> None:
+def signal_int(signal: int, frame) -> None:
     """Handle interrupt signal by exiting.
 
     Args:
@@ -65,7 +79,7 @@ def signal_int(signal, frame) -> None:
     sys.exit(1)
 
 
-def signal_quit(signal, frame) -> None:
+def signal_quit(signal: int, frame) -> None:
     """Handle quit signal by entering debugger.
 
     Args:
@@ -77,7 +91,7 @@ def signal_quit(signal, frame) -> None:
     pdb.set_trace()
 
 
-def parse_relative_time(time_str):
+def parse_relative_time(time_str: str) -> int:
     """Parse a relative time string like '2 weeks ago' into seconds.
 
     Args:
@@ -126,7 +140,7 @@ def parse_relative_time(time_str):
         raise
 
 
-def format_bytes(bytes):
+def format_bytes(bytes: float) -> str:
     """Format bytes as human-readable string.
 
     Args:
@@ -142,7 +156,7 @@ def format_bytes(bytes):
     return f"{bytes:.1f} TB"
 
 
-def launch_editor(template_content=b""):
+def launch_editor(template_content: bytes = b"") -> bytes:
     """Launch an editor for the user to enter text.
 
     Args:
@@ -176,7 +190,7 @@ def launch_editor(template_content=b""):
 class PagerBuffer:
     """Binary buffer wrapper for Pager to mimic sys.stdout.buffer."""
 
-    def __init__(self, pager):
+    def __init__(self, pager: "Pager") -> None:
         """Initialize PagerBuffer.
 
         Args:
@@ -184,40 +198,40 @@ class PagerBuffer:
         """
         self.pager = pager
 
-    def write(self, data: bytes):
+    def write(self, data: bytes) -> int:
         """Write bytes to pager."""
         if isinstance(data, bytes):
             text = data.decode("utf-8", errors="replace")
             return self.pager.write(text)
         return self.pager.write(data)
 
-    def flush(self):
+    def flush(self) -> None:
         """Flush the pager."""
         return self.pager.flush()
 
-    def writelines(self, lines):
+    def writelines(self, lines) -> None:
         """Write multiple lines to pager."""
         for line in lines:
             self.write(line)
 
-    def readable(self):
+    def readable(self) -> bool:
         """Return whether the buffer is readable (it's not)."""
         return False
 
-    def writable(self):
+    def writable(self) -> bool:
         """Return whether the buffer is writable."""
         return not self.pager._closed
 
-    def seekable(self):
+    def seekable(self) -> bool:
         """Return whether the buffer is seekable (it's not)."""
         return False
 
-    def close(self):
+    def close(self) -> None:
         """Close the pager."""
         return self.pager.close()
 
     @property
-    def closed(self):
+    def closed(self) -> bool:
         """Return whether the buffer is closed."""
         return self.pager.closed
 
@@ -225,13 +239,13 @@ class PagerBuffer:
 class Pager:
     """File-like object that pages output through external pager programs."""
 
-    def __init__(self, pager_cmd="cat"):
+    def __init__(self, pager_cmd: str = "cat") -> None:
         """Initialize Pager.
 
         Args:
             pager_cmd: Command to use for paging (default: "cat")
         """
-        self.pager_process = None
+        self.pager_process: Optional[subprocess.Popen] = None
         self.buffer = PagerBuffer(self)
         self._closed = False
         self.pager_cmd = pager_cmd
@@ -241,7 +255,7 @@ class Pager:
         """Get the pager command to use."""
         return self.pager_cmd
 
-    def _ensure_pager_started(self):
+    def _ensure_pager_started(self) -> None:
         """Start the pager process if not already started."""
         if self.pager_process is None and not self._closed:
             try:
@@ -280,7 +294,7 @@ class Pager:
             # No pager available, write directly to stdout
             return sys.stdout.write(text)
 
-    def flush(self):
+    def flush(self) -> None:
         """Flush the pager."""
         if self._closed or self._pager_died:
             return
@@ -293,7 +307,7 @@ class Pager:
         else:
             sys.stdout.flush()
 
-    def close(self):
+    def close(self) -> None:
         """Close the pager."""
         if self._closed:
             return
@@ -308,16 +322,16 @@ class Pager:
                 pass
             self.pager_process = None
 
-    def __enter__(self):
+    def __enter__(self) -> "Pager":
         """Context manager entry."""
         return self
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
         """Context manager exit."""
         self.close()
 
     # Additional file-like methods for compatibility
-    def writelines(self, lines):
+    def writelines(self, lines) -> None:
         """Write a list of lines to the pager."""
         if self._pager_died:
             return
@@ -325,19 +339,19 @@ class Pager:
             self.write(line)
 
     @property
-    def closed(self):
+    def closed(self) -> bool:
         """Return whether the pager is closed."""
         return self._closed
 
-    def readable(self):
+    def readable(self) -> bool:
         """Return whether the pager is readable (it's not)."""
         return False
 
-    def writable(self):
+    def writable(self) -> bool:
         """Return whether the pager is writable."""
         return not self._closed
 
-    def seekable(self):
+    def seekable(self) -> bool:
         """Return whether the pager is seekable (it's not)."""
         return False
 
@@ -345,7 +359,7 @@ class Pager:
 class _StreamContextAdapter:
     """Adapter to make streams work with context manager protocol."""
 
-    def __init__(self, stream):
+    def __init__(self, stream) -> None:
         self.stream = stream
         # Expose buffer if it exists
         if hasattr(stream, "buffer"):
@@ -356,15 +370,15 @@ class _StreamContextAdapter:
     def __enter__(self):
         return self.stream
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
         # For stdout/stderr, we don't close them
         pass
 
-    def __getattr__(self, name):
+    def __getattr__(self, name: str):
         return getattr(self.stream, name)
 
 
-def get_pager(config=None, cmd_name=None):
+def get_pager(config=None, cmd_name: Optional[str] = None):
     """Get a pager instance if paging should be used, otherwise return sys.stdout.
 
     Args:
@@ -447,14 +461,14 @@ def get_pager(config=None, cmd_name=None):
     return Pager(pager_cmd)
 
 
-def disable_pager():
+def disable_pager() -> None:
     """Disable pager for this session."""
-    get_pager._disabled = True
+    get_pager._disabled = True  # type: ignore[attr-defined]
 
 
-def enable_pager():
+def enable_pager() -> None:
     """Enable pager for this session."""
-    get_pager._disabled = False
+    get_pager._disabled = False  # type: ignore[attr-defined]
 
 
 class Command:
@@ -491,10 +505,14 @@ class cmd_archive(Command):
                 write_error=sys.stderr.write,
             )
         else:
-            # Use buffer if available (for binary output), otherwise use stdout
-            outstream = getattr(sys.stdout, "buffer", sys.stdout)
+            # Use binary buffer for archive output
+            outstream: BinaryIO = sys.stdout.buffer
+            errstream: BinaryIO = sys.stderr.buffer
             porcelain.archive(
-                ".", args.committish, outstream=outstream, errstream=sys.stderr
+                ".",
+                args.committish,
+                outstream=outstream,
+                errstream=errstream,
             )
 
 
@@ -642,10 +660,11 @@ class cmd_fetch(Command):
         def progress(msg: bytes) -> None:
             sys.stdout.buffer.write(msg)
 
-        refs = client.fetch(path, r, progress=progress)
+        result = client.fetch(path, r, progress=progress)
         print("Remote refs:")
-        for item in refs.items():
-            print("{} -> {}".format(*item))
+        for ref, sha in result.refs.items():
+            if sha is not None:
+                print(f"{ref.decode()} -> {sha.decode()}")
 
 
 class cmd_for_each_ref(Command):
@@ -676,7 +695,7 @@ class cmd_fsck(Command):
         parser = argparse.ArgumentParser()
         parser.parse_args(args)
         for obj, msg in porcelain.fsck("."):
-            print(f"{obj}: {msg}")
+            print(f"{obj.decode() if isinstance(obj, bytes) else obj}: {msg}")
 
 
 class cmd_log(Command):
@@ -838,7 +857,7 @@ class cmd_dump_pack(Command):
 
         basename, _ = os.path.splitext(args.filename)
         x = Pack(basename)
-        print(f"Object names checksum: {x.name()}")
+        print(f"Object names checksum: {x.name().decode('ascii', 'replace')}")
         print(f"Checksum: {sha_to_hex(x.get_stored_checksum())!r}")
         x.check()
         print(f"Length: {len(x)}")
@@ -846,9 +865,13 @@ class cmd_dump_pack(Command):
             try:
                 print(f"\t{x[name]}")
             except KeyError as k:
-                print(f"\t{name}: Unable to resolve base {k}")
+                print(
+                    f"\t{name.decode('ascii', 'replace')}: Unable to resolve base {k!r}"
+                )
             except ApplyDeltaError as e:
-                print(f"\t{name}: Unable to apply delta: {e!r}")
+                print(
+                    f"\t{name.decode('ascii', 'replace')}: Unable to apply delta: {e!r}"
+                )
 
 
 class cmd_dump_index(Command):
@@ -1302,9 +1325,17 @@ class cmd_reflog(Command):
 
                     for i, entry in enumerate(porcelain.reflog(repo, ref)):
                         # Format similar to git reflog
+                        from dulwich.reflog import Entry
+
+                        assert isinstance(entry, Entry)
                         short_new = entry.new_sha[:8].decode("ascii")
+                        message = (
+                            entry.message.decode("utf-8", "replace")
+                            if entry.message
+                            else ""
+                        )
                         outstream.write(
-                            f"{short_new} {ref.decode('utf-8', 'replace')}@{{{i}}}: {entry.message.decode('utf-8', 'replace')}\n"
+                            f"{short_new} {ref.decode('utf-8', 'replace')}@{{{i}}}: {message}\n"
                         )
 
 
@@ -1538,11 +1569,14 @@ class cmd_ls_remote(Command):
         if args.symref:
             # Show symrefs first, like git does
             for ref, target in sorted(result.symrefs.items()):
-                sys.stdout.write(f"ref: {target.decode()}\t{ref.decode()}\n")
+                if target:
+                    sys.stdout.write(f"ref: {target.decode()}\t{ref.decode()}\n")
 
         # Show regular refs
         for ref in sorted(result.refs):
-            sys.stdout.write(f"{result.refs[ref].decode()}\t{ref.decode()}\n")
+            sha = result.refs[ref]
+            if sha is not None:
+                sys.stdout.write(f"{sha.decode()}\t{ref.decode()}\n")
 
 
 class cmd_ls_tree(Command):
@@ -1601,12 +1635,13 @@ class cmd_pack_objects(Command):
         if not args.stdout and not args.basename:
             parser.error("basename required when not using --stdout")
 
-        object_ids = [line.strip() for line in sys.stdin.readlines()]
+        object_ids = [line.strip().encode() for line in sys.stdin.readlines()]
         deltify = args.deltify
         reuse_deltas = not args.no_reuse_deltas
 
         if args.stdout:
             packf = getattr(sys.stdout, "buffer", sys.stdout)
+            assert isinstance(packf, BinaryIO)
             idxf = None
             close = []
         else:
@@ -2022,8 +2057,17 @@ class cmd_stash_list(Command):
         """
         parser = argparse.ArgumentParser()
         parser.parse_args(args)
-        for i, entry in porcelain.stash_list("."):
-            print("stash@{{{}}}: {}".format(i, entry.message.rstrip("\n")))
+        from .repo import Repo
+        from .stash import Stash
+
+        with Repo(".") as r:
+            stash = Stash.from_repo(r)
+            for i, entry in enumerate(stash.stashes()):
+                print(
+                    "stash@{{{}}}: {}".format(
+                        i, entry.message.decode("utf-8", "replace").rstrip("\n")
+                    )
+                )
 
 
 class cmd_stash_push(Command):
@@ -2145,6 +2189,7 @@ class cmd_bisect(SuperCommand):
                         with open(bad_ref, "rb") as f:
                             bad_sha = f.read().strip()
                         commit = r.object_store[bad_sha]
+                        assert isinstance(commit, Commit)
                         message = commit.message.decode(
                             "utf-8", errors="replace"
                         ).split("\n")[0]
@@ -2173,7 +2218,7 @@ class cmd_bisect(SuperCommand):
                 print(log, end="")
 
             elif parsed_args.subcommand == "replay":
-                porcelain.bisect_replay(log_file=parsed_args.logfile)
+                porcelain.bisect_replay(".", log_file=parsed_args.logfile)
                 print(f"Replayed bisect log from {parsed_args.logfile}")
 
             elif parsed_args.subcommand == "help":
@@ -2270,6 +2315,7 @@ class cmd_merge(Command):
             elif args.no_commit:
                 print("Automatic merge successful; not committing as requested.")
             else:
+                assert merge_commit_id is not None
                 print(
                     f"Merge successful. Created merge commit {merge_commit_id.decode()}"
                 )
@@ -3129,12 +3175,14 @@ class cmd_lfs(Command):
             tracked = porcelain.lfs_untrack(patterns=args.patterns)
             print("Remaining tracked patterns:")
             for pattern in tracked:
-                print(f"  {pattern}")
+                print(f"  {to_display_str(pattern)}")
 
         elif args.subcommand == "ls-files":
             files = porcelain.lfs_ls_files(ref=args.ref)
             for path, oid, size in files:
-                print(f"{oid[:12]} * {path} ({format_bytes(size)})")
+                print(
+                    f"{to_display_str(oid[:12])} * {to_display_str(path)} ({format_bytes(size)})"
+                )
 
         elif args.subcommand == "migrate":
             count = porcelain.lfs_migrate(
@@ -3145,13 +3193,13 @@ class cmd_lfs(Command):
         elif args.subcommand == "pointer":
             if args.paths is not None:
                 results = porcelain.lfs_pointer_check(paths=args.paths or None)
-                for path, pointer in results.items():
+                for file_path, pointer in results.items():
                     if pointer:
                         print(
-                            f"{path}: LFS pointer (oid: {pointer.oid[:12]}, size: {format_bytes(pointer.size)})"
+                            f"{to_display_str(file_path)}: LFS pointer (oid: {to_display_str(pointer.oid[:12])}, size: {format_bytes(pointer.size)})"
                         )
                     else:
-                        print(f"{path}: Not an LFS pointer")
+                        print(f"{to_display_str(file_path)}: Not an LFS pointer")
 
         elif args.subcommand == "clean":
             pointer = porcelain.lfs_clean(path=args.path)
@@ -3188,13 +3236,13 @@ class cmd_lfs(Command):
 
             if status["missing"]:
                 print("\nMissing LFS objects:")
-                for path in status["missing"]:
-                    print(f"  {path}")
+                for file_path in status["missing"]:
+                    print(f"  {to_display_str(file_path)}")
 
             if status["not_staged"]:
                 print("\nModified LFS files not staged:")
-                for path in status["not_staged"]:
-                    print(f"  {path}")
+                for file_path in status["not_staged"]:
+                    print(f"  {to_display_str(file_path)}")
 
             if not any(status.values()):
                 print("No LFS files found.")
@@ -3273,14 +3321,19 @@ class cmd_format_patch(Command):
         args = parser.parse_args(args)
 
         # Parse committish using the new function
-        committish = None
+        committish: Optional[Union[bytes, tuple[bytes, bytes]]] = None
         if args.committish:
             with Repo(".") as r:
                 range_result = parse_commit_range(r, args.committish)
                 if range_result:
-                    committish = range_result
+                    # Convert Commit objects to their SHAs
+                    committish = (range_result[0].id, range_result[1].id)
                 else:
-                    committish = args.committish
+                    committish = (
+                        args.committish.encode()
+                        if isinstance(args.committish, str)
+                        else args.committish
+                    )
 
         filenames = porcelain.format_patch(
             ".",

+ 144 - 123
dulwich/client.py

@@ -69,11 +69,11 @@ import dulwich
 
 from .config import Config, apply_instead_of, get_xdg_config_home_path
 from .errors import GitProtocolError, NotGitRepository, SendPackError
+from .object_store import GraphWalker
 from .pack import (
     PACK_SPOOL_FILE_MAX_SIZE,
     PackChunkGenerator,
     PackData,
-    UnpackedObject,
     write_pack_from_container,
 )
 from .protocol import (
@@ -116,7 +116,6 @@ from .protocol import (
     capability_agent,
     extract_capabilities,
     extract_capability_names,
-    filter_ref_prefix,
     parse_capability,
     pkt_line,
     pkt_seq,
@@ -129,6 +128,7 @@ from .refs import (
     _set_default_branch,
     _set_head,
     _set_origin_head,
+    filter_ref_prefix,
     read_info_refs,
     split_peeled_refs,
 )
@@ -149,7 +149,7 @@ logger = logging.getLogger(__name__)
 class InvalidWants(Exception):
     """Invalid wants."""
 
-    def __init__(self, wants) -> None:
+    def __init__(self, wants: set[bytes]) -> None:
         """Initialize InvalidWants exception.
 
         Args:
@@ -163,7 +163,7 @@ class InvalidWants(Exception):
 class HTTPUnauthorized(Exception):
     """Raised when authentication fails."""
 
-    def __init__(self, www_authenticate, url) -> None:
+    def __init__(self, www_authenticate: Optional[str], url: str) -> None:
         """Initialize HTTPUnauthorized exception.
 
         Args:
@@ -178,7 +178,7 @@ class HTTPUnauthorized(Exception):
 class HTTPProxyUnauthorized(Exception):
     """Raised when proxy authentication fails."""
 
-    def __init__(self, proxy_authenticate, url) -> None:
+    def __init__(self, proxy_authenticate: Optional[str], url: str) -> None:
         """Initialize HTTPProxyUnauthorized exception.
 
         Args:
@@ -190,22 +190,28 @@ class HTTPProxyUnauthorized(Exception):
         self.url = url
 
 
-def _fileno_can_read(fileno):
+def _fileno_can_read(fileno: int) -> bool:
     """Check if a file descriptor is readable."""
     return len(select.select([fileno], [], [], 0)[0]) > 0
 
 
-def _win32_peek_avail(handle):
+def _win32_peek_avail(handle: int) -> int:
     """Wrapper around PeekNamedPipe to check how many bytes are available."""
-    from ctypes import byref, windll, wintypes
+    from ctypes import (  # type: ignore[attr-defined]
+        byref,
+        windll,  # type: ignore[attr-defined]
+        wintypes,
+    )
 
     c_avail = wintypes.DWORD()
     c_message = wintypes.DWORD()
-    success = windll.kernel32.PeekNamedPipe(
+    success = windll.kernel32.PeekNamedPipe(  # type: ignore[attr-defined]
         handle, None, 0, None, byref(c_avail), byref(c_message)
     )
     if not success:
-        raise OSError(wintypes.GetLastError())
+        from ctypes import GetLastError  # type: ignore[attr-defined]
+
+        raise OSError(GetLastError())
     return c_avail.value
 
 
@@ -230,10 +236,10 @@ class ReportStatusParser:
     def __init__(self) -> None:
         """Initialize ReportStatusParser."""
         self._done = False
-        self._pack_status = None
+        self._pack_status: Optional[bytes] = None
         self._ref_statuses: list[bytes] = []
 
-    def check(self):
+    def check(self) -> Iterator[tuple[bytes, Optional[str]]]:
         """Check if there were any errors and, if so, raise exceptions.
 
         Raises:
@@ -257,7 +263,7 @@ class ReportStatusParser:
             else:
                 raise GitProtocolError(f"invalid ref status {status!r}")
 
-    def handle_packet(self, pkt) -> None:
+    def handle_packet(self, pkt: Optional[bytes]) -> None:
         """Handle a packet.
 
         Raises:
@@ -276,13 +282,8 @@ class ReportStatusParser:
             self._ref_statuses.append(ref_status)
 
 
-def negotiate_protocol_version(proto) -> int:
-    """Negotiate the protocol version to use.
-
-    Args:
-      proto: Protocol instance to negotiate with
-    Returns: Protocol version (0, 1, or 2)
-    """
+def negotiate_protocol_version(proto: Protocol) -> int:
+    """Negotiate protocol version with the server."""
     pkt = proto.read_pkt_line()
     if pkt is not None and pkt.strip() == b"version 2":
         return 2
@@ -290,13 +291,8 @@ def negotiate_protocol_version(proto) -> int:
     return 0
 
 
-def read_server_capabilities(pkt_seq):
-    """Read server capabilities from a packet sequence.
-
-    Args:
-      pkt_seq: Sequence of packets from server
-    Returns: Set of server capabilities
-    """
+def read_server_capabilities(pkt_seq: Iterable[bytes]) -> set[bytes]:
+    """Read server capabilities from packet sequence."""
     server_capabilities = []
     for pkt in pkt_seq:
         server_capabilities.append(pkt)
@@ -304,21 +300,16 @@ def read_server_capabilities(pkt_seq):
 
 
 def read_pkt_refs_v2(
-    pkt_seq,
-) -> tuple[dict[bytes, bytes], dict[bytes, bytes], dict[bytes, bytes]]:
-    """Read packet references in protocol v2 format.
-
-    Args:
-      pkt_seq: Sequence of packets
-    Returns: Tuple of (refs dict, symrefs dict, peeled dict)
-    """
-    refs = {}
+    pkt_seq: Iterable[bytes],
+) -> tuple[dict[bytes, Optional[bytes]], dict[bytes, bytes], dict[bytes, bytes]]:
+    """Read references using protocol version 2."""
+    refs: dict[bytes, Optional[bytes]] = {}
     symrefs = {}
     peeled = {}
     # Receive refs from server
     for pkt in pkt_seq:
         parts = pkt.rstrip(b"\n").split(b" ")
-        sha = parts[0]
+        sha: Optional[bytes] = parts[0]
         if sha == b"unborn":
             sha = None
         ref = parts[1]
@@ -334,15 +325,12 @@ def read_pkt_refs_v2(
     return refs, symrefs, peeled
 
 
-def read_pkt_refs_v1(pkt_seq) -> tuple[dict[bytes, bytes], set[bytes]]:
-    """Read packet references in protocol v1 format.
-
-    Args:
-      pkt_seq: Sequence of packets
-    Returns: Tuple of (refs dict, server capabilities set)
-    """
+def read_pkt_refs_v1(
+    pkt_seq: Iterable[bytes],
+) -> tuple[dict[bytes, Optional[bytes]], set[bytes]]:
+    """Read references using protocol version 1."""
     server_capabilities = None
-    refs = {}
+    refs: dict[bytes, Optional[bytes]] = {}
     # Receive refs from server
     for pkt in pkt_seq:
         (sha, ref) = pkt.rstrip(b"\n").split(None, 1)
@@ -363,6 +351,8 @@ def read_pkt_refs_v1(pkt_seq) -> tuple[dict[bytes, bytes], set[bytes]]:
 class _DeprecatedDictProxy:
     """Base class for result objects that provide deprecated dict-like interface."""
 
+    refs: dict[bytes, Optional[bytes]]  # To be overridden by subclasses
+
     _FORWARDED_ATTRS: ClassVar[set[str]] = {
         "clear",
         "copy",
@@ -389,11 +379,11 @@ class _DeprecatedDictProxy:
             stacklevel=3,
         )
 
-    def __contains__(self, name) -> bool:
+    def __contains__(self, name: bytes) -> bool:
         self._warn_deprecated()
         return name in self.refs
 
-    def __getitem__(self, name):
+    def __getitem__(self, name: bytes) -> Optional[bytes]:
         self._warn_deprecated()
         return self.refs[name]
 
@@ -401,11 +391,11 @@ class _DeprecatedDictProxy:
         self._warn_deprecated()
         return len(self.refs)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[bytes]:
         self._warn_deprecated()
         return iter(self.refs)
 
-    def __getattribute__(self, name):
+    def __getattribute__(self, name: str) -> object:
         # Avoid infinite recursion by checking against class variable directly
         if name != "_FORWARDED_ATTRS" and name in type(self)._FORWARDED_ATTRS:
             self._warn_deprecated()
@@ -424,8 +414,16 @@ class FetchPackResult(_DeprecatedDictProxy):
       agent: User agent string
     """
 
+    symrefs: dict[bytes, bytes]
+    agent: Optional[bytes]
+
     def __init__(
-        self, refs, symrefs, agent, new_shallow=None, new_unshallow=None
+        self,
+        refs: dict[bytes, Optional[bytes]],
+        symrefs: dict[bytes, bytes],
+        agent: Optional[bytes],
+        new_shallow: Optional[set[bytes]] = None,
+        new_unshallow: Optional[set[bytes]] = None,
     ) -> None:
         """Initialize FetchPackResult.
 
@@ -442,11 +440,13 @@ class FetchPackResult(_DeprecatedDictProxy):
         self.new_shallow = new_shallow
         self.new_unshallow = new_unshallow
 
-    def __eq__(self, other):
-        """Check equality with another FetchPackResult."""
+    def __eq__(self, other: object) -> bool:
+        """Check equality with another object."""
         if isinstance(other, dict):
             self._warn_deprecated()
             return self.refs == other
+        if not isinstance(other, FetchPackResult):
+            return False
         return (
             self.refs == other.refs
             and self.symrefs == other.symrefs
@@ -466,7 +466,11 @@ class LsRemoteResult(_DeprecatedDictProxy):
       symrefs: Dictionary with remote symrefs
     """
 
-    def __init__(self, refs, symrefs) -> None:
+    symrefs: dict[bytes, bytes]
+
+    def __init__(
+        self, refs: dict[bytes, Optional[bytes]], symrefs: dict[bytes, bytes]
+    ) -> None:
         """Initialize LsRemoteResult.
 
         Args:
@@ -486,11 +490,13 @@ class LsRemoteResult(_DeprecatedDictProxy):
             stacklevel=3,
         )
 
-    def __eq__(self, other):
-        """Check equality with another LsRemoteResult."""
+    def __eq__(self, other: object) -> bool:
+        """Check equality with another object."""
         if isinstance(other, dict):
             self._warn_deprecated()
             return self.refs == other
+        if not isinstance(other, LsRemoteResult):
+            return False
         return self.refs == other.refs and self.symrefs == other.symrefs
 
     def __repr__(self) -> str:
@@ -508,7 +514,12 @@ class SendPackResult(_DeprecatedDictProxy):
         failed to update), or None if it was updated successfully
     """
 
-    def __init__(self, refs, agent=None, ref_status=None) -> None:
+    def __init__(
+        self,
+        refs: dict[bytes, Optional[bytes]],
+        agent: Optional[bytes] = None,
+        ref_status: Optional[dict[bytes, Optional[str]]] = None,
+    ) -> None:
         """Initialize SendPackResult.
 
         Args:
@@ -520,11 +531,13 @@ class SendPackResult(_DeprecatedDictProxy):
         self.agent = agent
         self.ref_status = ref_status
 
-    def __eq__(self, other):
-        """Check equality with another SendPackResult."""
+    def __eq__(self, other: object) -> bool:
+        """Check equality with another object."""
         if isinstance(other, dict):
             self._warn_deprecated()
             return self.refs == other
+        if not isinstance(other, SendPackResult):
+            return False
         return self.refs == other.refs and self.agent == other.agent
 
     def __repr__(self) -> str:
@@ -532,13 +545,7 @@ class SendPackResult(_DeprecatedDictProxy):
         return f"{self.__class__.__name__}({self.refs!r}, {self.agent!r})"
 
 
-def _read_shallow_updates(pkt_seq):
-    """Read shallow/unshallow updates from a packet sequence.
-
-    Args:
-      pkt_seq: Sequence of packets
-    Returns: Tuple of (new_shallow set, new_unshallow set)
-    """
+def _read_shallow_updates(pkt_seq: Iterable[bytes]) -> tuple[set[bytes], set[bytes]]:
     new_shallow = set()
     new_unshallow = set()
     for pkt in pkt_seq:
@@ -547,30 +554,29 @@ def _read_shallow_updates(pkt_seq):
         try:
             cmd, sha = pkt.split(b" ", 1)
         except ValueError:
-            raise GitProtocolError(f"unknown command {pkt}")
+            raise GitProtocolError(f"unknown command {pkt!r}")
         if cmd == COMMAND_SHALLOW:
             new_shallow.add(sha.strip())
         elif cmd == COMMAND_UNSHALLOW:
             new_unshallow.add(sha.strip())
         else:
-            raise GitProtocolError(f"unknown command {pkt}")
+            raise GitProtocolError(f"unknown command {pkt!r}")
     return (new_shallow, new_unshallow)
 
 
 class _v1ReceivePackHeader:
-    """Handler for v1 receive-pack header."""
-
-    def __init__(self, capabilities, old_refs, new_refs) -> None:
-        self.want: list[bytes] = []
-        self.have: list[bytes] = []
+    def __init__(self, capabilities: list, old_refs: dict, new_refs: dict) -> None:
+        self.want: set[bytes] = set()
+        self.have: set[bytes] = set()
         self._it = self._handle_receive_pack_head(capabilities, old_refs, new_refs)
         self.sent_capabilities = False
 
-    def __iter__(self):
-        """Iterate over the receive-pack header lines."""
+    def __iter__(self) -> Iterator[Optional[bytes]]:
         return self._it
 
-    def _handle_receive_pack_head(self, capabilities, old_refs, new_refs):
+    def _handle_receive_pack_head(
+        self, capabilities: list, old_refs: dict, new_refs: dict
+    ) -> Iterator[Optional[bytes]]:
         """Handle the head of a 'git-receive-pack' request.
 
         Args:
@@ -581,7 +587,7 @@ class _v1ReceivePackHeader:
         Returns:
           (have, want) tuple
         """
-        self.have = [x for x in old_refs.values() if not x == ZERO_SHA]
+        self.have = {x for x in old_refs.values() if not x == ZERO_SHA}
 
         for refname in new_refs:
             if not isinstance(refname, bytes):
@@ -615,7 +621,7 @@ class _v1ReceivePackHeader:
                     )
                     self.sent_capabilities = True
             if new_sha1 not in self.have and new_sha1 != ZERO_SHA:
-                self.want.append(new_sha1)
+                self.want.add(new_sha1)
         yield None
 
 
@@ -632,33 +638,29 @@ def _read_side_band64k_data(pkt_seq: Iterable[bytes]) -> Iterator[tuple[int, byt
         yield channel, pkt[1:]
 
 
-def find_capability(capabilities, key, value):
-    """Find a capability in the list of capabilities.
-
-    Args:
-      capabilities: List of capabilities
-      key: Capability key to search for
-      value: Optional specific value to match
-    Returns: The matching capability or None
-    """
+def find_capability(
+    capabilities: list, key: bytes, value: Optional[bytes]
+) -> Optional[bytes]:
+    """Find a capability with a specific key and value."""
     for capability in capabilities:
         k, v = parse_capability(capability)
         if k != key:
             continue
-        if value and value not in v.split(b" "):
+        if value and v and value not in v.split(b" "):
             continue
         return capability
+    return None
 
 
 def _handle_upload_pack_head(
-    proto,
-    capabilities,
-    graph_walker,
-    wants,
-    can_read,
+    proto: Protocol,
+    capabilities: list,
+    graph_walker: GraphWalker,
+    wants: list,
+    can_read: Optional[Callable],
     depth: Optional[int],
-    protocol_version,
-):
+    protocol_version: Optional[int],
+) -> tuple[Optional[set[bytes]], Optional[set[bytes]]]:
     """Handle the head of a 'git-upload-pack' request.
 
     Args:
@@ -671,6 +673,8 @@ def _handle_upload_pack_head(
       depth: Depth for request
       protocol_version: Neogiated Git protocol version.
     """
+    new_shallow: Optional[set[bytes]]
+    new_unshallow: Optional[set[bytes]]
     assert isinstance(wants, list) and isinstance(wants[0], bytes)
     wantcmd = COMMAND_WANT + b" " + wants[0]
     if protocol_version is None:
@@ -681,7 +685,9 @@ def _handle_upload_pack_head(
     proto.write_pkt_line(wantcmd)
     for want in wants[1:]:
         proto.write_pkt_line(COMMAND_WANT + b" " + want + b"\n")
-    if depth not in (0, None) or graph_walker.shallow:
+    if depth not in (0, None) or (
+        hasattr(graph_walker, "shallow") and graph_walker.shallow
+    ):
         if protocol_version == 2:
             if not find_capability(capabilities, CAPABILITY_FETCH, CAPABILITY_SHALLOW):
                 raise GitProtocolError(
@@ -691,8 +697,9 @@ def _handle_upload_pack_head(
             raise GitProtocolError(
                 "server does not support shallow capability required for depth"
             )
-        for sha in graph_walker.shallow:
-            proto.write_pkt_line(COMMAND_SHALLOW + b" " + sha + b"\n")
+        if hasattr(graph_walker, "shallow"):
+            for sha in graph_walker.shallow:
+                proto.write_pkt_line(COMMAND_SHALLOW + b" " + sha + b"\n")
         if depth is not None:
             proto.write_pkt_line(
                 COMMAND_DEEPEN + b" " + str(depth).encode("ascii") + b"\n"
@@ -705,6 +712,7 @@ def _handle_upload_pack_head(
         proto.write_pkt_line(COMMAND_HAVE + b" " + have + b"\n")
         if can_read is not None and can_read():
             pkt = proto.read_pkt_line()
+            assert pkt is not None
             parts = pkt.rstrip(b"\n").split(b" ")
             if parts[0] == b"ACK":
                 graph_walker.ack(parts[1])
@@ -714,7 +722,7 @@ def _handle_upload_pack_head(
                     break
                 else:
                     raise AssertionError(
-                        f"{parts[2]} not in ('continue', 'ready', 'common)"
+                        f"{parts[2]!r} not in ('continue', 'ready', 'common)"
                     )
         have = next(graph_walker)
     proto.write_pkt_line(COMMAND_DONE + b"\n")
@@ -725,7 +733,8 @@ def _handle_upload_pack_head(
         if can_read is not None:
             (new_shallow, new_unshallow) = _read_shallow_updates(proto.read_pkt_seq())
         else:
-            new_shallow = new_unshallow = None
+            new_shallow = None
+            new_unshallow = None
     else:
         new_shallow = new_unshallow = set()
 
@@ -773,7 +782,7 @@ def _handle_upload_pack_tail(
         if progress is None:
             # Just ignore progress data
 
-            def progress(x) -> None:
+            def progress(x: bytes) -> None:
                 pass
 
         for chan, data in _read_side_band64k_data(proto.read_pkt_seq()):
@@ -804,6 +813,7 @@ def _extract_symrefs_and_agent(capabilities):
     for capability in capabilities:
         k, v = parse_capability(capability)
         if k == CAPABILITY_SYMREF:
+            assert v is not None
             (src, dst) = v.split(b":", 1)
             symrefs[src] = dst
         if k == CAPABILITY_AGENT:
@@ -879,9 +889,7 @@ class GitClient:
         self,
         path: str,
         update_refs,
-        generate_pack_data: Callable[
-            [set[bytes], set[bytes], bool], tuple[int, Iterator[UnpackedObject]]
-        ],
+        generate_pack_data,
         progress=None,
     ) -> SendPackResult:
         """Upload a pack to a remote repository.
@@ -972,8 +980,11 @@ class GitClient:
             origin_sha = result.refs.get(b"HEAD")
             if origin is None or (origin_sha and not origin_head):
                 # set detached HEAD
-                target.refs[b"HEAD"] = origin_sha
-                head = origin_sha
+                if origin_sha is not None:
+                    target.refs[b"HEAD"] = origin_sha
+                    head = origin_sha
+                else:
+                    head = None
             else:
                 _set_origin_head(target.refs, origin.encode("utf-8"), origin_head)
                 head_ref = _set_default_branch(
@@ -1203,10 +1214,11 @@ class GitClient:
             if self.protocol_version == 2 and k == CAPABILITY_FETCH:
                 fetch_capa = CAPABILITY_FETCH
                 fetch_features = []
-                v = v.strip().split(b" ")
-                if b"shallow" in v:
+                assert v is not None
+                v_list = v.strip().split(b" ")
+                if b"shallow" in v_list:
                     fetch_features.append(CAPABILITY_SHALLOW)
-                if b"filter" in v:
+                if b"filter" in v_list:
                     fetch_features.append(CAPABILITY_FILTER)
                 for i in range(len(fetch_features)):
                     if i == 0:
@@ -1357,10 +1369,10 @@ class TraditionalGitClient(GitClient):
                 for ref, sha in orig_new_refs.items():
                     if sha == ZERO_SHA:
                         if CAPABILITY_REPORT_STATUS in negotiated_capabilities:
+                            assert report_status_parser is not None
                             report_status_parser._ref_statuses.append(
                                 b"ng " + ref + b" remote does not support deleting refs"
                             )
-                            report_status_parser._ref_status_ok = False
                         del new_refs[ref]
 
             if new_refs is None:
@@ -1767,7 +1779,7 @@ class TCPGitClient(TraditionalGitClient):
         proto.send_cmd(
             b"git-" + cmd, path, b"host=" + self._host.encode("ascii") + version_str
         )
-        return proto, lambda: _fileno_can_read(s), None
+        return proto, lambda: _fileno_can_read(s.fileno()), None
 
 
 class SubprocessWrapper:
@@ -1997,7 +2009,7 @@ class LocalGitClient(GitClient):
                 *generate_pack_data(have, want, ofs_delta=True)
             )
 
-            ref_status = {}
+            ref_status: dict[bytes, Optional[str]] = {}
 
             for refname, new_sha1 in new_refs.items():
                 old_sha1 = old_refs.get(refname, ZERO_SHA)
@@ -2236,7 +2248,7 @@ class BundleClient(GitClient):
 
             while line.startswith(b"-"):
                 (obj_id, comment) = line[1:].rstrip(b"\n").split(b" ", 1)
-                prerequisites.append((obj_id, comment.decode("utf-8")))
+                prerequisites.append((obj_id, comment))
                 line = f.readline()
 
             while line != b"\n":
@@ -2977,7 +2989,11 @@ class AbstractHttpGitClient(GitClient):
         protocol_version: Optional[int] = None,
         ref_prefix: Optional[list[Ref]] = None,
     ) -> tuple[
-        dict[Ref, ObjectID], set[bytes], str, dict[Ref, Ref], dict[Ref, ObjectID]
+        dict[Ref, Optional[ObjectID]],
+        set[bytes],
+        str,
+        dict[Ref, Ref],
+        dict[Ref, ObjectID],
     ]:
         if (
             protocol_version is not None
@@ -3040,10 +3056,10 @@ class AbstractHttpGitClient(GitClient):
                     resp, read = self._smart_request(
                         service.decode("ascii"), base_url, body
                     )
-                    proto = Protocol(read, None)
+                    proto = Protocol(read, lambda data: None)
                     return server_capabilities, resp, read, proto
 
-                proto = Protocol(read, None)  # type: ignore
+                proto = Protocol(read, lambda data: None)
                 server_protocol_version = negotiate_protocol_version(proto)
                 if server_protocol_version not in GIT_PROTOCOL_VERSIONS:
                     raise ValueError(
@@ -3108,7 +3124,12 @@ class AbstractHttpGitClient(GitClient):
                     if not chunk:
                         break
                     data += chunk
-                (refs, peeled) = split_peeled_refs(read_info_refs(BytesIO(data)))
+                from typing import Optional, cast
+
+                info_refs = read_info_refs(BytesIO(data))
+                (refs, peeled) = split_peeled_refs(
+                    cast(dict[bytes, Optional[bytes]], info_refs)
+                )
                 if ref_prefix is not None:
                     refs = filter_ref_prefix(refs, ref_prefix)
                 return refs, set(), base_url, {}, peeled
@@ -3196,7 +3217,7 @@ class AbstractHttpGitClient(GitClient):
 
         resp, read = self._smart_request("git-receive-pack", url, data=body_generator())
         try:
-            resp_proto = Protocol(read, None)
+            resp_proto = Protocol(read, lambda data: None)
             ref_status = self._handle_receive_pack_tail(
                 resp_proto, negotiated_capabilities, progress
             )
@@ -3441,7 +3462,7 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
     def _http_request(self, url, headers=None, data=None, raise_for_status=True):
         import urllib3.exceptions
 
-        req_headers = self.pool_manager.headers.copy()
+        req_headers = dict(self.pool_manager.headers)
         if headers is not None:
             req_headers.update(headers)
         req_headers["Pragma"] = "no-cache"
@@ -3455,10 +3476,10 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
                 request_kwargs["timeout"] = self._timeout
 
             if data is None:
-                resp = self.pool_manager.request("GET", url, **request_kwargs)
+                resp = self.pool_manager.request("GET", url, **request_kwargs)  # type: ignore[arg-type]
             else:
                 request_kwargs["body"] = data
-                resp = self.pool_manager.request("POST", url, **request_kwargs)
+                resp = self.pool_manager.request("POST", url, **request_kwargs)  # type: ignore[arg-type]
         except urllib3.exceptions.HTTPError as e:
             raise GitProtocolError(str(e)) from e
 
@@ -3472,15 +3493,15 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
             if resp.status != 200:
                 raise GitProtocolError(f"unexpected http resp {resp.status} for {url}")
 
-        resp.content_type = resp.headers.get("Content-Type")
+        resp.content_type = resp.headers.get("Content-Type")  # type: ignore[attr-defined]
         # Check if geturl() is available (urllib3 version >= 1.23)
         try:
             resp_url = resp.geturl()
         except AttributeError:
             # get_redirect_location() is available for urllib3 >= 1.1
-            resp.redirect_location = resp.get_redirect_location()
+            resp.redirect_location = resp.get_redirect_location()  # type: ignore[attr-defined]
         else:
-            resp.redirect_location = resp_url if resp_url != url else ""
+            resp.redirect_location = resp_url if resp_url != url else ""  # type: ignore[attr-defined]
         return resp, _wrap_urllib3_exceptions(resp.read)
 
 

+ 32 - 15
dulwich/cloud/gcs.py

@@ -24,22 +24,33 @@
 
 import posixpath
 import tempfile
+from collections.abc import Iterator
+from typing import TYPE_CHECKING, BinaryIO
 
 from ..object_store import BucketBasedObjectStore
-from ..pack import PACK_SPOOL_FILE_MAX_SIZE, Pack, PackData, load_pack_index_file
+from ..pack import (
+    PACK_SPOOL_FILE_MAX_SIZE,
+    Pack,
+    PackData,
+    PackIndex,
+    load_pack_index_file,
+)
+
+if TYPE_CHECKING:
+    from google.cloud.storage import Bucket
 
 # TODO(jelmer): For performance, read ranges?
 
 
 class GcsObjectStore(BucketBasedObjectStore):
-    """Object store implementation for Google Cloud Storage."""
+    """Object store implementation using Google Cloud Storage."""
 
-    def __init__(self, bucket, subpath="") -> None:
-        """Initialize GcsObjectStore.
+    def __init__(self, bucket: "Bucket", subpath: str = "") -> None:
+        """Initialize GCS object store.
 
         Args:
-          bucket: GCS bucket instance
-          subpath: Subpath within the bucket
+            bucket: GCS bucket instance
+            subpath: Optional subpath within the bucket
         """
         super().__init__()
         self.bucket = bucket
@@ -49,13 +60,13 @@ class GcsObjectStore(BucketBasedObjectStore):
         """Return string representation of GcsObjectStore."""
         return f"{type(self).__name__}({self.bucket!r}, subpath={self.subpath!r})"
 
-    def _remove_pack(self, name) -> None:
+    def _remove_pack_by_name(self, name: str) -> None:
         self.bucket.delete_blobs(
             [posixpath.join(self.subpath, name) + "." + ext for ext in ["pack", "idx"]]
         )
 
-    def _iter_pack_names(self):
-        packs = {}
+    def _iter_pack_names(self) -> Iterator[str]:
+        packs: dict[str, set[str]] = {}
         for blob in self.bucket.list_blobs(prefix=self.subpath):
             name, ext = posixpath.splitext(posixpath.basename(blob.name))
             packs.setdefault(name, set()).add(ext)
@@ -63,26 +74,32 @@ class GcsObjectStore(BucketBasedObjectStore):
             if exts == {".pack", ".idx"}:
                 yield name
 
-    def _load_pack_data(self, name):
+    def _load_pack_data(self, name: str) -> PackData:
         b = self.bucket.blob(posixpath.join(self.subpath, name + ".pack"))
+        from typing import cast
+
+        from ..file import _GitFile
+
         f = tempfile.SpooledTemporaryFile(max_size=PACK_SPOOL_FILE_MAX_SIZE)
         b.download_to_file(f)
         f.seek(0)
-        return PackData(name + ".pack", f)
+        return PackData(name + ".pack", cast(_GitFile, f))
 
-    def _load_pack_index(self, name):
+    def _load_pack_index(self, name: str) -> PackIndex:
         b = self.bucket.blob(posixpath.join(self.subpath, name + ".idx"))
         f = tempfile.SpooledTemporaryFile(max_size=PACK_SPOOL_FILE_MAX_SIZE)
         b.download_to_file(f)
         f.seek(0)
         return load_pack_index_file(name + ".idx", f)
 
-    def _get_pack(self, name):
-        return Pack.from_lazy_objects(
+    def _get_pack(self, name: str) -> Pack:
+        return Pack.from_lazy_objects(  # type: ignore[no-untyped-call]
             lambda: self._load_pack_data(name), lambda: self._load_pack_index(name)
         )
 
-    def _upload_pack(self, basename, pack_file, index_file) -> None:
+    def _upload_pack(
+        self, basename: str, pack_file: BinaryIO, index_file: BinaryIO
+    ) -> None:
         idxblob = self.bucket.blob(posixpath.join(self.subpath, basename + ".idx"))
         datablob = self.bucket.blob(posixpath.join(self.subpath, basename + ".pack"))
         idxblob.upload_from_file(index_file)

+ 1 - 3
dulwich/commit_graph.py

@@ -567,9 +567,7 @@ def write_commit_graph(
 
     graph_path = os.path.join(info_dir, b"commit-graph")
     with GitFile(graph_path, "wb") as f:
-        from typing import BinaryIO, cast
-
-        graph.write_to_file(cast(BinaryIO, f))
+        graph.write_to_file(f)
 
 
 def get_reachable_commits(

+ 39 - 12
dulwich/config.py

@@ -41,14 +41,11 @@ from contextlib import suppress
 from pathlib import Path
 from typing import (
     IO,
-    Any,
-    BinaryIO,
     Callable,
     Generic,
     Optional,
     TypeVar,
     Union,
-    cast,
     overload,
 )
 
@@ -60,7 +57,7 @@ ConfigValue = Union[str, bytes, bool, int]
 logger = logging.getLogger(__name__)
 
 # Type for file opener callback
-FileOpener = Callable[[Union[str, os.PathLike]], BinaryIO]
+FileOpener = Callable[[Union[str, os.PathLike]], IO[bytes]]
 
 # Type for includeIf condition matcher
 # Takes the condition value (e.g., "main" for onbranch:main) and returns bool
@@ -194,7 +191,7 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
           default_factory: Optional factory function for default values
         """
         self._real: list[tuple[K, V]] = []
-        self._keyed: dict[Any, V] = {}
+        self._keyed: dict[ConfigKey, V] = {}
         self._default_factory = default_factory
 
     @classmethod
@@ -239,7 +236,31 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
 
     def keys(self) -> KeysView[K]:
         """Return a view of the dictionary's keys."""
-        return self._keyed.keys()  # type: ignore[return-value]
+        # Return a view of the original keys (not lowercased)
+        # We need to deduplicate since _real can have duplicates
+        seen = set()
+        unique_keys = []
+        for k, _ in self._real:
+            lower = lower_key(k)
+            if lower not in seen:
+                seen.add(lower)
+                unique_keys.append(k)
+        from collections.abc import KeysView as ABCKeysView
+
+        class UniqueKeysView(ABCKeysView[K]):
+            def __init__(self, keys: list[K]):
+                self._keys = keys
+
+            def __contains__(self, key: object) -> bool:
+                return key in self._keys
+
+            def __iter__(self):
+                return iter(self._keys)
+
+            def __len__(self) -> int:
+                return len(self._keys)
+
+        return UniqueKeysView(unique_keys)
 
     def items(self) -> ItemsView[K, V]:
         """Return a view of the dictionary's (key, value) pairs in insertion order."""
@@ -267,7 +288,13 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
 
     def __iter__(self) -> Iterator[K]:
         """Iterate over the dictionary's keys."""
-        return iter(self._keyed)
+        # Return iterator over original keys (not lowercased), deduplicated
+        seen = set()
+        for k, _ in self._real:
+            lower = lower_key(k)
+            if lower not in seen:
+                seen.add(lower)
+                yield k
 
     def values(self) -> ValuesView[V]:
         """Return a view of the dictionary's values."""
@@ -898,7 +925,7 @@ class ConfigFile(ConfigDict):
     @classmethod
     def from_file(
         cls,
-        f: BinaryIO,
+        f: IO[bytes],
         *,
         config_dir: Optional[str] = None,
         included_paths: Optional[set[str]] = None,
@@ -1075,8 +1102,8 @@ class ConfigFile(ConfigDict):
             opener: FileOpener
             if file_opener is None:
 
-                def opener(path: Union[str, os.PathLike]) -> BinaryIO:
-                    return cast(BinaryIO, GitFile(path, "rb"))
+                def opener(path: Union[str, os.PathLike]) -> IO[bytes]:
+                    return GitFile(path, "rb")
             else:
                 opener = file_opener
 
@@ -1236,8 +1263,8 @@ class ConfigFile(ConfigDict):
         opener: FileOpener
         if file_opener is None:
 
-            def opener(p: Union[str, os.PathLike]) -> BinaryIO:
-                return cast(BinaryIO, GitFile(p, "rb"))
+            def opener(p: Union[str, os.PathLike]) -> IO[bytes]:
+                return GitFile(p, "rb")
         else:
             opener = file_opener
 

+ 14 - 8
dulwich/contrib/swift.py

@@ -43,6 +43,7 @@ from typing import BinaryIO, Callable, Optional, Union, cast
 
 from geventhttpclient import HTTPClient
 
+from ..file import _GitFile
 from ..greenthreads import GreenThreadsMissingObjectFinder
 from ..lru_cache import LRUSizeCache
 from ..object_store import INFODIR, PACKDIR, ObjectContainer, PackBasedObjectStore
@@ -823,7 +824,11 @@ class SwiftObjectStore(PackBasedObjectStore):
               The created SwiftPack or None if empty
             """
             f.seek(0)
-            pack = PackData(file=f, filename="")
+            from typing import cast
+
+            from ..file import _GitFile
+
+            pack = PackData(file=cast(_GitFile, f), filename="")
             entries = pack.sorted_entries()
             if entries:
                 basename = posixpath.join(
@@ -875,7 +880,8 @@ class SwiftObjectStore(PackBasedObjectStore):
         fd, path = tempfile.mkstemp(prefix="tmp_pack_")
         f = os.fdopen(fd, "w+b")
         try:
-            indexer = PackIndexer(f, resolve_ext_ref=None)
+            pack_data = PackData(file=cast(_GitFile, f), filename=path)
+            indexer = PackIndexer(cast(BinaryIO, pack_data._file), resolve_ext_ref=None)
             copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer)
             copier.verify()
             return self._complete_thin_pack(f, path, copier, indexer)
@@ -890,7 +896,7 @@ class SwiftObjectStore(PackBasedObjectStore):
 
         # Update the header with the new number of objects.
         f.seek(0)
-        write_pack_header(f, len(entries) + len(indexer.ext_refs()))  # type: ignore
+        write_pack_header(f, len(entries) + len(indexer.ext_refs))  # type: ignore
 
         # Must flush before reading (http://bugs.python.org/issue3207)
         f.flush()
@@ -902,7 +908,7 @@ class SwiftObjectStore(PackBasedObjectStore):
         f.seek(0, os.SEEK_CUR)
 
         # Complete the pack.
-        for ext_sha in indexer.ext_refs():  # type: ignore
+        for ext_sha in indexer.ext_refs:  # type: ignore
             assert len(ext_sha) == 20
             type_num, data = self.get_raw(ext_sha)
             offset = f.tell()
@@ -928,7 +934,7 @@ class SwiftObjectStore(PackBasedObjectStore):
 
         # Write pack info.
         f.seek(0)
-        pack_data = PackData(filename="", file=f)
+        pack_data = PackData(filename="", file=cast(_GitFile, f))
         index_file.seek(0)
         pack_index = load_pack_index_file("", index_file)
         serialized_pack_info = pack_info_create(pack_data, pack_index)
@@ -1030,17 +1036,17 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
         del self._refs[name]
         return True
 
-    def allkeys(self) -> Iterator[bytes]:
+    def allkeys(self) -> set[bytes]:
         """Get all reference names.
 
         Returns:
-          Iterator of reference names as bytes
+          Set of reference names as bytes
         """
         try:
             self._refs[b"HEAD"] = self._refs[b"refs/heads/master"]
         except KeyError:
             pass
-        return iter(self._refs.keys())
+        return set(self._refs.keys())
 
 
 class SwiftRepo(BaseRepo):

+ 12 - 6
dulwich/diff.py

@@ -47,11 +47,11 @@ Example usage:
 import logging
 import os
 import stat
-from typing import BinaryIO, Optional, cast
+from typing import BinaryIO, Optional
 
 from .index import ConflictedIndexEntry, commit_index
 from .object_store import iter_tree_contents
-from .objects import S_ISGITLINK, Blob
+from .objects import S_ISGITLINK, Blob, Commit
 from .patch import write_blob_diff, write_object_diff
 from .repo import Repo
 
@@ -90,12 +90,16 @@ def diff_index_to_tree(
     if commit_sha is None:
         try:
             commit_sha = repo.refs[b"HEAD"]
-            old_tree = repo[commit_sha].tree
+            old_commit = repo[commit_sha]
+            assert isinstance(old_commit, Commit)
+            old_tree = old_commit.tree
         except KeyError:
             # No HEAD means no commits yet
             old_tree = None
     else:
-        old_tree = repo[commit_sha].tree
+        old_commit = repo[commit_sha]
+        assert isinstance(old_commit, Commit)
+        old_tree = old_commit.tree
 
     # Get tree from index
     index = repo.open_index()
@@ -125,7 +129,9 @@ def diff_working_tree_to_tree(
         commit_sha: SHA of commit to compare against
         paths: Optional list of paths to filter (as bytes)
     """
-    tree = repo[commit_sha].tree
+    commit = repo[commit_sha]
+    assert isinstance(commit, Commit)
+    tree = commit.tree
     normalizer = repo.get_blob_normalizer()
     filter_callback = normalizer.checkin_normalize
 
@@ -382,7 +388,7 @@ def diff_working_tree_to_index(
         old_obj = repo.object_store[old_sha]
         # Type check and cast to Blob
         if isinstance(old_obj, Blob):
-            old_blob = cast(Blob, old_obj)
+            old_blob = old_obj
         else:
             old_blob = None
 

+ 4 - 4
dulwich/dumb.py

@@ -24,7 +24,7 @@
 import os
 import tempfile
 import zlib
-from collections.abc import Iterator
+from collections.abc import Iterator, Sequence
 from io import BytesIO
 from typing import Any, Callable, Optional
 from urllib.parse import urljoin
@@ -338,9 +338,9 @@ class DumbHTTPObjectStore(BaseObjectStore):
 
     def add_objects(
         self,
-        objects: Iterator[ShaFile],
-        progress: Optional[Callable[[int], None]] = None,
-    ) -> None:
+        objects: Sequence[tuple[ShaFile, Optional[str]]],
+        progress: Optional[Callable[[str], None]] = None,
+    ) -> Optional["Pack"]:
         """Add a set of objects to this object store."""
         raise NotImplementedError("Cannot add objects to dumb HTTP repository")
 

+ 57 - 9
dulwich/file.py

@@ -24,10 +24,15 @@
 import os
 import sys
 import warnings
-from collections.abc import Iterator
+from collections.abc import Iterable, Iterator
 from types import TracebackType
 from typing import IO, Any, ClassVar, Literal, Optional, Union, overload
 
+if sys.version_info >= (3, 12):
+    from collections.abc import Buffer
+else:
+    Buffer = Union[bytes, bytearray, memoryview]
+
 
 def ensure_dir_exists(dirname: Union[str, bytes, os.PathLike]) -> None:
     """Ensure a directory exists, creating if necessary."""
@@ -136,7 +141,7 @@ class FileLocked(Exception):
         super().__init__(filename, lockfilename)
 
 
-class _GitFile:
+class _GitFile(IO[bytes]):
     """File that follows the git locking protocol for writes.
 
     All writes to a file foo will be written into foo.lock in the same
@@ -148,7 +153,6 @@ class _GitFile:
     """
 
     PROXY_PROPERTIES: ClassVar[set[str]] = {
-        "closed",
         "encoding",
         "errors",
         "mode",
@@ -158,15 +162,19 @@ class _GitFile:
     }
     PROXY_METHODS: ClassVar[set[str]] = {
         "__iter__",
+        "__next__",
         "flush",
         "fileno",
         "isatty",
         "read",
+        "readable",
         "readline",
         "readlines",
         "seek",
+        "seekable",
         "tell",
         "truncate",
+        "writable",
         "write",
         "writelines",
     }
@@ -195,9 +203,6 @@ class _GitFile:
         self._file = os.fdopen(fd, mode, bufsize)
         self._closed = False
 
-        for method in self.PROXY_METHODS:
-            setattr(self, method, getattr(self._file, method))
-
     def __iter__(self) -> Iterator[bytes]:
         """Iterate over lines in the file."""
         return iter(self._file)
@@ -267,20 +272,63 @@ class _GitFile:
         else:
             self.close()
 
+    @property
+    def closed(self) -> bool:
+        """Return whether the file is closed."""
+        return self._closed
+
     def __getattr__(self, name: str) -> Any:  # noqa: ANN401
         """Proxy property calls to the underlying file."""
         if name in self.PROXY_PROPERTIES:
             return getattr(self._file, name)
         raise AttributeError(name)
 
+    # Implement IO[bytes] methods by delegating to the underlying file
+    def read(self, size: int = -1) -> bytes:
+        return self._file.read(size)
+
+    # TODO: Remove type: ignore when Python 3.10 support is dropped (Oct 2026)
+    # Python 3.9/3.10 have issues with IO[bytes] overload signatures
+    def write(self, data: Buffer, /) -> int:  # type: ignore[override]
+        return self._file.write(data)
+
+    def readline(self, size: int = -1) -> bytes:
+        return self._file.readline(size)
+
+    def readlines(self, hint: int = -1) -> list[bytes]:
+        return self._file.readlines(hint)
+
+    # TODO: Remove type: ignore when Python 3.10 support is dropped (Oct 2026)
+    # Python 3.9/3.10 have issues with IO[bytes] overload signatures
+    def writelines(self, lines: Iterable[Buffer], /) -> None:  # type: ignore[override]
+        return self._file.writelines(lines)
+
+    def seek(self, offset: int, whence: int = 0) -> int:
+        return self._file.seek(offset, whence)
+
+    def tell(self) -> int:
+        return self._file.tell()
+
+    def flush(self) -> None:
+        return self._file.flush()
+
+    def truncate(self, size: Optional[int] = None) -> int:
+        return self._file.truncate(size)
+
+    def fileno(self) -> int:
+        return self._file.fileno()
+
+    def isatty(self) -> bool:
+        return self._file.isatty()
+
     def readable(self) -> bool:
-        """Return whether the file is readable."""
         return self._file.readable()
 
     def writable(self) -> bool:
-        """Return whether the file is writable."""
         return self._file.writable()
 
     def seekable(self) -> bool:
-        """Return whether the file is seekable."""
         return self._file.seekable()
+
+    def __next__(self) -> bytes:
+        return next(iter(self._file))

+ 15 - 7
dulwich/filters.py

@@ -30,7 +30,7 @@ from .objects import Blob
 
 if TYPE_CHECKING:
     from .config import StackedConfig
-    from .repo import Repo
+    from .repo import BaseRepo
 
 
 class FilterError(Exception):
@@ -128,7 +128,9 @@ class FilterRegistry:
     """Registry for filter drivers."""
 
     def __init__(
-        self, config: Optional["StackedConfig"] = None, repo: Optional["Repo"] = None
+        self,
+        config: Optional["StackedConfig"] = None,
+        repo: Optional["BaseRepo"] = None,
     ) -> None:
         """Initialize FilterRegistry.
 
@@ -211,8 +213,12 @@ class FilterRegistry:
         required = self.config.get_boolean(("filter", name), "required", False)
 
         if clean_cmd or smudge_cmd:
-            # Get repository working directory
-            repo_path = self.repo.path if self.repo else None
+            # Get repository working directory (only for Repo, not BaseRepo)
+            from .repo import Repo
+
+            repo_path = (
+                self.repo.path if self.repo and isinstance(self.repo, Repo) else None
+            )
             return ProcessFilterDriver(clean_cmd, smudge_cmd, required, repo_path)
 
         return None
@@ -221,8 +227,10 @@ class FilterRegistry:
         """Create LFS filter driver."""
         from .lfs import LFSFilterDriver, LFSStore
 
-        # If we have a repo, use its LFS store
-        if registry.repo is not None:
+        # If we have a Repo (not just BaseRepo), use its LFS store
+        from .repo import Repo
+
+        if registry.repo is not None and isinstance(registry.repo, Repo):
             lfs_store = LFSStore.from_repo(registry.repo, create=True)
         else:
             # Fall back to creating a temporary LFS store
@@ -389,7 +397,7 @@ class FilterBlobNormalizer:
         config_stack: Optional["StackedConfig"],
         gitattributes: GitAttributes,
         filter_registry: Optional[FilterRegistry] = None,
-        repo: Optional["Repo"] = None,
+        repo: Optional["BaseRepo"] = None,
     ) -> None:
         """Initialize FilterBlobNormalizer.
 

+ 24 - 22
dulwich/index.py

@@ -32,13 +32,13 @@ from collections.abc import Generator, Iterable, Iterator
 from dataclasses import dataclass
 from enum import Enum
 from typing import (
+    IO,
     TYPE_CHECKING,
     Any,
     BinaryIO,
     Callable,
     Optional,
     Union,
-    cast,
 )
 
 if TYPE_CHECKING:
@@ -588,7 +588,7 @@ def read_cache_time(f: BinaryIO) -> tuple[int, int]:
     return struct.unpack(">LL", f.read(8))
 
 
-def write_cache_time(f: BinaryIO, t: Union[int, float, tuple[int, int]]) -> None:
+def write_cache_time(f: IO[bytes], t: Union[int, float, tuple[int, int]]) -> None:
     """Write a cache time.
 
     Args:
@@ -664,7 +664,7 @@ def read_cache_entry(
 
 
 def write_cache_entry(
-    f: BinaryIO, entry: SerializedIndexEntry, version: int, previous_path: bytes = b""
+    f: IO[bytes], entry: SerializedIndexEntry, version: int, previous_path: bytes = b""
 ) -> None:
     """Write an index entry to a file.
 
@@ -745,7 +745,7 @@ def read_index_header(f: BinaryIO) -> tuple[int, int]:
     return version, num_entries
 
 
-def write_index_extension(f: BinaryIO, extension: IndexExtension) -> None:
+def write_index_extension(f: IO[bytes], extension: IndexExtension) -> None:
     """Write an index extension.
 
     Args:
@@ -868,7 +868,7 @@ def read_index_dict(
 
 
 def write_index(
-    f: BinaryIO,
+    f: IO[bytes],
     entries: list[SerializedIndexEntry],
     version: Optional[int] = None,
     extensions: Optional[list[IndexExtension]] = None,
@@ -909,7 +909,7 @@ def write_index(
 
 
 def write_index_dict(
-    f: BinaryIO,
+    f: IO[bytes],
     entries: dict[bytes, Union[IndexEntry, ConflictedIndexEntry]],
     version: Optional[int] = None,
     extensions: Optional[list[IndexExtension]] = None,
@@ -1007,8 +1007,6 @@ class Index:
 
     def write(self) -> None:
         """Write current contents of index to disk."""
-        from typing import BinaryIO, cast
-
         f = GitFile(self._filename, "wb")
         try:
             # Filter out extensions with no meaningful data
@@ -1022,7 +1020,7 @@ class Index:
             if self._skip_hash:
                 # When skipHash is enabled, write the index without computing SHA1
                 write_index_dict(
-                    cast(BinaryIO, f),
+                    f,
                     self._byname,
                     version=self._version,
                     extensions=meaningful_extensions,
@@ -1031,9 +1029,9 @@ class Index:
                 f.write(b"\x00" * 20)
                 f.close()
             else:
-                sha1_writer = SHA1Writer(cast(BinaryIO, f))
+                sha1_writer = SHA1Writer(f)
                 write_index_dict(
-                    cast(BinaryIO, sha1_writer),
+                    sha1_writer,
                     self._byname,
                     version=self._version,
                     extensions=meaningful_extensions,
@@ -1050,9 +1048,7 @@ class Index:
         f = GitFile(self._filename, "rb")
         try:
             sha1_reader = SHA1Reader(f)
-            entries, version, extensions = read_index_dict_with_version(
-                cast(BinaryIO, sha1_reader)
-            )
+            entries, version, extensions = read_index_dict_with_version(sha1_reader)
             self._version = version
             self._extensions = extensions
             self.update(entries)
@@ -1411,7 +1407,9 @@ def build_file_from_blob(
     *,
     honor_filemode: bool = True,
     tree_encoding: str = "utf-8",
-    symlink_fn: Optional[Callable] = None,
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ] = None,
 ) -> os.stat_result:
     """Build a file or symlink on disk based on a Git object.
 
@@ -1596,7 +1594,9 @@ def build_index_from_tree(
     tree_id: bytes,
     honor_filemode: bool = True,
     validate_path_element: Callable[[bytes], bool] = validate_path_element_default,
-    symlink_fn: Optional[Callable] = None,
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ] = None,
     blob_normalizer: Optional["BlobNormalizer"] = None,
     tree_encoding: str = "utf-8",
 ) -> None:
@@ -1956,7 +1956,9 @@ def _transition_to_file(
     entry: IndexEntry,
     index: Index,
     honor_filemode: bool,
-    symlink_fn: Optional[Callable[[bytes, bytes], None]],
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ],
     blob_normalizer: Optional["BlobNormalizer"],
     tree_encoding: str = "utf-8",
 ) -> None:
@@ -2208,7 +2210,9 @@ def update_working_tree(
     change_iterator: Iterator["TreeChange"],
     honor_filemode: bool = True,
     validate_path_element: Optional[Callable[[bytes], bool]] = None,
-    symlink_fn: Optional[Callable] = None,
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ] = None,
     force_remove_untracked: bool = False,
     blob_normalizer: Optional["BlobNormalizer"] = None,
     tree_encoding: str = "utf-8",
@@ -2685,10 +2689,8 @@ class locked_index:
             self._file.abort()
             return
         try:
-            from typing import BinaryIO, cast
-
-            f = SHA1Writer(cast(BinaryIO, self._file))
-            write_index_dict(cast(BinaryIO, f), self._index._byname)
+            f = SHA1Writer(self._file)
+            write_index_dict(f, self._index._byname)
         except BaseException:
             self._file.abort()
         else:

+ 153 - 103
dulwich/object_store.py

@@ -33,11 +33,11 @@ from collections.abc import Iterable, Iterator, Sequence
 from contextlib import suppress
 from io import BytesIO
 from typing import (
+    TYPE_CHECKING,
     Callable,
     Optional,
     Protocol,
     Union,
-    cast,
 )
 
 from .errors import NotTreeError
@@ -82,6 +82,23 @@ from .pack import (
 from .protocol import DEPTH_INFINITE
 from .refs import PEELED_TAG_SUFFIX, Ref
 
+if TYPE_CHECKING:
+    from .commit_graph import CommitGraph
+    from .diff_tree import RenameDetector
+
+
+class GraphWalker(Protocol):
+    """Protocol for graph walker objects."""
+
+    def __next__(self) -> Optional[bytes]:
+        """Return the next object SHA to visit."""
+        ...
+
+    def ack(self, sha: bytes) -> None:
+        """Acknowledge that an object has been received."""
+        ...
+
+
 INFODIR = "info"
 PACKDIR = "pack"
 
@@ -95,7 +112,9 @@ PACK_MODE = 0o444 if sys.platform != "win32" else 0o644
 DEFAULT_TEMPFILE_GRACE_PERIOD = 14 * 24 * 60 * 60  # 2 weeks
 
 
-def find_shallow(store, heads, depth):
+def find_shallow(
+    store: ObjectContainer, heads: Iterable[bytes], depth: int
+) -> tuple[set[bytes], set[bytes]]:
     """Find shallow commits according to a given depth.
 
     Args:
@@ -107,10 +126,10 @@ def find_shallow(store, heads, depth):
         considered shallow and unshallow according to the arguments. Note that
         these sets may overlap if a commit is reachable along multiple paths.
     """
-    parents = {}
+    parents: dict[bytes, list[bytes]] = {}
     commit_graph = store.get_commit_graph()
 
-    def get_parents(sha):
+    def get_parents(sha: bytes) -> list[bytes]:
         result = parents.get(sha, None)
         if not result:
             # Try to use commit graph first if available
@@ -121,7 +140,9 @@ def find_shallow(store, heads, depth):
                     parents[sha] = result
                     return result
             # Fall back to loading the object
-            result = store[sha].parents
+            commit = store[sha]
+            assert isinstance(commit, Commit)
+            result = commit.parents
             parents[sha] = result
         return result
 
@@ -150,11 +171,11 @@ def find_shallow(store, heads, depth):
 
 
 def get_depth(
-    store,
-    head,
-    get_parents=lambda commit: commit.parents,
-    max_depth=None,
-):
+    store: ObjectContainer,
+    head: bytes,
+    get_parents: Callable = lambda commit: commit.parents,
+    max_depth: Optional[int] = None,
+) -> int:
     """Return the current available depth for the given head.
 
     For commits with multiple parents, the largest possible depth will be
@@ -206,17 +227,9 @@ class BaseObjectStore:
     def determine_wants_all(
         self, refs: dict[Ref, ObjectID], depth: Optional[int] = None
     ) -> list[ObjectID]:
-        """Determine all objects that are wanted by the client.
+        """Determine which objects are wanted based on refs."""
 
-        Args:
-          refs: Dictionary mapping ref names to object IDs
-          depth: Shallow fetch depth (None for full fetch)
-
-        Returns:
-          List of object IDs that are wanted
-        """
-
-        def _want_deepen(sha):
+        def _want_deepen(sha: bytes) -> bool:
             if not depth:
                 return False
             if depth == DEPTH_INFINITE:
@@ -231,7 +244,7 @@ class BaseObjectStore:
             and not sha == ZERO_SHA
         ]
 
-    def contains_loose(self, sha) -> bool:
+    def contains_loose(self, sha: bytes) -> bool:
         """Check if a particular object is present by SHA1 and is loose."""
         raise NotImplementedError(self.contains_loose)
 
@@ -243,11 +256,11 @@ class BaseObjectStore:
         return self.contains_loose(sha1)
 
     @property
-    def packs(self):
+    def packs(self) -> list[Pack]:
         """Iterable of pack objects."""
         raise NotImplementedError
 
-    def get_raw(self, name) -> tuple[int, bytes]:
+    def get_raw(self, name: bytes) -> tuple[int, bytes]:
         """Obtain the raw text for an object.
 
         Args:
@@ -261,15 +274,19 @@ class BaseObjectStore:
         type_num, uncomp = self.get_raw(sha1)
         return ShaFile.from_raw_string(type_num, uncomp, sha=sha1)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[bytes]:
         """Iterate over the SHAs that are present in this store."""
         raise NotImplementedError(self.__iter__)
 
-    def add_object(self, obj) -> None:
+    def add_object(self, obj: ShaFile) -> None:
         """Add a single object to this object store."""
         raise NotImplementedError(self.add_object)
 
-    def add_objects(self, objects, progress=None) -> None:
+    def add_objects(
+        self,
+        objects: Sequence[tuple[ShaFile, Optional[str]]],
+        progress: Optional[Callable] = None,
+    ) -> Optional["Pack"]:
         """Add a set of objects to this object store.
 
         Args:
@@ -280,14 +297,20 @@ class BaseObjectStore:
 
     def tree_changes(
         self,
-        source,
-        target,
-        want_unchanged=False,
-        include_trees=False,
-        change_type_same=False,
-        rename_detector=None,
-        paths=None,
-    ):
+        source: Optional[bytes],
+        target: Optional[bytes],
+        want_unchanged: bool = False,
+        include_trees: bool = False,
+        change_type_same: bool = False,
+        rename_detector: Optional["RenameDetector"] = None,
+        paths: Optional[list[bytes]] = None,
+    ) -> Iterator[
+        tuple[
+            tuple[Optional[bytes], Optional[bytes]],
+            tuple[Optional[int], Optional[int]],
+            tuple[Optional[bytes], Optional[bytes]],
+        ]
+    ]:
         """Find the differences between the contents of two trees.
 
         Args:
@@ -320,7 +343,9 @@ class BaseObjectStore:
                 (change.old.sha, change.new.sha),
             )
 
-    def iter_tree_contents(self, tree_id, include_trees=False):
+    def iter_tree_contents(
+        self, tree_id: bytes, include_trees: bool = False
+    ) -> Iterator[tuple[bytes, int, bytes]]:
         """Iterate the contents of a tree and all subtrees.
 
         Iteration is depth-first pre-order, as in e.g. os.walk.
@@ -362,13 +387,13 @@ class BaseObjectStore:
 
     def find_missing_objects(
         self,
-        haves,
-        wants,
-        shallow=None,
-        progress=None,
-        get_tagged=None,
-        get_parents=lambda commit: commit.parents,
-    ):
+        haves: Iterable[bytes],
+        wants: Iterable[bytes],
+        shallow: Optional[set[bytes]] = None,
+        progress: Optional[Callable] = None,
+        get_tagged: Optional[Callable] = None,
+        get_parents: Callable = lambda commit: commit.parents,
+    ) -> Iterator[tuple[bytes, Optional[bytes]]]:
         """Find the missing objects required for a set of revisions.
 
         Args:
@@ -395,7 +420,7 @@ class BaseObjectStore:
         )
         return iter(finder)
 
-    def find_common_revisions(self, graphwalker):
+    def find_common_revisions(self, graphwalker: GraphWalker) -> list[bytes]:
         """Find which revisions this store has in common using graphwalker.
 
         Args:
@@ -412,7 +437,12 @@ class BaseObjectStore:
         return haves
 
     def generate_pack_data(
-        self, have, want, shallow=None, progress=None, ofs_delta=True
+        self,
+        have: Iterable[bytes],
+        want: Iterable[bytes],
+        shallow: Optional[set[bytes]] = None,
+        progress: Optional[Callable] = None,
+        ofs_delta: bool = True,
     ) -> tuple[int, Iterator[UnpackedObject]]:
         """Generate pack data objects for a set of wants/haves.
 
@@ -435,7 +465,7 @@ class BaseObjectStore:
             progress=progress,
         )
 
-    def peel_sha(self, sha):
+    def peel_sha(self, sha: bytes) -> bytes:
         """Peel all tags from a SHA.
 
         Args:
@@ -449,14 +479,14 @@ class BaseObjectStore:
             DeprecationWarning,
             stacklevel=2,
         )
-        return peel_sha(self, sha)[1]
+        return peel_sha(self, sha)[1].id
 
     def _get_depth(
         self,
-        head,
-        get_parents=lambda commit: commit.parents,
-        max_depth=None,
-    ):
+        head: bytes,
+        get_parents: Callable = lambda commit: commit.parents,
+        max_depth: Optional[int] = None,
+    ) -> int:
         """Return the current available depth for the given head.
 
         For commits with multiple parents, the largest possible depth will be
@@ -496,7 +526,7 @@ class BaseObjectStore:
             if sha.startswith(prefix):
                 yield sha
 
-    def get_commit_graph(self):
+    def get_commit_graph(self) -> Optional["CommitGraph"]:
         """Get the commit graph for this object store.
 
         Returns:
@@ -504,7 +534,9 @@ class BaseObjectStore:
         """
         return None
 
-    def write_commit_graph(self, refs=None, reachable=True) -> None:
+    def write_commit_graph(
+        self, refs: Optional[list[bytes]] = None, reachable: bool = True
+    ) -> None:
         """Write a commit graph file for this object store.
 
         Args:
@@ -518,7 +550,7 @@ class BaseObjectStore:
         """
         raise NotImplementedError(self.write_commit_graph)
 
-    def get_object_mtime(self, sha):
+    def get_object_mtime(self, sha: bytes) -> float:
         """Get the modification time of an object.
 
         Args:
@@ -545,14 +577,14 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
 
     def __init__(
         self,
-        pack_compression_level=-1,
-        pack_index_version=None,
-        pack_delta_window_size=None,
-        pack_window_memory=None,
-        pack_delta_cache_size=None,
-        pack_depth=None,
-        pack_threads=None,
-        pack_big_file_threshold=None,
+        pack_compression_level: int = -1,
+        pack_index_version: Optional[int] = None,
+        pack_delta_window_size: Optional[int] = None,
+        pack_window_memory: Optional[int] = None,
+        pack_delta_cache_size: Optional[int] = None,
+        pack_depth: Optional[int] = None,
+        pack_threads: Optional[int] = None,
+        pack_big_file_threshold: Optional[int] = None,
     ) -> None:
         """Initialize a PackBasedObjectStore.
 
@@ -581,8 +613,11 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         raise NotImplementedError(self.add_pack)
 
     def add_pack_data(
-        self, count: int, unpacked_objects: Iterator[UnpackedObject], progress=None
-    ) -> None:
+        self,
+        count: int,
+        unpacked_objects: Iterator[UnpackedObject],
+        progress: Optional[Callable] = None,
+    ) -> Optional["Pack"]:
         """Add pack data to this object store.
 
         Args:
@@ -592,7 +627,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         """
         if count == 0:
             # Don't bother writing an empty pack file
-            return
+            return None
         f, commit, abort = self.add_pack()
         try:
             write_pack_data(
@@ -609,15 +644,11 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
             return commit()
 
     @property
-    def alternates(self):
-        """Get the list of alternate object stores.
-
-        Returns:
-          List of alternate BaseObjectStore instances
-        """
+    def alternates(self) -> list:
+        """Return list of alternate object stores."""
         return []
 
-    def contains_packed(self, sha) -> bool:
+    def contains_packed(self, sha: bytes) -> bool:
         """Check if a particular object is present by SHA1 and is packed.
 
         This does not check alternates.
@@ -642,7 +673,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
                 return True
         return False
 
-    def _add_cached_pack(self, base_name, pack) -> None:
+    def _add_cached_pack(self, base_name: str, pack: Pack) -> None:
         """Add a newly appeared pack to the cache by path."""
         prev_pack = self._pack_cache.get(base_name)
         if prev_pack is not pack:
@@ -668,7 +699,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         remote_has = missing_objects.get_remote_has()
         object_ids = list(missing_objects)
         return len(object_ids), generate_unpacked_objects(
-            cast(PackedObjectContainer, self),
+            self,
             object_ids,
             progress=progress,
             ofs_delta=ofs_delta,
@@ -682,8 +713,8 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
             (name, pack) = pack_cache.popitem()
             pack.close()
 
-    def _iter_cached_packs(self):
-        return self._pack_cache.values()
+    def _iter_cached_packs(self) -> Iterator[Pack]:
+        return iter(self._pack_cache.values())
 
     def _update_pack_cache(self) -> list[Pack]:
         raise NotImplementedError(self._update_pack_cache)
@@ -696,7 +727,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         self._clear_cached_packs()
 
     @property
-    def packs(self):
+    def packs(self) -> list[Pack]:
         """List with pack objects."""
         return list(self._iter_cached_packs()) + list(self._update_pack_cache())
 
@@ -714,19 +745,19 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
                 count += 1
         return count
 
-    def _iter_alternate_objects(self):
+    def _iter_alternate_objects(self) -> Iterator[bytes]:
         """Iterate over the SHAs of all the objects in alternate stores."""
         for alternate in self.alternates:
             yield from alternate
 
-    def _iter_loose_objects(self):
+    def _iter_loose_objects(self) -> Iterator[bytes]:
         """Iterate over the SHAs of all loose objects."""
         raise NotImplementedError(self._iter_loose_objects)
 
-    def _get_loose_object(self, sha) -> Optional[ShaFile]:
+    def _get_loose_object(self, sha: bytes) -> Optional[ShaFile]:
         raise NotImplementedError(self._get_loose_object)
 
-    def delete_loose_object(self, sha) -> None:
+    def delete_loose_object(self, sha: bytes) -> None:
         """Delete a loose object.
 
         This method only handles loose objects. For packed objects,
@@ -734,23 +765,25 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         """
         raise NotImplementedError(self.delete_loose_object)
 
-    def _remove_pack(self, name) -> None:
+    def _remove_pack(self, pack: "Pack") -> None:
         raise NotImplementedError(self._remove_pack)
 
-    def pack_loose_objects(self):
+    def pack_loose_objects(self) -> int:
         """Pack loose objects.
 
         Returns: Number of objects packed
         """
-        objects = set()
+        objects: list[tuple[ShaFile, None]] = []
         for sha in self._iter_loose_objects():
-            objects.add((self._get_loose_object(sha), None))
-        self.add_objects(list(objects))
+            obj = self._get_loose_object(sha)
+            if obj is not None:
+                objects.append((obj, None))
+        self.add_objects(objects)
         for obj, path in objects:
             self.delete_loose_object(obj.id)
         return len(objects)
 
-    def repack(self, exclude=None):
+    def repack(self, exclude: Optional[set] = None) -> int:
         """Repack the packs in this repository.
 
         Note that this implementation is fairly naive and currently keeps all
@@ -766,11 +799,13 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         excluded_loose_objects = set()
         for sha in self._iter_loose_objects():
             if sha not in exclude:
-                loose_objects.add(self._get_loose_object(sha))
+                obj = self._get_loose_object(sha)
+                if obj is not None:
+                    loose_objects.add(obj)
             else:
                 excluded_loose_objects.add(sha)
 
-        objects = {(obj, None) for obj in loose_objects}
+        objects: set[tuple[ShaFile, None]] = {(obj, None) for obj in loose_objects}
         old_packs = {p.name(): p for p in self.packs}
         for name, pack in old_packs.items():
             objects.update(
@@ -782,12 +817,14 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
             # The name of the consolidated pack might match the name of a
             # pre-existing pack. Take care not to remove the newly created
             # consolidated pack.
-            consolidated = self.add_objects(objects)
-            old_packs.pop(consolidated.name(), None)
+            consolidated = self.add_objects(list(objects))
+            if consolidated is not None:
+                old_packs.pop(consolidated.name(), None)
 
         # Delete loose objects that were packed
         for obj in loose_objects:
-            self.delete_loose_object(obj.id)
+            if obj is not None:
+                self.delete_loose_object(obj.id)
         # Delete excluded loose objects
         for sha in excluded_loose_objects:
             self.delete_loose_object(sha)
@@ -943,9 +980,9 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
                 yield o
                 todo.remove(o.id)
         for oid in todo:
-            o = self._get_loose_object(oid)
-            if o is not None:
-                yield o
+            loose_obj: Optional[ShaFile] = self._get_loose_object(oid)
+            if loose_obj is not None:
+                yield loose_obj
             elif not allow_missing:
                 raise KeyError(oid)
 
@@ -993,7 +1030,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         self,
         objects: Sequence[tuple[ShaFile, Optional[str]]],
         progress: Optional[Callable[[str], None]] = None,
-    ) -> None:
+    ) -> Optional["Pack"]:
         """Add a set of objects to this object store.
 
         Args:
@@ -1012,6 +1049,8 @@ class DiskObjectStore(PackBasedObjectStore):
 
     path: Union[str, os.PathLike]
     pack_dir: Union[str, os.PathLike]
+    _alternates: Optional[list["DiskObjectStore"]]
+    _commit_graph: Optional["CommitGraph"]
 
     def __init__(
         self,
@@ -1244,7 +1283,7 @@ class DiskObjectStore(PackBasedObjectStore):
 
     def _get_shafile_path(self, sha):
         # Check from object dir
-        return hex_to_filename(self.path, sha)
+        return hex_to_filename(os.fspath(self.path), sha)
 
     def _iter_loose_objects(self):
         for base in os.listdir(self.path):
@@ -1345,9 +1384,9 @@ class DiskObjectStore(PackBasedObjectStore):
         os.remove(pack.index.path)
 
     def _get_pack_basepath(self, entries):
-        suffix = iter_sha1(entry[0] for entry in entries)
+        suffix_bytes = iter_sha1(entry[0] for entry in entries)
         # TODO: Handle self.pack_dir being bytes
-        suffix = suffix.decode("ascii")
+        suffix = suffix_bytes.decode("ascii")
         return os.path.join(self.pack_dir, "pack-" + suffix)
 
     def _complete_pack(self, f, path, num_objects, indexer, progress=None):
@@ -1371,7 +1410,7 @@ class DiskObjectStore(PackBasedObjectStore):
 
         pack_sha, extra_entries = extend_pack(
             f,
-            indexer.ext_refs(),
+            indexer.ext_refs,
             get_raw=self.get_raw,
             compression_level=self.pack_compression_level,
             progress=progress,
@@ -1466,6 +1505,7 @@ class DiskObjectStore(PackBasedObjectStore):
         def commit():
             if f.tell() > 0:
                 f.seek(0)
+
                 with PackData(path, f) as pd:
                     indexer = PackIndexer.for_pack_data(
                         pd, resolve_ext_ref=self.get_raw
@@ -1798,6 +1838,7 @@ class MemoryObjectStore(BaseObjectStore):
             size = f.tell()
             if size > 0:
                 f.seek(0)
+
                 p = PackData.from_file(f, size)
                 for obj in PackInflater.for_pack_data(p, self.get_raw):
                     self.add_object(obj)
@@ -2134,7 +2175,7 @@ class ObjectStoreGraphWalker:
     heads: set[ObjectID]
     """Revisions without descendants in the local repo."""
 
-    get_parents: Callable[[ObjectID], ObjectID]
+    get_parents: Callable[[ObjectID], list[ObjectID]]
     """Function to retrieve parents in the local repo."""
 
     shallow: set[ObjectID]
@@ -2230,7 +2271,7 @@ def commit_tree_changes(object_store, tree, changes):
     """
     # TODO(jelmer): Save up the objects and add them using .add_objects
     # rather than with individual calls to .add_object.
-    nested_changes = {}
+    nested_changes: dict[bytes, list[tuple[bytes, Optional[int], Optional[bytes]]]] = {}
     for path, new_mode, new_sha in changes:
         try:
             (dirname, subpath) = path.split(b"/", 1)
@@ -2465,8 +2506,16 @@ class BucketBasedObjectStore(PackBasedObjectStore):
         """
         # Doesn't exist..
 
-    def _remove_pack(self, name) -> None:
-        raise NotImplementedError(self._remove_pack)
+    def pack_loose_objects(self) -> int:
+        """Pack loose objects. Returns number of objects packed.
+
+        BucketBasedObjectStore doesn't support loose objects, so this is a no-op.
+        """
+        return 0
+
+    def _remove_pack_by_name(self, name: str) -> None:
+        """Remove a pack by name. Subclasses should implement this."""
+        raise NotImplementedError(self._remove_pack_by_name)
 
     def _iter_pack_names(self) -> Iterator[str]:
         raise NotImplementedError(self._iter_pack_names)
@@ -2511,6 +2560,7 @@ class BucketBasedObjectStore(PackBasedObjectStore):
                 return None
 
             pf.seek(0)
+
             p = PackData(pf.name, pf)
             entries = p.sorted_entries()
             basename = iter_sha1(entry[0] for entry in entries).decode("ascii")

+ 2 - 1
dulwich/objects.py

@@ -430,7 +430,8 @@ class ShaFile:
             self._sha = None
             self._chunked_text = self._serialize()
             self._needs_serialization = False
-        return self._chunked_text  # type: ignore
+        assert self._chunked_text is not None
+        return self._chunked_text
 
     def as_raw_string(self) -> bytes:
         """Return raw string with serialization of the object.

+ 30 - 15
dulwich/objectspec.py

@@ -24,6 +24,7 @@
 from typing import TYPE_CHECKING, Optional, Union
 
 from .objects import Commit, ShaFile, Tag, Tree
+from .repo import BaseRepo
 
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
@@ -40,9 +41,9 @@ def to_bytes(text: Union[str, bytes]) -> bytes:
     Returns:
       Bytes representation of text
     """
-    if getattr(text, "encode", None) is not None:
-        text = text.encode("ascii")  # type: ignore
-    return text  # type: ignore
+    if isinstance(text, str):
+        return text.encode("ascii")
+    return text
 
 
 def _resolve_object(repo: "Repo", ref: bytes) -> "ShaFile":
@@ -136,7 +137,9 @@ def parse_object(repo: "Repo", objectish: Union[bytes, str]) -> "ShaFile":
                         raise ValueError(
                             f"Commit {commit.id.decode('ascii', 'replace')} has no parents"
                         )
-                    commit = repo[commit.parents[0]]
+                    parent_obj = repo[commit.parents[0]]
+                    assert isinstance(parent_obj, Commit)
+                    commit = parent_obj
                 obj = commit
             else:  # sep == b"^"
                 # Get N-th parent (or commit itself if N=0)
@@ -157,11 +160,13 @@ def parse_object(repo: "Repo", objectish: Union[bytes, str]) -> "ShaFile":
     return _resolve_object(repo, objectish)
 
 
-def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "Tree":
+def parse_tree(
+    repo: "BaseRepo", treeish: Union[bytes, str, Tree, Commit, Tag]
+) -> "Tree":
     """Parse a string referring to a tree.
 
     Args:
-      repo: A `Repo` object
+      repo: A repository object
       treeish: A string referring to a tree, or a Tree, Commit, or Tag object
     Returns: A Tree object
     Raises:
@@ -173,7 +178,9 @@ def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "
 
     # If it's a Commit, return its tree
     if isinstance(treeish, Commit):
-        return repo[treeish.tree]
+        tree = repo[treeish.tree]
+        assert isinstance(tree, Tree)
+        return tree
 
     # For Tag objects or strings, use the existing logic
     if isinstance(treeish, Tag):
@@ -181,7 +188,7 @@ def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "
     else:
         treeish = to_bytes(treeish)
     try:
-        treeish = parse_ref(repo, treeish)
+        treeish = parse_ref(repo.refs, treeish)
     except KeyError:  # treeish is commit sha
         pass
     try:
@@ -190,15 +197,21 @@ def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "
         # Try parsing as commit (handles short hashes)
         try:
             commit = parse_commit(repo, treeish)
-            return repo[commit.tree]
+            assert isinstance(commit, Commit)
+            tree = repo[commit.tree]
+            assert isinstance(tree, Tree)
+            return tree
         except KeyError:
             raise KeyError(treeish)
-    if o.type_name == b"commit":
-        return repo[o.tree]
-    elif o.type_name == b"tag":
+    if isinstance(o, Commit):
+        tree = repo[o.tree]
+        assert isinstance(tree, Tree)
+        return tree
+    elif isinstance(o, Tag):
         # Tag handling - dereference and recurse
         obj_type, obj_sha = o.object
         return parse_tree(repo, obj_sha)
+    assert isinstance(o, Tree)
     return o
 
 
@@ -383,11 +396,13 @@ def scan_for_short_id(
     raise AmbiguousShortId(prefix, ret)
 
 
-def parse_commit(repo: "Repo", committish: Union[str, bytes, Commit, Tag]) -> "Commit":
+def parse_commit(
+    repo: "BaseRepo", committish: Union[str, bytes, Commit, Tag]
+) -> "Commit":
     """Parse a string referring to a single commit.
 
     Args:
-      repo: A` Repo` object
+      repo: A repository object
       committish: A string referring to a single commit, or a Commit or Tag object.
     Returns: A Commit object
     Raises:
@@ -426,7 +441,7 @@ def parse_commit(repo: "Repo", committish: Union[str, bytes, Commit, Tag]) -> "C
     else:
         return dereference_tag(obj)
     try:
-        obj = repo[parse_ref(repo, committish)]
+        obj = repo[parse_ref(repo.refs, committish)]
     except KeyError:
         pass
     else:

File diff suppressed because it is too large
+ 243 - 169
dulwich/pack.py


+ 8 - 9
dulwich/patch.py

@@ -30,6 +30,7 @@ import time
 from collections.abc import Generator
 from difflib import SequenceMatcher
 from typing import (
+    IO,
     TYPE_CHECKING,
     BinaryIO,
     Optional,
@@ -48,7 +49,7 @@ FIRST_FEW_BYTES = 8000
 
 
 def write_commit_patch(
-    f: BinaryIO,
+    f: IO[bytes],
     commit: "Commit",
     contents: Union[str, bytes],
     progress: tuple[int, int],
@@ -231,7 +232,7 @@ def patch_filename(p: Optional[bytes], root: bytes) -> bytes:
 
 
 def write_object_diff(
-    f: BinaryIO,
+    f: IO[bytes],
     store: "BaseObjectStore",
     old_file: tuple[Optional[bytes], Optional[int], Optional[bytes]],
     new_file: tuple[Optional[bytes], Optional[int], Optional[bytes]],
@@ -264,19 +265,17 @@ def write_object_diff(
         Returns:
             Blob object
         """
-        from typing import cast
-
         if hexsha is None:
-            return cast(Blob, Blob.from_string(b""))
+            return Blob.from_string(b"")
         elif mode is not None and S_ISGITLINK(mode):
-            return cast(Blob, Blob.from_string(b"Subproject commit " + hexsha + b"\n"))
+            return Blob.from_string(b"Subproject commit " + hexsha + b"\n")
         else:
             obj = store[hexsha]
             if isinstance(obj, Blob):
                 return obj
             else:
                 # Fallback for non-blob objects
-                return cast(Blob, Blob.from_string(obj.as_raw_string()))
+                return Blob.from_string(obj.as_raw_string())
 
     def lines(content: "Blob") -> list[bytes]:
         """Split blob content into lines.
@@ -356,7 +355,7 @@ def gen_diff_header(
 
 # TODO(jelmer): Support writing unicode, rather than bytes.
 def write_blob_diff(
-    f: BinaryIO,
+    f: IO[bytes],
     old_file: tuple[Optional[bytes], Optional[int], Optional["Blob"]],
     new_file: tuple[Optional[bytes], Optional[int], Optional["Blob"]],
 ) -> None:
@@ -403,7 +402,7 @@ def write_blob_diff(
 
 
 def write_tree_diff(
-    f: BinaryIO,
+    f: IO[bytes],
     store: "BaseObjectStore",
     old_tree: Optional[bytes],
     new_tree: Optional[bytes],

File diff suppressed because it is too large
+ 313 - 227
dulwich/porcelain.py


+ 0 - 12
dulwich/protocol.py

@@ -258,18 +258,6 @@ def pkt_seq(*seq: Optional[bytes]) -> bytes:
     return b"".join([pkt_line(s) for s in seq]) + pkt_line(None)
 
 
-def filter_ref_prefix(
-    refs: dict[bytes, bytes], prefixes: Iterable[bytes]
-) -> dict[bytes, bytes]:
-    """Filter refs to only include those with a given prefix.
-
-    Args:
-      refs: A list of refs.
-      prefixes: The prefixes to filter by.
-    """
-    return {k: v for k, v in refs.items() if any(k.startswith(p) for p in prefixes)}
-
-
 class Protocol:
     """Class for interacting with a remote git process over the wire.
 

+ 17 - 3
dulwich/rebase.py

@@ -32,7 +32,7 @@ from dulwich.graph import find_merge_base
 from dulwich.merge import three_way_merge
 from dulwich.objects import Commit
 from dulwich.objectspec import parse_commit
-from dulwich.repo import Repo
+from dulwich.repo import BaseRepo, Repo
 
 
 class RebaseError(Exception):
@@ -529,7 +529,7 @@ class DiskRebaseStateManager:
 class MemoryRebaseStateManager:
     """Manages rebase state in memory for MemoryRepo."""
 
-    def __init__(self, repo: Repo) -> None:
+    def __init__(self, repo: BaseRepo) -> None:
         """Initialize MemoryRebaseStateManager.
 
         Args:
@@ -642,6 +642,8 @@ class Rebaser:
         if branch is None:
             # Use current HEAD
             head_ref, head_sha = self.repo.refs.follow(b"HEAD")
+            if head_sha is None:
+                raise ValueError("HEAD does not point to a valid commit")
             branch_commit = self.repo[head_sha]
         else:
             # Parse the branch reference
@@ -664,6 +666,7 @@ class Rebaser:
         commits = []
         current = branch_commit
         while current.id != merge_base:
+            assert isinstance(current, Commit)
             commits.append(current)
             if not current.parents:
                 break
@@ -691,6 +694,9 @@ class Rebaser:
         parent = self.repo[commit.parents[0]]
         onto_commit = self.repo[onto]
 
+        assert isinstance(parent, Commit)
+        assert isinstance(onto_commit, Commit)
+
         # Perform three-way merge
         merged_tree, conflicts = three_way_merge(
             self.object_store, parent, onto_commit, commit
@@ -798,7 +804,9 @@ class Rebaser:
 
         if new_sha:
             # Success - add to done list
-            self._done.append(self.repo[new_sha])
+            new_commit = self.repo[new_sha]
+            assert isinstance(new_commit, Commit)
+            self._done.append(new_commit)
             self._save_rebase_state()
 
             # Continue with next commit if any
@@ -822,6 +830,8 @@ class Rebaser:
             raise RebaseError("No rebase in progress")
 
         # Restore original HEAD
+        if self._original_head is None:
+            raise RebaseError("No original HEAD to restore")
         self.repo.refs[b"HEAD"] = self._original_head
 
         # Clean up rebase state
@@ -1200,12 +1210,16 @@ def _squash_commits(
     if not entry.commit_sha:
         raise RebaseError("No commit SHA for squash/fixup operation")
     commit_to_squash = repo[entry.commit_sha]
+    if not isinstance(commit_to_squash, Commit):
+        raise RebaseError(f"Expected commit, got {type(commit_to_squash).__name__}")
 
     # Get the previous commit (target of squash)
     previous_commit = rebaser._done[-1]
 
     # Cherry-pick the changes onto the previous commit
     parent = repo[commit_to_squash.parents[0]]
+    if not isinstance(parent, Commit):
+        raise RebaseError(f"Expected parent commit, got {type(parent).__name__}")
 
     # Perform three-way merge for the tree
     merged_tree, conflicts = three_way_merge(

+ 259 - 242
dulwich/refs.py

@@ -25,9 +25,19 @@
 import os
 import types
 import warnings
-from collections.abc import Iterator
+from collections.abc import Iterable, Iterator
 from contextlib import suppress
-from typing import TYPE_CHECKING, Any, Optional, Union
+from typing import (
+    IO,
+    TYPE_CHECKING,
+    Any,
+    BinaryIO,
+    Callable,
+    Optional,
+    TypeVar,
+    Union,
+    cast,
+)
 
 if TYPE_CHECKING:
     from .file import _GitFile
@@ -55,13 +65,8 @@ ANNOTATED_TAG_SUFFIX = PEELED_TAG_SUFFIX
 class SymrefLoop(Exception):
     """There is a loop between one or more symrefs."""
 
-    def __init__(self, ref, depth) -> None:
-        """Initialize a SymrefLoop exception.
-
-        Args:
-          ref: The ref that caused the loop
-          depth: Depth at which the loop was detected
-        """
+    def __init__(self, ref: bytes, depth: int) -> None:
+        """Initialize SymrefLoop exception."""
         self.ref = ref
         self.depth = depth
 
@@ -142,23 +147,35 @@ def parse_remote_ref(ref: bytes) -> tuple[bytes, bytes]:
 class RefsContainer:
     """A container for refs."""
 
-    def __init__(self, logger=None) -> None:
-        """Initialize a RefsContainer.
-
-        Args:
-          logger: Optional logger for reflog updates
-        """
+    def __init__(
+        self,
+        logger: Optional[
+            Callable[
+                [
+                    bytes,
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[int],
+                    Optional[int],
+                    Optional[bytes],
+                ],
+                None,
+            ]
+        ] = None,
+    ) -> None:
+        """Initialize RefsContainer with optional logger function."""
         self._logger = logger
 
     def _log(
         self,
-        ref,
-        old_sha,
-        new_sha,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        ref: bytes,
+        old_sha: Optional[bytes],
+        new_sha: Optional[bytes],
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> None:
         if self._logger is None:
             return
@@ -168,12 +185,12 @@ class RefsContainer:
 
     def set_symbolic_ref(
         self,
-        name,
-        other,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        other: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> None:
         """Make a ref point at another ref.
 
@@ -206,7 +223,7 @@ class RefsContainer:
         """
         raise NotImplementedError(self.add_packed_refs)
 
-    def get_peeled(self, name) -> Optional[ObjectID]:
+    def get_peeled(self, name: bytes) -> Optional[ObjectID]:
         """Return the cached peeled value of a ref, if available.
 
         Args:
@@ -257,12 +274,12 @@ class RefsContainer:
         for ref in to_delete:
             self.remove_if_equals(b"/".join((base, ref)), None, message=message)
 
-    def allkeys(self) -> Iterator[Ref]:
+    def allkeys(self) -> set[Ref]:
         """All refs present in this container."""
         raise NotImplementedError(self.allkeys)
 
-    def __iter__(self):
-        """Iterate over all ref names."""
+    def __iter__(self) -> Iterator[Ref]:
+        """Iterate over all reference keys."""
         return iter(self.allkeys())
 
     def keys(self, base=None):
@@ -278,7 +295,7 @@ class RefsContainer:
         else:
             return self.allkeys()
 
-    def subkeys(self, base):
+    def subkeys(self, base: bytes) -> set[bytes]:
         """Refs present in this container under a base.
 
         Args:
@@ -293,7 +310,7 @@ class RefsContainer:
                 keys.add(refname[base_len:])
         return keys
 
-    def as_dict(self, base=None) -> dict[Ref, ObjectID]:
+    def as_dict(self, base: Optional[bytes] = None) -> dict[Ref, ObjectID]:
         """Return the contents of this container as a dictionary."""
         ret = {}
         keys = self.keys(base)
@@ -309,7 +326,7 @@ class RefsContainer:
 
         return ret
 
-    def _check_refname(self, name) -> None:
+    def _check_refname(self, name: bytes) -> None:
         """Ensure a refname is valid and lives in refs or is HEAD.
 
         HEAD is not a valid refname according to git-check-ref-format, but this
@@ -328,7 +345,7 @@ class RefsContainer:
         if not name.startswith(b"refs/") or not check_ref_format(name[5:]):
             raise RefFormatError(name)
 
-    def read_ref(self, refname):
+    def read_ref(self, refname: bytes) -> Optional[bytes]:
         """Read a reference without following any references.
 
         Args:
@@ -341,7 +358,7 @@ class RefsContainer:
             contents = self.get_packed_refs().get(refname, None)
         return contents
 
-    def read_loose_ref(self, name) -> bytes:
+    def read_loose_ref(self, name: bytes) -> Optional[bytes]:
         """Read a loose reference and return its contents.
 
         Args:
@@ -351,16 +368,16 @@ class RefsContainer:
         """
         raise NotImplementedError(self.read_loose_ref)
 
-    def follow(self, name) -> tuple[list[bytes], bytes]:
+    def follow(self, name: bytes) -> tuple[list[bytes], Optional[bytes]]:
         """Follow a reference name.
 
         Returns: a tuple of (refnames, sha), wheres refnames are the names of
             references in the chain
         """
-        contents = SYMREF + name
+        contents: Optional[bytes] = SYMREF + name
         depth = 0
         refnames = []
-        while contents.startswith(SYMREF):
+        while contents and contents.startswith(SYMREF):
             refname = contents[len(SYMREF) :]
             refnames.append(refname)
             contents = self.read_ref(refname)
@@ -371,20 +388,13 @@ class RefsContainer:
                 raise SymrefLoop(name, depth)
         return refnames, contents
 
-    def __contains__(self, refname) -> bool:
-        """Check if a ref exists.
-
-        Args:
-          refname: Name of the ref to check
-
-        Returns:
-          True if the ref exists
-        """
+    def __contains__(self, refname: bytes) -> bool:
+        """Check if a reference exists."""
         if self.read_ref(refname):
             return True
         return False
 
-    def __getitem__(self, name) -> ObjectID:
+    def __getitem__(self, name: bytes) -> ObjectID:
         """Get the SHA1 for a reference name.
 
         This method follows all symbolic references.
@@ -396,13 +406,13 @@ class RefsContainer:
 
     def set_if_equals(
         self,
-        name,
-        old_ref,
-        new_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        new_ref: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
         """Set a refname to new_ref only if it currently equals old_ref.
 
@@ -424,7 +434,13 @@ class RefsContainer:
         raise NotImplementedError(self.set_if_equals)
 
     def add_if_new(
-        self, name, ref, committer=None, timestamp=None, timezone=None, message=None
+        self,
+        name: bytes,
+        ref: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
         """Add a new reference only if it does not already exist.
 
@@ -438,7 +454,7 @@ class RefsContainer:
         """
         raise NotImplementedError(self.add_if_new)
 
-    def __setitem__(self, name, ref) -> None:
+    def __setitem__(self, name: bytes, ref: bytes) -> None:
         """Set a reference name to point to the given SHA1.
 
         This method follows all symbolic references if applicable for the
@@ -458,12 +474,12 @@ class RefsContainer:
 
     def remove_if_equals(
         self,
-        name,
-        old_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
         """Remove a refname only if it currently equals old_ref.
 
@@ -483,7 +499,7 @@ class RefsContainer:
         """
         raise NotImplementedError(self.remove_if_equals)
 
-    def __delitem__(self, name) -> None:
+    def __delitem__(self, name: bytes) -> None:
         """Remove a refname.
 
         This method does not follow symbolic references, even if applicable for
@@ -498,7 +514,7 @@ class RefsContainer:
         """
         self.remove_if_equals(name, None)
 
-    def get_symrefs(self):
+    def get_symrefs(self) -> dict[bytes, bytes]:
         """Get a dict with all symrefs in this container.
 
         Returns: Dictionary mapping source ref to target ref
@@ -506,7 +522,9 @@ class RefsContainer:
         ret = {}
         for src in self.allkeys():
             try:
-                dst = parse_symref_value(self.read_ref(src))
+                ref_value = self.read_ref(src)
+                assert ref_value is not None
+                dst = parse_symref_value(ref_value)
             except ValueError:
                 pass
             else:
@@ -529,41 +547,43 @@ class DictRefsContainer(RefsContainer):
     threadsafe.
     """
 
-    def __init__(self, refs, logger=None) -> None:
-        """Initialize DictRefsContainer."""
+    def __init__(
+        self,
+        refs: dict[bytes, bytes],
+        logger: Optional[
+            Callable[
+                [
+                    bytes,
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[int],
+                    Optional[int],
+                    Optional[bytes],
+                ],
+                None,
+            ]
+        ] = None,
+    ) -> None:
+        """Initialize DictRefsContainer with refs dictionary and optional logger."""
         super().__init__(logger=logger)
         self._refs = refs
         self._peeled: dict[bytes, ObjectID] = {}
         self._watchers: set[Any] = set()
 
-    def allkeys(self):
-        """Get all ref names.
+    def allkeys(self) -> set[bytes]:
+        """Return all reference keys."""
+        return set(self._refs.keys())
 
-        Returns:
-          All ref names in the container
-        """
-        return self._refs.keys()
-
-    def read_loose_ref(self, name):
-        """Read a reference from the refs dictionary.
-
-        Args:
-          name: The ref name to read
-
-        Returns:
-          The ref value or None if not found
-        """
+    def read_loose_ref(self, name: bytes) -> Optional[bytes]:
+        """Read a loose reference."""
         return self._refs.get(name, None)
 
-    def get_packed_refs(self):
-        """Get packed refs (always empty for DictRefsContainer).
-
-        Returns:
-          Empty dictionary
-        """
+    def get_packed_refs(self) -> dict[bytes, bytes]:
+        """Get packed references."""
         return {}
 
-    def _notify(self, ref, newsha) -> None:
+    def _notify(self, ref: bytes, newsha: Optional[bytes]) -> None:
         for watcher in self._watchers:
             watcher._notify((ref, newsha))
 
@@ -571,10 +591,10 @@ class DictRefsContainer(RefsContainer):
         self,
         name: Ref,
         other: Ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> None:
         """Make a ref point at another ref.
 
@@ -602,13 +622,13 @@ class DictRefsContainer(RefsContainer):
 
     def set_if_equals(
         self,
-        name,
-        old_ref,
-        new_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        new_ref: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
         """Set a refname to new_ref only if it currently equals old_ref.
 
@@ -650,9 +670,9 @@ class DictRefsContainer(RefsContainer):
         self,
         name: Ref,
         ref: ObjectID,
-        committer=None,
-        timestamp=None,
-        timezone=None,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
         message: Optional[bytes] = None,
     ) -> bool:
         """Add a new reference only if it does not already exist.
@@ -685,12 +705,12 @@ class DictRefsContainer(RefsContainer):
 
     def remove_if_equals(
         self,
-        name,
-        old_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
         """Remove a refname only if it currently equals old_ref.
 
@@ -728,25 +748,18 @@ class DictRefsContainer(RefsContainer):
             )
         return True
 
-    def get_peeled(self, name):
-        """Get the peeled value of a ref.
-
-        Args:
-          name: Ref name to get peeled value for
-
-        Returns:
-          The peeled SHA or None if not available
-        """
+    def get_peeled(self, name: bytes) -> Optional[bytes]:
+        """Get peeled version of a reference."""
         return self._peeled.get(name)
 
-    def _update(self, refs) -> None:
+    def _update(self, refs: dict[bytes, bytes]) -> None:
         """Update multiple refs; intended only for testing."""
         # TODO(dborowitz): replace this with a public function that uses
         # set_if_equal.
         for ref, sha in refs.items():
             self.set_if_equals(ref, None, sha)
 
-    def _update_peeled(self, peeled) -> None:
+    def _update_peeled(self, peeled: dict[bytes, bytes]) -> None:
         """Update cached peeled refs; intended only for testing."""
         self._peeled.update(peeled)
 
@@ -754,56 +767,27 @@ class DictRefsContainer(RefsContainer):
 class InfoRefsContainer(RefsContainer):
     """Refs container that reads refs from a info/refs file."""
 
-    def __init__(self, f) -> None:
-        """Initialize an InfoRefsContainer.
-
-        Args:
-          f: File-like object containing info/refs data
-        """
-        self._refs = {}
-        self._peeled = {}
+    def __init__(self, f: BinaryIO) -> None:
+        """Initialize InfoRefsContainer from info/refs file."""
+        self._refs: dict[bytes, bytes] = {}
+        self._peeled: dict[bytes, bytes] = {}
         refs = read_info_refs(f)
         (self._refs, self._peeled) = split_peeled_refs(refs)
 
-    def allkeys(self):
-        """Get all ref names.
+    def allkeys(self) -> set[bytes]:
+        """Return all reference keys."""
+        return set(self._refs.keys())
 
-        Returns:
-          All ref names in the info/refs file
-        """
-        return self._refs.keys()
-
-    def read_loose_ref(self, name):
-        """Read a reference from the parsed info/refs.
-
-        Args:
-          name: The ref name to read
-
-        Returns:
-          The ref value or None if not found
-        """
+    def read_loose_ref(self, name: bytes) -> Optional[bytes]:
+        """Read a loose reference."""
         return self._refs.get(name, None)
 
-    def get_packed_refs(self):
-        """Get packed refs (always empty for InfoRefsContainer).
-
-        Returns:
-          Empty dictionary
-        """
+    def get_packed_refs(self) -> dict[bytes, bytes]:
+        """Get packed references."""
         return {}
 
-    def get_peeled(self, name):
-        """Get the peeled value of a ref.
-
-        Args:
-          name: Ref name to get peeled value for
-
-        Returns:
-          The peeled SHA if available, otherwise the ref value itself
-
-        Raises:
-          KeyError: If the ref doesn't exist
-        """
+    def get_peeled(self, name: bytes) -> Optional[bytes]:
+        """Get peeled version of a reference."""
         try:
             return self._peeled[name]
         except KeyError:
@@ -817,7 +801,20 @@ class DiskRefsContainer(RefsContainer):
         self,
         path: Union[str, bytes, os.PathLike],
         worktree_path: Optional[Union[str, bytes, os.PathLike]] = None,
-        logger=None,
+        logger: Optional[
+            Callable[
+                [
+                    bytes,
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[int],
+                    Optional[int],
+                    Optional[bytes],
+                ],
+                None,
+            ]
+        ] = None,
     ) -> None:
         """Initialize DiskRefsContainer."""
         super().__init__(logger=logger)
@@ -827,22 +824,15 @@ class DiskRefsContainer(RefsContainer):
             self.worktree_path = self.path
         else:
             self.worktree_path = os.fsencode(os.fspath(worktree_path))
-        self._packed_refs = None
-        self._peeled_refs = None
+        self._packed_refs: Optional[dict[bytes, bytes]] = None
+        self._peeled_refs: Optional[dict[bytes, bytes]] = None
 
     def __repr__(self) -> str:
         """Return string representation of DiskRefsContainer."""
         return f"{self.__class__.__name__}({self.path!r})"
 
-    def subkeys(self, base):
-        """Get all ref names under a base ref.
-
-        Args:
-          base: Base ref path to search under
-
-        Returns:
-          Set of ref names under the base (without base prefix)
-        """
+    def subkeys(self, base: bytes) -> set[bytes]:
+        """Return subkeys under a given base reference path."""
         subkeys = set()
         path = self.refpath(base)
         for root, unused_dirs, files in os.walk(path):
@@ -861,12 +851,8 @@ class DiskRefsContainer(RefsContainer):
                 subkeys.add(key[len(base) :].strip(b"/"))
         return subkeys
 
-    def allkeys(self):
-        """Get all ref names from disk.
-
-        Returns:
-          Set of all ref names (both loose and packed)
-        """
+    def allkeys(self) -> set[bytes]:
+        """Return all reference keys."""
         allkeys = set()
         if os.path.exists(self.refpath(HEADREF)):
             allkeys.add(HEADREF)
@@ -883,7 +869,7 @@ class DiskRefsContainer(RefsContainer):
         allkeys.update(self.get_packed_refs())
         return allkeys
 
-    def refpath(self, name):
+    def refpath(self, name: bytes) -> bytes:
         """Return the disk path of a ref."""
         if os.path.sep != "/":
             name = name.replace(b"/", os.fsencode(os.path.sep))
@@ -894,7 +880,7 @@ class DiskRefsContainer(RefsContainer):
         else:
             return os.path.join(self.path, name)
 
-    def get_packed_refs(self):
+    def get_packed_refs(self) -> dict[bytes, bytes]:
         """Get contents of the packed-refs file.
 
         Returns: Dictionary mapping ref names to SHA1s
@@ -962,7 +948,7 @@ class DiskRefsContainer(RefsContainer):
 
             self._packed_refs = packed_refs
 
-    def get_peeled(self, name):
+    def get_peeled(self, name: bytes) -> Optional[bytes]:
         """Return the cached peeled value of a ref, if available.
 
         Args:
@@ -972,7 +958,11 @@ class DiskRefsContainer(RefsContainer):
             to a tag, but no cached information is available, None is returned.
         """
         self.get_packed_refs()
-        if self._peeled_refs is None or name not in self._packed_refs:
+        if (
+            self._peeled_refs is None
+            or self._packed_refs is None
+            or name not in self._packed_refs
+        ):
             # No cache: no peeled refs were read, or this ref is loose
             return None
         if name in self._peeled_refs:
@@ -981,7 +971,7 @@ class DiskRefsContainer(RefsContainer):
             # Known not peelable
             return self[name]
 
-    def read_loose_ref(self, name):
+    def read_loose_ref(self, name: bytes) -> Optional[bytes]:
         """Read a reference file and return its contents.
 
         If the reference file a symbolic reference, only read the first line of
@@ -1011,7 +1001,7 @@ class DiskRefsContainer(RefsContainer):
             # errors depending on the specific operating system
             return None
 
-    def _remove_packed_ref(self, name) -> None:
+    def _remove_packed_ref(self, name: bytes) -> None:
         if self._packed_refs is None:
             return
         filename = os.path.join(self.path, b"packed-refs")
@@ -1021,13 +1011,14 @@ class DiskRefsContainer(RefsContainer):
             self._packed_refs = None
             self.get_packed_refs()
 
-            if name not in self._packed_refs:
+            if self._packed_refs is None or name not in self._packed_refs:
                 f.abort()
                 return
 
             del self._packed_refs[name]
-            with suppress(KeyError):
-                del self._peeled_refs[name]
+            if self._peeled_refs is not None:
+                with suppress(KeyError):
+                    del self._peeled_refs[name]
             write_packed_refs(f, self._packed_refs, self._peeled_refs)
             f.close()
         except BaseException:
@@ -1036,12 +1027,12 @@ class DiskRefsContainer(RefsContainer):
 
     def set_symbolic_ref(
         self,
-        name,
-        other,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        other: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> None:
         """Make a ref point at another ref.
 
@@ -1077,13 +1068,13 @@ class DiskRefsContainer(RefsContainer):
 
     def set_if_equals(
         self,
-        name,
-        old_ref,
-        new_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        new_ref: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
         """Set a refname to new_ref only if it currently equals old_ref.
 
@@ -1163,9 +1154,9 @@ class DiskRefsContainer(RefsContainer):
         self,
         name: bytes,
         ref: bytes,
-        committer=None,
-        timestamp=None,
-        timezone=None,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
         message: Optional[bytes] = None,
     ) -> bool:
         """Add a new reference only if it does not already exist.
@@ -1215,12 +1206,12 @@ class DiskRefsContainer(RefsContainer):
 
     def remove_if_equals(
         self,
-        name,
-        old_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
         """Remove a refname only if it currently equals old_ref.
 
@@ -1321,7 +1312,7 @@ class DiskRefsContainer(RefsContainer):
             self.add_packed_refs(refs_to_pack)
 
 
-def _split_ref_line(line):
+def _split_ref_line(line: bytes) -> tuple[bytes, bytes]:
     """Split a single ref line into a tuple of SHA1 and name."""
     fields = line.rstrip(b"\n\r").split(b" ")
     if len(fields) != 2:
@@ -1334,7 +1325,7 @@ def _split_ref_line(line):
     return (sha, name)
 
 
-def read_packed_refs(f):
+def read_packed_refs(f: IO[bytes]) -> Iterator[tuple[bytes, bytes]]:
     """Read a packed refs file.
 
     Args:
@@ -1350,7 +1341,9 @@ def read_packed_refs(f):
         yield _split_ref_line(line)
 
 
-def read_packed_refs_with_peeled(f):
+def read_packed_refs_with_peeled(
+    f: IO[bytes],
+) -> Iterator[tuple[bytes, bytes, Optional[bytes]]]:
     """Read a packed refs file including peeled refs.
 
     Assumes the "# pack-refs with: peeled" line was already read. Yields tuples
@@ -1382,7 +1375,11 @@ def read_packed_refs_with_peeled(f):
         yield (sha, name, None)
 
 
-def write_packed_refs(f, packed_refs, peeled_refs=None) -> None:
+def write_packed_refs(
+    f: IO[bytes],
+    packed_refs: dict[bytes, bytes],
+    peeled_refs: Optional[dict[bytes, bytes]] = None,
+) -> None:
     """Write a packed refs file.
 
     Args:
@@ -1400,7 +1397,7 @@ def write_packed_refs(f, packed_refs, peeled_refs=None) -> None:
             f.write(b"^" + peeled_refs[refname] + b"\n")
 
 
-def read_info_refs(f):
+def read_info_refs(f: BinaryIO) -> dict[bytes, bytes]:
     """Read info/refs file.
 
     Args:
@@ -1416,7 +1413,9 @@ def read_info_refs(f):
     return ret
 
 
-def write_info_refs(refs, store: ObjectContainer):
+def write_info_refs(
+    refs: dict[bytes, bytes], store: ObjectContainer
+) -> Iterator[bytes]:
     """Generate info refs."""
     # TODO: Avoid recursive import :(
     from .object_store import peel_sha
@@ -1436,38 +1435,38 @@ def write_info_refs(refs, store: ObjectContainer):
             yield peeled.id + b"\t" + name + PEELED_TAG_SUFFIX + b"\n"
 
 
-def is_local_branch(x):
-    """Check if a ref name refers to a local branch.
+def is_local_branch(x: bytes) -> bool:
+    """Check if a ref name is a local branch."""
+    return x.startswith(LOCAL_BRANCH_PREFIX)
 
-    Args:
-      x: Ref name to check
 
-    Returns:
-      True if ref is a local branch (refs/heads/...)
-    """
-    return x.startswith(LOCAL_BRANCH_PREFIX)
+T = TypeVar("T", dict[bytes, bytes], dict[bytes, Optional[bytes]])
 
 
-def strip_peeled_refs(refs):
+def strip_peeled_refs(refs: T) -> T:
     """Remove all peeled refs."""
     return {
         ref: sha for (ref, sha) in refs.items() if not ref.endswith(PEELED_TAG_SUFFIX)
     }
 
 
-def split_peeled_refs(refs):
+def split_peeled_refs(refs: T) -> tuple[T, dict[bytes, bytes]]:
     """Split peeled refs from regular refs."""
-    peeled = {}
-    regular = {}
+    peeled: dict[bytes, bytes] = {}
+    regular = {k: v for k, v in refs.items() if not k.endswith(PEELED_TAG_SUFFIX)}
+
     for ref, sha in refs.items():
         if ref.endswith(PEELED_TAG_SUFFIX):
-            peeled[ref[: -len(PEELED_TAG_SUFFIX)]] = sha
-        else:
-            regular[ref] = sha
+            # Only add to peeled dict if sha is not None
+            if sha is not None:
+                peeled[ref[: -len(PEELED_TAG_SUFFIX)]] = sha
+
     return regular, peeled
 
 
-def _set_origin_head(refs, origin, origin_head) -> None:
+def _set_origin_head(
+    refs: RefsContainer, origin: bytes, origin_head: Optional[bytes]
+) -> None:
     # set refs/remotes/origin/HEAD
     origin_base = b"refs/remotes/" + origin + b"/"
     if origin_head and origin_head.startswith(LOCAL_BRANCH_PREFIX):
@@ -1511,7 +1510,9 @@ def _set_default_branch(
     return head_ref
 
 
-def _set_head(refs, head_ref, ref_message):
+def _set_head(
+    refs: RefsContainer, head_ref: bytes, ref_message: Optional[bytes]
+) -> Optional[bytes]:
     if head_ref.startswith(LOCAL_TAG_PREFIX):
         # detach HEAD at specified tag
         head = refs[head_ref]
@@ -1534,7 +1535,7 @@ def _set_head(refs, head_ref, ref_message):
 def _import_remote_refs(
     refs_container: RefsContainer,
     remote_name: str,
-    refs: dict[str, str],
+    refs: dict[bytes, Optional[bytes]],
     message: Optional[bytes] = None,
     prune: bool = False,
     prune_tags: bool = False,
@@ -1543,7 +1544,7 @@ def _import_remote_refs(
     branches = {
         n[len(LOCAL_BRANCH_PREFIX) :]: v
         for (n, v) in stripped_refs.items()
-        if n.startswith(LOCAL_BRANCH_PREFIX)
+        if n.startswith(LOCAL_BRANCH_PREFIX) and v is not None
     }
     refs_container.import_refs(
         b"refs/remotes/" + remote_name.encode(),
@@ -1554,14 +1555,18 @@ def _import_remote_refs(
     tags = {
         n[len(LOCAL_TAG_PREFIX) :]: v
         for (n, v) in stripped_refs.items()
-        if n.startswith(LOCAL_TAG_PREFIX) and not n.endswith(PEELED_TAG_SUFFIX)
+        if n.startswith(LOCAL_TAG_PREFIX)
+        and not n.endswith(PEELED_TAG_SUFFIX)
+        and v is not None
     }
     refs_container.import_refs(
         LOCAL_TAG_PREFIX, tags, message=message, prune=prune_tags
     )
 
 
-def serialize_refs(store, refs):
+def serialize_refs(
+    store: ObjectContainer, refs: dict[bytes, bytes]
+) -> dict[bytes, bytes]:
     """Serialize refs with peeled refs.
 
     Args:
@@ -1658,6 +1663,7 @@ class locked_ref:
         if not self._file:
             raise RuntimeError("locked_ref not in context")
 
+        assert self._realname is not None
         current_ref = self._refs_container.read_loose_ref(self._realname)
         if current_ref is None:
             current_ref = self._refs_container.get_packed_refs().get(
@@ -1724,3 +1730,14 @@ class locked_ref:
             self._refs_container._remove_packed_ref(self._realname)
 
         self._deleted = True
+
+
+def filter_ref_prefix(refs: T, prefixes: Iterable[bytes]) -> T:
+    """Filter refs to only include those with a given prefix.
+
+    Args:
+      refs: A dictionary of refs.
+      prefixes: The prefixes to filter by.
+    """
+    filtered = {k: v for k, v in refs.items() if any(k.startswith(p) for p in prefixes)}
+    return cast(T, filtered)

+ 1 - 1
dulwich/reftable.py

@@ -700,7 +700,7 @@ class ReftableReader:
         # Read magic bytes
         magic = self.f.read(4)
         if magic != REFTABLE_MAGIC:
-            raise ValueError(f"Invalid reftable magic: {magic}")
+            raise ValueError(f"Invalid reftable magic: {magic!r}")
 
         # Read version + block size (4 bytes total, big-endian network order)
         # Format: uint8(version) + uint24(block_size)

+ 87 - 74
dulwich/repo.py

@@ -34,7 +34,7 @@ import stat
 import sys
 import time
 import warnings
-from collections.abc import Iterable
+from collections.abc import Iterable, Iterator
 from io import BytesIO
 from typing import (
     TYPE_CHECKING,
@@ -42,6 +42,7 @@ from typing import (
     BinaryIO,
     Callable,
     Optional,
+    TypeVar,
     Union,
 )
 
@@ -52,7 +53,11 @@ if TYPE_CHECKING:
     from .attrs import GitAttributes
     from .config import ConditionMatcher, ConfigFile, StackedConfig
     from .index import Index
+    from .line_ending import BlobNormalizer
     from .notes import Notes
+    from .object_store import BaseObjectStore, GraphWalker, UnpackedObject
+    from .rebase import RebaseStateManager
+    from .walk import Walker
     from .worktree import WorkTree
 
 from . import replace_me
@@ -115,6 +120,8 @@ from .refs import (
 
 CONTROLDIR = ".git"
 OBJECTDIR = "objects"
+
+T = TypeVar("T", bound="ShaFile")
 REFSDIR = "refs"
 REFSDIR_TAGS = "tags"
 REFSDIR_HEADS = "heads"
@@ -138,12 +145,8 @@ DEFAULT_BRANCH = b"master"
 class InvalidUserIdentity(Exception):
     """User identity is not of the format 'user <email>'."""
 
-    def __init__(self, identity) -> None:
-        """Initialize InvalidUserIdentity exception.
-
-        Args:
-            identity: The invalid identity string
-        """
+    def __init__(self, identity: str) -> None:
+        """Initialize InvalidUserIdentity exception."""
         self.identity = identity
 
 
@@ -241,7 +244,7 @@ def get_user_identity(config: "StackedConfig", kind: Optional[str] = None) -> by
     return user + b" <" + email + b">"
 
 
-def check_user_identity(identity) -> None:
+def check_user_identity(identity: bytes) -> None:
     """Verify that a user identity is formatted correctly.
 
     Args:
@@ -252,11 +255,11 @@ def check_user_identity(identity) -> None:
     try:
         fst, snd = identity.split(b" <", 1)
     except ValueError as exc:
-        raise InvalidUserIdentity(identity) from exc
+        raise InvalidUserIdentity(identity.decode("utf-8", "replace")) from exc
     if b">" not in snd:
-        raise InvalidUserIdentity(identity)
+        raise InvalidUserIdentity(identity.decode("utf-8", "replace"))
     if b"\0" in identity or b"\n" in identity:
-        raise InvalidUserIdentity(identity)
+        raise InvalidUserIdentity(identity.decode("utf-8", "replace"))
 
 
 def parse_graftpoints(
@@ -313,7 +316,7 @@ def serialize_graftpoints(graftpoints: dict[bytes, list[bytes]]) -> bytes:
     return b"\n".join(graft_lines)
 
 
-def _set_filesystem_hidden(path) -> None:
+def _set_filesystem_hidden(path: str) -> None:
     """Mark path as to be hidden if supported by platform and filesystem.
 
     On win32 uses SetFileAttributesW api:
@@ -337,15 +340,20 @@ def _set_filesystem_hidden(path) -> None:
 
 
 class ParentsProvider:
-    """Provides parents for commits, handling grafts and shallow commits."""
+    """Provider for commit parent information."""
 
-    def __init__(self, store, grafts={}, shallows=[]) -> None:
+    def __init__(
+        self,
+        store: "BaseObjectStore",
+        grafts: dict = {},
+        shallows: Iterable[bytes] = [],
+    ) -> None:
         """Initialize ParentsProvider.
 
         Args:
-            store: Object store to get commits from
-            grafts: Dictionary mapping commit ids to parent ids
-            shallows: List of shallow commit ids
+            store: Object store to use
+            grafts: Graft information
+            shallows: Shallow commit SHAs
         """
         self.store = store
         self.grafts = grafts
@@ -354,16 +362,10 @@ class ParentsProvider:
         # Get commit graph once at initialization for performance
         self.commit_graph = store.get_commit_graph()
 
-    def get_parents(self, commit_id, commit=None):
-        """Get the parents of a commit.
-
-        Args:
-          commit_id: The commit SHA to get parents for
-          commit: Optional commit object to avoid fetching
-
-        Returns:
-          List of parent commit SHAs
-        """
+    def get_parents(
+        self, commit_id: bytes, commit: Optional[Commit] = None
+    ) -> list[bytes]:
+        """Get parents for a commit using the parents provider."""
         try:
             return self.grafts[commit_id]
         except KeyError:
@@ -379,7 +381,9 @@ class ParentsProvider:
 
         # Fallback to reading the commit object
         if commit is None:
-            commit = self.store[commit_id]
+            obj = self.store[commit_id]
+            assert isinstance(obj, Commit)
+            commit = obj
         return commit.parents
 
 
@@ -494,8 +498,12 @@ class BaseRepo:
         raise NotImplementedError(self.open_index)
 
     def fetch(
-        self, target, determine_wants=None, progress=None, depth: Optional[int] = None
-    ):
+        self,
+        target: "BaseRepo",
+        determine_wants: Optional[Callable] = None,
+        progress: Optional[Callable] = None,
+        depth: Optional[int] = None,
+    ) -> dict:
         """Fetch objects into another repository.
 
         Args:
@@ -519,13 +527,13 @@ class BaseRepo:
 
     def fetch_pack_data(
         self,
-        determine_wants,
-        graph_walker,
-        progress,
+        determine_wants: Callable,
+        graph_walker: "GraphWalker",
+        progress: Optional[Callable],
         *,
-        get_tagged=None,
+        get_tagged: Optional[Callable] = None,
         depth: Optional[int] = None,
-    ):
+    ) -> tuple:
         """Fetch the pack data required for a set of revisions.
 
         Args:
@@ -554,11 +562,11 @@ class BaseRepo:
 
     def find_missing_objects(
         self,
-        determine_wants,
-        graph_walker,
-        progress,
+        determine_wants: Callable,
+        graph_walker: "GraphWalker",
+        progress: Optional[Callable],
         *,
-        get_tagged=None,
+        get_tagged: Optional[Callable] = None,
         depth: Optional[int] = None,
     ) -> Optional[MissingObjectFinder]:
         """Fetch the missing objects required for a set of revisions.
@@ -585,16 +593,17 @@ class BaseRepo:
         current_shallow = set(getattr(graph_walker, "shallow", set()))
 
         if depth not in (None, 0):
+            assert depth is not None
             shallow, not_shallow = find_shallow(self.object_store, wants, depth)
             # Only update if graph_walker has shallow attribute
             if hasattr(graph_walker, "shallow"):
                 graph_walker.shallow.update(shallow - not_shallow)
                 new_shallow = graph_walker.shallow - current_shallow
-                unshallow = graph_walker.unshallow = not_shallow & current_shallow
+                unshallow = graph_walker.unshallow = not_shallow & current_shallow  # type: ignore[attr-defined]
                 if hasattr(graph_walker, "update_shallow"):
                     graph_walker.update_shallow(new_shallow, unshallow)
         else:
-            unshallow = getattr(graph_walker, "unshallow", frozenset())
+            unshallow = getattr(graph_walker, "unshallow", set())
 
         if wants == []:
             # TODO(dborowitz): find a way to short-circuit that doesn't change
@@ -618,7 +627,7 @@ class BaseRepo:
                 def __len__(self) -> int:
                     return 0
 
-                def __iter__(self):
+                def __iter__(self) -> Iterator[tuple[bytes, Optional[bytes]]]:
                     yield from []
 
             return DummyMissingObjectFinder()  # type: ignore
@@ -637,7 +646,7 @@ class BaseRepo:
 
         parents_provider = ParentsProvider(self.object_store, shallows=current_shallow)
 
-        def get_parents(commit):
+        def get_parents(commit: Commit) -> list[bytes]:
             """Get parents for a commit using the parents provider.
 
             Args:
@@ -660,11 +669,11 @@ class BaseRepo:
 
     def generate_pack_data(
         self,
-        have: list[ObjectID],
-        want: list[ObjectID],
+        have: Iterable[ObjectID],
+        want: Iterable[ObjectID],
         progress: Optional[Callable[[str], None]] = None,
         ofs_delta: Optional[bool] = None,
-    ):
+    ) -> tuple[int, Iterator["UnpackedObject"]]:
         """Generate pack data objects for a set of wants/haves.
 
         Args:
@@ -719,18 +728,18 @@ class BaseRepo:
         # TODO: move this method to WorkTree
         return self.refs[b"HEAD"]
 
-    def _get_object(self, sha, cls):
+    def _get_object(self, sha: bytes, cls: type[T]) -> T:
         assert len(sha) in (20, 40)
         ret = self.get_object(sha)
         if not isinstance(ret, cls):
             if cls is Commit:
-                raise NotCommitError(ret)
+                raise NotCommitError(ret.id)
             elif cls is Blob:
-                raise NotBlobError(ret)
+                raise NotBlobError(ret.id)
             elif cls is Tree:
-                raise NotTreeError(ret)
+                raise NotTreeError(ret.id)
             elif cls is Tag:
-                raise NotTagError(ret)
+                raise NotTagError(ret.id)
             else:
                 raise Exception(f"Type invalid: {ret.type_name!r} != {cls.type_name!r}")
         return ret
@@ -790,7 +799,7 @@ class BaseRepo:
         """
         raise NotImplementedError(self.get_description)
 
-    def set_description(self, description) -> None:
+    def set_description(self, description: bytes) -> None:
         """Set the description for this repository.
 
         Args:
@@ -798,14 +807,14 @@ class BaseRepo:
         """
         raise NotImplementedError(self.set_description)
 
-    def get_rebase_state_manager(self):
+    def get_rebase_state_manager(self) -> "RebaseStateManager":
         """Get the appropriate rebase state manager for this repository.
 
         Returns: RebaseStateManager instance
         """
         raise NotImplementedError(self.get_rebase_state_manager)
 
-    def get_blob_normalizer(self):
+    def get_blob_normalizer(self) -> "BlobNormalizer":
         """Return a BlobNormalizer object for checkin/checkout operations.
 
         Returns: BlobNormalizer instance
@@ -853,7 +862,9 @@ class BaseRepo:
         with f:
             return {line.strip() for line in f}
 
-    def update_shallow(self, new_shallow, new_unshallow) -> None:
+    def update_shallow(
+        self, new_shallow: Optional[set[bytes]], new_unshallow: Optional[set[bytes]]
+    ) -> None:
         """Update the list of shallow objects.
 
         Args:
@@ -895,7 +906,7 @@ class BaseRepo:
 
         return Notes(self.object_store, self.refs)
 
-    def get_walker(self, include: Optional[list[bytes]] = None, **kwargs):
+    def get_walker(self, include: Optional[list[bytes]] = None, **kwargs) -> "Walker":
         """Obtain a walker for this repository.
 
         Args:
@@ -932,7 +943,7 @@ class BaseRepo:
 
         return Walker(self.object_store, include, **kwargs)
 
-    def __getitem__(self, name: Union[ObjectID, Ref]):
+    def __getitem__(self, name: Union[ObjectID, Ref]) -> "ShaFile":
         """Retrieve a Git object by SHA1 or ref.
 
         Args:
@@ -1024,7 +1035,7 @@ class BaseRepo:
         for sha in to_remove:
             del self._graftpoints[sha]
 
-    def _read_heads(self, name):
+    def _read_heads(self, name: str) -> list[bytes]:
         f = self.get_named_file(name)
         if f is None:
             return []
@@ -1050,17 +1061,17 @@ class BaseRepo:
         message: Optional[bytes] = None,
         committer: Optional[bytes] = None,
         author: Optional[bytes] = None,
-        commit_timestamp=None,
-        commit_timezone=None,
-        author_timestamp=None,
-        author_timezone=None,
+        commit_timestamp: Optional[float] = None,
+        commit_timezone: Optional[int] = None,
+        author_timestamp: Optional[float] = None,
+        author_timezone: Optional[int] = None,
         tree: Optional[ObjectID] = None,
         encoding: Optional[bytes] = None,
         ref: Optional[Ref] = b"HEAD",
         merge_heads: Optional[list[ObjectID]] = None,
         no_verify: bool = False,
         sign: bool = False,
-    ):
+    ) -> bytes:
         """Create a new commit.
 
         If not specified, committer and author default to
@@ -1109,7 +1120,7 @@ class BaseRepo:
         )
 
 
-def read_gitfile(f):
+def read_gitfile(f: BinaryIO) -> str:
     """Read a ``.git`` file.
 
     The first line of the file should start with "gitdir: "
@@ -1119,9 +1130,9 @@ def read_gitfile(f):
     Returns: A path
     """
     cs = f.read()
-    if not cs.startswith("gitdir: "):
+    if not cs.startswith(b"gitdir: "):
         raise ValueError("Expected file to start with 'gitdir: '")
-    return cs[len("gitdir: ") :].rstrip("\n")
+    return cs[len(b"gitdir: ") :].rstrip(b"\n").decode("utf-8")
 
 
 class UnsupportedVersion(Exception):
@@ -1205,7 +1216,7 @@ class Repo(BaseRepo):
         self.bare = bare
         if bare is False:
             if os.path.isfile(hidden_path):
-                with open(hidden_path) as f:
+                with open(hidden_path, "rb") as f:
                     path = read_gitfile(f)
                 self._controldir = os.path.join(root, path)
             else:
@@ -1364,11 +1375,11 @@ class Repo(BaseRepo):
             "No git repository was found at {path}".format(**dict(path=start))
         )
 
-    def controldir(self):
+    def controldir(self) -> str:
         """Return the path of the control directory."""
         return self._controldir
 
-    def commondir(self):
+    def commondir(self) -> str:
         """Return the path of the common directory.
 
         For a main working tree, it is identical to controldir().
@@ -1378,7 +1389,7 @@ class Repo(BaseRepo):
         """
         return self._commondir
 
-    def _determine_file_mode(self):
+    def _determine_file_mode(self) -> bool:
         """Probe the file-system to determine whether permissions can be trusted.
 
         Returns: True if permissions can be trusted, False otherwise.
@@ -1401,7 +1412,7 @@ class Repo(BaseRepo):
 
         return mode_differs and st2_has_exec
 
-    def _determine_symlinks(self):
+    def _determine_symlinks(self) -> bool:
         """Probe the filesystem to determine whether symlinks can be created.
 
         Returns: True if symlinks can be created, False otherwise.
@@ -1409,7 +1420,7 @@ class Repo(BaseRepo):
         # TODO(jelmer): Actually probe disk / look at filesystem
         return sys.platform != "win32"
 
-    def _put_named_file(self, path, contents) -> None:
+    def _put_named_file(self, path: str, contents: bytes) -> None:
         """Write a file to the control dir with the given name and contents.
 
         Args:
@@ -1420,7 +1431,7 @@ class Repo(BaseRepo):
         with GitFile(os.path.join(self.controldir(), path), "wb") as f:
             f.write(contents)
 
-    def _del_named_file(self, path) -> None:
+    def _del_named_file(self, path: str) -> None:
         try:
             os.unlink(os.path.join(self.controldir(), path))
         except FileNotFoundError:
@@ -2040,6 +2051,7 @@ class Repo(BaseRepo):
                 if isinstance(head, Tag):
                     _cls, obj = head.object
                     head = self.get_object(obj)
+                assert isinstance(head, Commit)
                 tree = head.tree
             except KeyError:
                 # No HEAD, no attributes from tree
@@ -2048,6 +2060,7 @@ class Repo(BaseRepo):
         if tree is not None:
             try:
                 tree_obj = self[tree]
+                assert isinstance(tree_obj, Tree)
                 if b".gitattributes" in tree_obj:
                     _, attrs_sha = tree_obj[b".gitattributes"]
                     attrs_blob = self[attrs_sha]
@@ -2136,7 +2149,7 @@ class MemoryRepo(BaseRepo):
 
         self._reflog: list[Any] = []
         refs_container = DictRefsContainer({}, logger=self._append_reflog)
-        BaseRepo.__init__(self, MemoryObjectStore(), refs_container)  # type: ignore
+        BaseRepo.__init__(self, MemoryObjectStore(), refs_container)  # type: ignore[arg-type]
         self._named_files: dict[str, bytes] = {}
         self.bare = True
         self._config = ConfigFile()

+ 16 - 3
dulwich/server.py

@@ -52,7 +52,7 @@ import time
 import zlib
 from collections.abc import Iterable, Iterator
 from functools import partial
-from typing import TYPE_CHECKING, Optional, cast
+from typing import TYPE_CHECKING, Optional
 from typing import Protocol as TypingProtocol
 
 if TYPE_CHECKING:
@@ -551,10 +551,12 @@ def _split_proto_line(line, allowed):
             COMMAND_SHALLOW,
             COMMAND_UNSHALLOW,
         ):
+            assert fields[1] is not None
             if not valid_hexsha(fields[1]):
                 raise GitProtocolError("Invalid sha")
             return tuple(fields)
         elif command == COMMAND_DEEPEN:
+            assert fields[1] is not None
             return command, int(fields[1])
     raise GitProtocolError(f"Received invalid line from client: {line!r}")
 
@@ -633,6 +635,10 @@ class AckGraphWalkerImpl:
         """
         raise NotImplementedError
 
+    def handle_done(self, done_required, done_received):
+        """Handle 'done' packet from client."""
+        raise NotImplementedError
+
 
 class _ProtocolGraphWalker:
     """A graph walker that knows the git protocol.
@@ -784,6 +790,7 @@ class _ProtocolGraphWalker:
         """
         if len(have_ref) != 40:
             raise ValueError(f"invalid sha {have_ref!r}")
+        assert self._impl is not None
         return self._impl.ack(have_ref)
 
     def reset(self) -> None:
@@ -799,7 +806,8 @@ class _ProtocolGraphWalker:
         if not self._cached:
             if not self._impl and self.stateless_rpc:
                 return None
-            return next(self._impl)
+            assert self._impl is not None
+            return next(self._impl)  # type: ignore[call-overload]
         self._cache_index += 1
         if self._cache_index > len(self._cache):
             return None
@@ -884,6 +892,7 @@ class _ProtocolGraphWalker:
         Returns: True if done handling succeeded
         """
         # Delegate this to the implementation.
+        assert self._impl is not None
         return self._impl.handle_done(done_required, done_received)
 
     def set_wants(self, wants) -> None:
@@ -1400,7 +1409,11 @@ class UploadArchiveHandler(Handler):
                 format = arguments[i].decode("ascii")
             else:
                 commit_sha = self.repo.refs[argument]
-                tree = cast(Tree, store[cast(Commit, store[commit_sha]).tree])
+                commit_obj = store[commit_sha]
+                assert isinstance(commit_obj, Commit)
+                tree_obj = store[commit_obj.tree]
+                assert isinstance(tree_obj, Tree)
+                tree = tree_obj
             i += 1
         self.proto.write_pkt_line(b"ACK")
         self.proto.write_pkt_line(None)

+ 12 - 6
dulwich/stash.py

@@ -23,7 +23,7 @@
 
 import os
 import sys
-from typing import TYPE_CHECKING, Optional, TypedDict
+from typing import TYPE_CHECKING, Optional, TypedDict, Union
 
 from .diff_tree import tree_changes
 from .file import GitFile
@@ -162,10 +162,16 @@ class Stash:
             symlink_fn = symlink
         else:
 
-            def symlink_fn(source, target) -> None:  # type: ignore
-                mode = "w" + ("b" if isinstance(source, bytes) else "")
-                with open(target, mode) as f:
-                    f.write(source)
+            def symlink_fn(
+                src: Union[str, bytes, os.PathLike],
+                dst: Union[str, bytes, os.PathLike],
+                target_is_directory: bool = False,
+                *,
+                dir_fd: Optional[int] = None,
+            ) -> None:
+                mode = "w" + ("b" if isinstance(src, bytes) else "")
+                with open(dst, mode) as f:
+                    f.write(src)
 
         # Get blob normalizer for line ending conversion
         blob_normalizer = self._repo.get_blob_normalizer()
@@ -228,7 +234,7 @@ class Stash:
                     entry.mode,
                     full_path,
                     honor_filemode=honor_filemode,
-                    symlink_fn=symlink_fn,
+                    symlink_fn=symlink_fn,  # type: ignore[arg-type]
                 )
 
             # Update index if the file wasn't already staged

+ 1 - 1
dulwich/tests/test_object_store.py

@@ -444,7 +444,7 @@ class FindShallowTests(TestCase):
     def make_linear_commits(self, n, message=b""):
         """Create a linear chain of commits."""
         commits = []
-        parents = []
+        parents: list[bytes] = []
         for _ in range(n):
             commits.append(self.make_commit(parents=parents, message=message))
             parents = [commits[-1].id]

+ 51 - 28
dulwich/tests/utils.py

@@ -21,6 +21,8 @@
 
 """Utility functions common to Dulwich tests."""
 
+# ruff: noqa: ANN401
+
 import datetime
 import os
 import shutil
@@ -28,10 +30,12 @@ import tempfile
 import time
 import types
 import warnings
+from typing import Any, BinaryIO, Callable, Optional, TypeVar, Union
 from unittest import SkipTest
 
 from dulwich.index import commit_tree
-from dulwich.objects import Commit, FixedSha, Tag, object_class
+from dulwich.object_store import BaseObjectStore
+from dulwich.objects import Commit, FixedSha, ShaFile, Tag, object_class
 from dulwich.pack import (
     DELTA_TYPES,
     OFS_DELTA,
@@ -47,8 +51,10 @@ from dulwich.repo import Repo
 # Plain files are very frequently used in tests, so let the mode be very short.
 F = 0o100644  # Shorthand mode for Files.
 
+T = TypeVar("T", bound=ShaFile)
+
 
-def open_repo(name, temp_dir=None):
+def open_repo(name: str, temp_dir: Optional[str] = None) -> Repo:
     """Open a copy of a repo in a temporary directory.
 
     Use this function for accessing repos in dulwich/tests/data/repos to avoid
@@ -72,14 +78,14 @@ def open_repo(name, temp_dir=None):
     return Repo(temp_repo_dir)
 
 
-def tear_down_repo(repo) -> None:
+def tear_down_repo(repo: Repo) -> None:
     """Tear down a test repository."""
     repo.close()
     temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
     shutil.rmtree(temp_dir)
 
 
-def make_object(cls, **attrs):
+def make_object(cls: type[T], **attrs: Any) -> T:
     """Make an object for testing and assign some members.
 
     This method creates a new subclass to allow arbitrary attribute
@@ -92,7 +98,7 @@ def make_object(cls, **attrs):
     Returns: A newly initialized object of type cls.
     """
 
-    class TestObject(cls):
+    class TestObject(cls):  # type: ignore[misc,valid-type]
         """Class that inherits from the given class, but without __slots__.
 
         Note that classes with __slots__ can't have arbitrary attributes
@@ -102,7 +108,7 @@ def make_object(cls, **attrs):
 
     TestObject.__name__ = "TestObject_" + cls.__name__
 
-    obj = TestObject()
+    obj = TestObject()  # type: ignore[abstract]
     for name, value in attrs.items():
         if name == "id":
             # id property is read-only, so we overwrite sha instead.
@@ -113,7 +119,7 @@ def make_object(cls, **attrs):
     return obj
 
 
-def make_commit(**attrs):
+def make_commit(**attrs: Any) -> Commit:
     """Make a Commit object with a default set of members.
 
     Args:
@@ -136,7 +142,7 @@ def make_commit(**attrs):
     return make_object(Commit, **all_attrs)
 
 
-def make_tag(target, **attrs):
+def make_tag(target: ShaFile, **attrs: Any) -> Tag:
     """Make a Tag object with a default set of values.
 
     Args:
@@ -159,16 +165,20 @@ def make_tag(target, **attrs):
     return make_object(Tag, **all_attrs)
 
 
-def functest_builder(method, func):
+def functest_builder(
+    method: Callable[[Any, Any], None], func: Any
+) -> Callable[[Any], None]:
     """Generate a test method that tests the given function."""
 
-    def do_test(self) -> None:
+    def do_test(self: Any) -> None:
         method(self, func)
 
     return do_test
 
 
-def ext_functest_builder(method, func):
+def ext_functest_builder(
+    method: Callable[[Any, Any], None], func: Any
+) -> Callable[[Any], None]:
     """Generate a test method that tests the given extension function.
 
     This is intended to generate test methods that test both a pure-Python
@@ -190,7 +200,7 @@ def ext_functest_builder(method, func):
       func: The function implementation to pass to method.
     """
 
-    def do_test(self) -> None:
+    def do_test(self: Any) -> None:
         if not isinstance(func, types.BuiltinFunctionType):
             raise SkipTest(f"{func} extension not found")
         method(self, func)
@@ -198,7 +208,11 @@ def ext_functest_builder(method, func):
     return do_test
 
 
-def build_pack(f, objects_spec, store=None):
+def build_pack(
+    f: BinaryIO,
+    objects_spec: list[tuple[int, Any]],
+    store: Optional[BaseObjectStore] = None,
+) -> list[tuple[int, int, bytes, bytes, int]]:
     """Write test pack data from a concise spec.
 
     Args:
@@ -221,14 +235,14 @@ def build_pack(f, objects_spec, store=None):
     num_objects = len(objects_spec)
     write_pack_header(sf.write, num_objects)
 
-    full_objects = {}
-    offsets = {}
-    crc32s = {}
+    full_objects: dict[int, tuple[int, bytes, bytes]] = {}
+    offsets: dict[int, int] = {}
+    crc32s: dict[int, int] = {}
 
     while len(full_objects) < num_objects:
         for i, (type_num, data) in enumerate(objects_spec):
             if type_num not in DELTA_TYPES:
-                full_objects[i] = (type_num, data, obj_sha(type_num, [data]))
+                full_objects[i] = (type_num, data, obj_sha(type_num, [data]))  # type: ignore[no-untyped-call]
                 continue
             base, data = data
             if isinstance(base, int):
@@ -236,11 +250,12 @@ def build_pack(f, objects_spec, store=None):
                     continue
                 base_type_num, _, _ = full_objects[base]
             else:
+                assert store is not None
                 base_type_num, _ = store.get_raw(base)
             full_objects[i] = (
                 base_type_num,
                 data,
-                obj_sha(base_type_num, [data]),
+                obj_sha(base_type_num, [data]),  # type: ignore[no-untyped-call]
             )
 
     for i, (type_num, obj) in enumerate(objects_spec):
@@ -249,17 +264,18 @@ def build_pack(f, objects_spec, store=None):
             base_index, data = obj
             base = offset - offsets[base_index]
             _, base_data, _ = full_objects[base_index]
-            obj = (base, list(create_delta(base_data, data)))
+            obj = (base, list(create_delta(base_data, data)))  # type: ignore[no-untyped-call]
         elif type_num == REF_DELTA:
             base_ref, data = obj
             if isinstance(base_ref, int):
                 _, base_data, base = full_objects[base_ref]
             else:
+                assert store is not None
                 base_type_num, base_data = store.get_raw(base_ref)
-                base = obj_sha(base_type_num, base_data)
-            obj = (base, list(create_delta(base_data, data)))
+                base = obj_sha(base_type_num, base_data)  # type: ignore[no-untyped-call]
+            obj = (base, list(create_delta(base_data, data)))  # type: ignore[no-untyped-call]
 
-        crc32 = write_pack_object(sf.write, type_num, obj)
+        crc32 = write_pack_object(sf.write, type_num, obj)  # type: ignore[no-untyped-call]
         offsets[i] = offset
         crc32s[i] = crc32
 
@@ -269,12 +285,19 @@ def build_pack(f, objects_spec, store=None):
         assert len(sha) == 20
         expected.append((offsets[i], type_num, data, sha, crc32s[i]))
 
-    sf.write_sha()
+    sf.write_sha()  # type: ignore[no-untyped-call]
     f.seek(0)
     return expected
 
 
-def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
+def build_commit_graph(
+    object_store: BaseObjectStore,
+    commit_spec: list[list[int]],
+    trees: Optional[
+        dict[int, list[Union[tuple[bytes, ShaFile], tuple[bytes, ShaFile, int]]]]
+    ] = None,
+    attrs: Optional[dict[int, dict[str, Any]]] = None,
+) -> list[Commit]:
     """Build a commit graph from a concise specification.
 
     Sample usage:
@@ -311,7 +334,7 @@ def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
     if attrs is None:
         attrs = {}
     commit_time = 0
-    nums = {}
+    nums: dict[int, bytes] = {}
     commits = []
 
     for commit in commit_spec:
@@ -343,7 +366,7 @@ def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
 
         # By default, increment the time by a lot. Out-of-order commits should
         # be closer together than this because their main cause is clock skew.
-        commit_time = commit_attrs["commit_time"] + 100
+        commit_time = commit_attrs["commit_time"] + 100  # type: ignore[operator]
         nums[commit_num] = commit_obj.id
         object_store.add_object(commit_obj)
         commits.append(commit_obj)
@@ -351,12 +374,12 @@ def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
     return commits
 
 
-def setup_warning_catcher():
+def setup_warning_catcher() -> tuple[list[Warning], Callable[[], None]]:
     """Wrap warnings.showwarning with code that records warnings."""
     caught_warnings = []
     original_showwarning = warnings.showwarning
 
-    def custom_showwarning(*args, **kwargs) -> None:
+    def custom_showwarning(*args: Any, **kwargs: Any) -> None:
         caught_warnings.append(args[0])
 
     warnings.showwarning = custom_showwarning

+ 12 - 5
dulwich/walk.py

@@ -87,7 +87,9 @@ class WalkEntry:
                 parent = None
             elif len(self._get_parents(commit)) == 1:
                 changes_func = tree_changes
-                parent = cast(Commit, self._store[self._get_parents(commit)[0]]).tree
+                parent_commit = self._store[self._get_parents(commit)[0]]
+                assert isinstance(parent_commit, Commit)
+                parent = parent_commit.tree
                 if path_prefix:
                     mode, subtree_sha = parent.lookup_path(
                         self._store.__getitem__,
@@ -96,9 +98,12 @@ class WalkEntry:
                     parent = self._store[subtree_sha]
             else:
                 # For merge commits, we need to handle multiple parents differently
-                parent = [
-                    cast(Commit, self._store[p]).tree for p in self._get_parents(commit)
-                ]
+                parent_trees = []
+                for p in self._get_parents(commit):
+                    parent_commit = self._store[p]
+                    assert isinstance(parent_commit, Commit)
+                    parent_trees.append(parent_commit.tree)
+                parent = parent_trees
                 # Use a lambda to adapt the signature
                 changes_func = cast(
                     Any,
@@ -200,7 +205,9 @@ class _CommitTimeQueue:
                     # some caching (which DiskObjectStore currently does not).
                     # We could either add caching in this class or pass around
                     # parsed queue entry objects instead of commits.
-                    todo.append(cast(Commit, self._store[parent]))
+                    parent_commit = self._store[parent]
+                    assert isinstance(parent_commit, Commit)
+                    todo.append(parent_commit)
                 excluded.add(parent)
 
     def next(self) -> Optional[WalkEntry]:

+ 141 - 64
dulwich/web.py

@@ -26,9 +26,10 @@ import os
 import re
 import sys
 import time
-from collections.abc import Iterator
+from collections.abc import Iterable, Iterator
 from io import BytesIO
-from typing import BinaryIO, Callable, ClassVar, Optional, cast
+from types import TracebackType
+from typing import Any, BinaryIO, Callable, ClassVar, Optional, Union, cast
 from urllib.parse import parse_qs
 from wsgiref.simple_server import (
     ServerHandler,
@@ -37,6 +38,45 @@ from wsgiref.simple_server import (
     make_server,
 )
 
+# wsgiref.types was added in Python 3.11
+if sys.version_info >= (3, 11):
+    from wsgiref.types import StartResponse, WSGIApplication, WSGIEnvironment
+else:
+    # Fallback type definitions for Python < 3.11
+    from typing import TYPE_CHECKING
+
+    if TYPE_CHECKING:
+        # For type checking, use the _typeshed types if available
+        try:
+            from _typeshed.wsgi import StartResponse, WSGIApplication, WSGIEnvironment
+        except ImportError:
+            # Define our own protocol types for type checking
+            from typing import Protocol
+
+            class StartResponse(Protocol):  # type: ignore[no-redef]
+                """WSGI start_response callable protocol."""
+
+                def __call__(
+                    self,
+                    status: str,
+                    response_headers: list[tuple[str, str]],
+                    exc_info: Optional[
+                        tuple[type, BaseException, TracebackType]
+                    ] = None,
+                ) -> Callable[[bytes], None]:
+                    """Start the response with status and headers."""
+                    ...
+
+            WSGIEnvironment = dict[str, Any]  # type: ignore[misc]
+            WSGIApplication = Callable[  # type: ignore[misc]
+                [WSGIEnvironment, StartResponse], Iterable[bytes]
+            ]
+    else:
+        # At runtime, just use type aliases since these are only for type hints
+        StartResponse = Any
+        WSGIEnvironment = dict[str, Any]
+        WSGIApplication = Callable
+
 from dulwich import log_utils
 
 from .protocol import ReceivableProtocol
@@ -45,6 +85,7 @@ from .server import (
     DEFAULT_HANDLERS,
     Backend,
     DictBackend,
+    Handler,
     generate_info_refs,
     generate_objects_info_packs,
 )
@@ -292,13 +333,21 @@ def get_info_refs(
         yield req.not_found(str(e))
         return
     if service and not req.dumb:
+        if req.handlers is None:
+            yield req.forbidden("No handlers configured")
+            return
         handler_cls = req.handlers.get(service.encode("ascii"), None)
         if handler_cls is None:
             yield req.forbidden("Unsupported service")
             return
         req.nocache()
         write = req.respond(HTTP_OK, f"application/x-{service}-advertisement")
-        proto = ReceivableProtocol(BytesIO().read, write)
+
+        def write_fn(data: bytes) -> Optional[int]:
+            result = write(data)
+            return len(data) if result is not None else None
+
+        proto = ReceivableProtocol(BytesIO().read, write_fn)
         handler = handler_cls(
             backend,
             [url_prefix(mat)],
@@ -425,6 +474,9 @@ def handle_service_request(
     """
     service = mat.group().lstrip("/")
     logger.info("Handling service request for %s", service)
+    if req.handlers is None:
+        yield req.forbidden("No handlers configured")
+        return
     handler_cls = req.handlers.get(service.encode("ascii"), None)
     if handler_cls is None:
         yield req.forbidden("Unsupported service")
@@ -436,11 +488,16 @@ def handle_service_request(
         return
     req.nocache()
     write = req.respond(HTTP_OK, f"application/x-{service}-result")
+
+    def write_fn(data: bytes) -> Optional[int]:
+        result = write(data)
+        return len(data) if result is not None else None
+
     if req.environ.get("HTTP_TRANSFER_ENCODING") == "chunked":
         read = ChunkReader(req.environ["wsgi.input"]).read
     else:
         read = req.environ["wsgi.input"].read
-    proto = ReceivableProtocol(read, write)
+    proto = ReceivableProtocol(read, write_fn)
     # TODO(jelmer): Find a way to pass in repo, rather than having handler_cls
     # reopen.
     handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True)
@@ -455,7 +512,11 @@ class HTTPGitRequest:
     """
 
     def __init__(
-        self, environ, start_response, dumb: bool = False, handlers=None
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+        dumb: bool = False,
+        handlers: Optional[dict[bytes, Callable]] = None,
     ) -> None:
         """Initialize HTTPGitRequest.
 
@@ -472,7 +533,7 @@ class HTTPGitRequest:
         self._cache_headers: list[tuple[str, str]] = []
         self._headers: list[tuple[str, str]] = []
 
-    def add_header(self, name, value) -> None:
+    def add_header(self, name: str, value: str) -> None:
         """Add a header to the response."""
         self._headers.append((name, value))
 
@@ -481,7 +542,7 @@ class HTTPGitRequest:
         status: str = HTTP_OK,
         content_type: Optional[str] = None,
         headers: Optional[list[tuple[str, str]]] = None,
-    ):
+    ) -> Callable[[bytes], object]:
         """Begin a response with the given status and other headers."""
         if headers:
             self._headers.extend(headers)
@@ -556,7 +617,11 @@ class HTTPGitApplication:
     }
 
     def __init__(
-        self, backend, dumb: bool = False, handlers=None, fallback_app=None
+        self,
+        backend: Backend,
+        dumb: bool = False,
+        handlers: Optional[dict[bytes, Callable]] = None,
+        fallback_app: Optional[WSGIApplication] = None,
     ) -> None:
         """Initialize HTTPGitApplication.
 
@@ -568,12 +633,18 @@ class HTTPGitApplication:
         """
         self.backend = backend
         self.dumb = dumb
-        self.handlers = dict(DEFAULT_HANDLERS)
+        self.handlers: dict[bytes, Union[type[Handler], Callable[..., Any]]] = dict(
+            DEFAULT_HANDLERS
+        )
         self.fallback_app = fallback_app
         if handlers is not None:
             self.handlers.update(handlers)
 
-    def __call__(self, environ, start_response):
+    def __call__(
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+    ) -> Iterable[bytes]:
         """Handle WSGI request."""
         path = environ["PATH_INFO"]
         method = environ["REQUEST_METHOD"]
@@ -582,6 +653,7 @@ class HTTPGitApplication:
         )
         # environ['QUERY_STRING'] has qs args
         handler = None
+        mat = None
         for smethod, spath in self.services.keys():
             if smethod != method:
                 continue
@@ -590,7 +662,7 @@ class HTTPGitApplication:
                 handler = self.services[smethod, spath]
                 break
 
-        if handler is None:
+        if handler is None or mat is None:
             if self.fallback_app is not None:
                 return self.fallback_app(environ, start_response)
             else:
@@ -602,12 +674,16 @@ class HTTPGitApplication:
 class GunzipFilter:
     """WSGI middleware that unzips gzip-encoded requests before passing on to the underlying application."""
 
-    def __init__(self, application) -> None:
-        """Initialize GunzipFilter."""
+    def __init__(self, application: WSGIApplication) -> None:
+        """Initialize GunzipFilter with WSGI application."""
         self.app = application
 
-    def __call__(self, environ, start_response):
-        """Handle WSGI request."""
+    def __call__(
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+    ) -> Iterable[bytes]:
+        """Handle WSGI request with gzip decompression."""
         import gzip
 
         if environ.get("HTTP_CONTENT_ENCODING", "") == "gzip":
@@ -615,8 +691,7 @@ class GunzipFilter:
                 filename=None, fileobj=environ["wsgi.input"], mode="rb"
             )
             del environ["HTTP_CONTENT_ENCODING"]
-            if "CONTENT_LENGTH" in environ:
-                del environ["CONTENT_LENGTH"]
+            environ.pop("CONTENT_LENGTH", None)
 
         return self.app(environ, start_response)
 
@@ -624,12 +699,16 @@ class GunzipFilter:
 class LimitedInputFilter:
     """WSGI middleware that limits the input length of a request to that specified in Content-Length."""
 
-    def __init__(self, application) -> None:
-        """Initialize LimitedInputFilter."""
+    def __init__(self, application: WSGIApplication) -> None:
+        """Initialize LimitedInputFilter with WSGI application."""
         self.app = application
 
-    def __call__(self, environ, start_response):
-        """Handle WSGI request."""
+    def __call__(
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+    ) -> Iterable[bytes]:
+        """Handle WSGI request with input length limiting."""
         # This is not necessary if this app is run from a conforming WSGI
         # server. Unfortunately, there's no way to tell that at this point.
         # TODO: git may used HTTP/1.1 chunked encoding instead of specifying
@@ -642,9 +721,19 @@ class LimitedInputFilter:
         return self.app(environ, start_response)
 
 
-def make_wsgi_chain(*args, **kwargs):
-    """Factory function to create an instance of HTTPGitApplication, correctly wrapped with needed middleware."""
-    app = HTTPGitApplication(*args, **kwargs)
+def make_wsgi_chain(
+    backend: Backend,
+    dumb: bool = False,
+    handlers: Optional[dict[bytes, Callable[..., Any]]] = None,
+    fallback_app: Optional[WSGIApplication] = None,
+) -> WSGIApplication:
+    """Factory function to create an instance of HTTPGitApplication.
+
+    Correctly wrapped with needed middleware.
+    """
+    app = HTTPGitApplication(
+        backend, dumb=dumb, handlers=handlers, fallback_app=fallback_app
+    )
     wrapped_app = LimitedInputFilter(GunzipFilter(app))
     return wrapped_app
 
@@ -652,64 +741,52 @@ def make_wsgi_chain(*args, **kwargs):
 class ServerHandlerLogger(ServerHandler):
     """ServerHandler that uses dulwich's logger for logging exceptions."""
 
-    def log_exception(self, exc_info) -> None:
-        """Log an exception using dulwich's logger.
-
-        Args:
-          exc_info: Exception information tuple
-        """
+    def log_exception(
+        self,
+        exc_info: Union[
+            tuple[type[BaseException], BaseException, TracebackType],
+            tuple[None, None, None],
+            None,
+        ],
+    ) -> None:
+        """Log exception using dulwich logger."""
         logger.exception(
             "Exception happened during processing of request",
             exc_info=exc_info,
         )
 
-    def log_message(self, format, *args) -> None:
-        """Log a message using dulwich's logger.
-
-        Args:
-          format: Format string for the message
-          *args: Arguments for the format string
-        """
+    def log_message(self, format: str, *args: object) -> None:
+        """Log message using dulwich logger."""
         logger.info(format, *args)
 
-    def log_error(self, *args) -> None:
-        """Log an error using dulwich's logger.
-
-        Args:
-          *args: Error message components
-        """
+    def log_error(self, *args: object) -> None:
+        """Log error using dulwich logger."""
         logger.error(*args)
 
 
 class WSGIRequestHandlerLogger(WSGIRequestHandler):
     """WSGIRequestHandler that uses dulwich's logger for logging exceptions."""
 
-    def log_exception(self, exc_info) -> None:
-        """Log an exception using dulwich's logger.
-
-        Args:
-          exc_info: Exception information tuple
-        """
+    def log_exception(
+        self,
+        exc_info: Union[
+            tuple[type[BaseException], BaseException, TracebackType],
+            tuple[None, None, None],
+            None,
+        ],
+    ) -> None:
+        """Log exception using dulwich logger."""
         logger.exception(
             "Exception happened during processing of request",
             exc_info=exc_info,
         )
 
-    def log_message(self, format, *args) -> None:
-        """Log a message using dulwich's logger.
-
-        Args:
-          format: Format string for the message
-          *args: Arguments for the format string
-        """
+    def log_message(self, format: str, *args: object) -> None:
+        """Log message using dulwich logger."""
         logger.info(format, *args)
 
-    def log_error(self, *args) -> None:
-        """Log an error using dulwich's logger.
-
-        Args:
-          *args: Error message components
-        """
+    def log_error(self, *args: object) -> None:
+        """Log error using dulwich logger."""
         logger.error(*args)
 
     def handle(self) -> None:
@@ -731,14 +808,14 @@ class WSGIRequestHandlerLogger(WSGIRequestHandler):
 class WSGIServerLogger(WSGIServer):
     """WSGIServer that uses dulwich's logger for error handling."""
 
-    def handle_error(self, request, client_address) -> None:
+    def handle_error(self, request: object, client_address: tuple[str, int]) -> None:
         """Handle an error."""
         logger.exception(
             f"Exception happened during processing of request from {client_address!s}"
         )
 
 
-def main(argv=sys.argv) -> None:
+def main(argv: list[str] = sys.argv) -> None:
     """Entry point for starting an HTTP git server."""
     import optparse
 

+ 41 - 16
dulwich/worktree.py

@@ -34,9 +34,10 @@ import warnings
 from collections.abc import Iterable, Iterator
 from contextlib import contextmanager
 from pathlib import Path
+from typing import Any, Callable, Union
 
 from .errors import CommitError, HookError
-from .objects import Commit, ObjectID, Tag, Tree
+from .objects import Blob, Commit, ObjectID, Tag, Tree
 from .refs import SYMREF, Ref
 from .repo import (
     GITDIR,
@@ -335,7 +336,7 @@ class WorkTree:
 
         index = self._repo.open_index()
         try:
-            tree_id = self._repo[b"HEAD"].tree
+            commit = self._repo[b"HEAD"]
         except KeyError:
             # no head mean no commit in the repo
             for fs_path in fs_paths:
@@ -343,6 +344,9 @@ class WorkTree:
                 del index[tree_path]
             index.write()
             return
+        else:
+            assert isinstance(commit, Commit), "HEAD must be a commit"
+            tree_id = commit.tree
 
         for fs_path in fs_paths:
             tree_path = _fs_to_tree_path(fs_path)
@@ -367,15 +371,19 @@ class WorkTree:
             except FileNotFoundError:
                 pass
 
+            blob_obj = self._repo[tree_entry[1]]
+            assert isinstance(blob_obj, Blob)
+            blob_size = len(blob_obj.data)
+
             index_entry = IndexEntry(
-                ctime=(self._repo[b"HEAD"].commit_time, 0),
-                mtime=(self._repo[b"HEAD"].commit_time, 0),
+                ctime=(commit.commit_time, 0),
+                mtime=(commit.commit_time, 0),
                 dev=st.st_dev if st else 0,
                 ino=st.st_ino if st else 0,
                 mode=tree_entry[0],
                 uid=st.st_uid if st else 0,
                 gid=st.st_gid if st else 0,
-                size=len(self._repo[tree_entry[1]].data),
+                size=blob_size,
                 sha=tree_entry[1],
                 flags=0,
                 extended_flags=0,
@@ -386,7 +394,7 @@ class WorkTree:
 
     def commit(
         self,
-        message: bytes | None = None,
+        message: Union[str, bytes, Callable[[Any, Commit], bytes], None] = None,
         committer: bytes | None = None,
         author: bytes | None = None,
         commit_timestamp: float | None = None,
@@ -541,13 +549,18 @@ class WorkTree:
                 if should_sign:
                     c.sign(keyid)
                 self._repo.object_store.add_object(c)
+                message_bytes = (
+                    message.encode() if isinstance(message, str) else message
+                )
                 ok = self._repo.refs.set_if_equals(
                     ref,
                     old_head,
                     c.id,
-                    message=b"commit: " + message,
+                    message=b"commit: " + message_bytes,
                     committer=committer,
-                    timestamp=commit_timestamp,
+                    timestamp=int(commit_timestamp)
+                    if commit_timestamp is not None
+                    else None,
                     timezone=commit_timezone,
                 )
             except KeyError:
@@ -555,12 +568,17 @@ class WorkTree:
                 if should_sign:
                     c.sign(keyid)
                 self._repo.object_store.add_object(c)
+                message_bytes = (
+                    message.encode() if isinstance(message, str) else message
+                )
                 ok = self._repo.refs.add_if_new(
                     ref,
                     c.id,
-                    message=b"commit: " + message,
+                    message=b"commit: " + message_bytes,
                     committer=committer,
-                    timestamp=commit_timestamp,
+                    timestamp=int(commit_timestamp)
+                    if commit_timestamp is not None
+                    else None,
                     timezone=commit_timezone,
                 )
             if not ok:
@@ -603,6 +621,9 @@ class WorkTree:
             if isinstance(head, Tag):
                 _cls, obj = head.object
                 head = self._repo.get_object(obj)
+            from .objects import Commit
+
+            assert isinstance(head, Commit)
             tree = head.tree
         config = self._repo.get_config()
         honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
@@ -616,11 +637,15 @@ class WorkTree:
             symlink_fn = symlink
         else:
 
-            def symlink_fn(source, target) -> None:  # type: ignore
-                with open(
-                    target, "w" + ("b" if isinstance(source, bytes) else "")
-                ) as f:
-                    f.write(source)
+            def symlink_fn(
+                src: Union[str, bytes, os.PathLike],
+                dst: Union[str, bytes, os.PathLike],
+                target_is_directory: bool = False,
+                *,
+                dir_fd: int | None = None,
+            ) -> None:
+                with open(dst, "w" + ("b" if isinstance(src, bytes) else "")) as f:
+                    f.write(src)
 
         blob_normalizer = self._repo.get_blob_normalizer()
         return build_index_from_tree(
@@ -630,7 +655,7 @@ class WorkTree:
             tree,
             honor_filemode=honor_filemode,
             validate_path_element=validate_path_element,
-            symlink_fn=symlink_fn,
+            symlink_fn=symlink_fn,  # type: ignore[arg-type]
             blob_normalizer=blob_normalizer,
         )
 

+ 2 - 1
pyproject.toml

@@ -25,7 +25,7 @@ classifiers = [
 requires-python = ">=3.9"
 dependencies = [
     "urllib3>=1.25",
-    'typing_extensions >=4.0 ; python_version < "3.11"',
+    'typing_extensions >=4.0 ; python_version < "3.12"',
 ]
 dynamic = ["version"]
 license-files = ["COPYING"]
@@ -99,6 +99,7 @@ ignore = [
     "ANN205",
     "ANN206",
     "E501",  # line too long
+    "UP007",  # Use X | Y for type annotations (Python 3.10+ syntax, but we support 3.9+)
 ]
 
 [tool.ruff.lint.pydocstyle]

+ 18 - 18
tests/__init__.py

@@ -38,7 +38,7 @@ import tempfile
 
 # If Python itself provides an exception, use that
 import unittest
-from typing import ClassVar
+from typing import ClassVar, Optional
 from unittest import SkipTest, expectedFailure, skipIf
 from unittest import TestCase as _TestCase
 
@@ -49,7 +49,7 @@ class TestCase(_TestCase):
         self.overrideEnv("HOME", "/nonexistent")
         self.overrideEnv("GIT_CONFIG_NOSYSTEM", "1")
 
-    def overrideEnv(self, name, value) -> None:
+    def overrideEnv(self, name: str, value: Optional[str]) -> None:
         def restore() -> None:
             if oldval is not None:
                 os.environ[name] = oldval
@@ -74,7 +74,7 @@ class BlackboxTestCase(TestCase):
         "/usr/local/bin",
     ]
 
-    def bin_path(self, name):
+    def bin_path(self, name: str) -> str:
         """Determine the full path of a binary.
 
         Args:
@@ -88,7 +88,7 @@ class BlackboxTestCase(TestCase):
         else:
             raise SkipTest(f"Unable to find binary {name}")
 
-    def run_command(self, name, args):
+    def run_command(self, name: str, args: list[str]) -> subprocess.Popen[bytes]:
         """Run a Dulwich command.
 
         Args:
@@ -113,7 +113,7 @@ class BlackboxTestCase(TestCase):
         )
 
 
-def self_test_suite():
+def self_test_suite() -> unittest.TestSuite:
     names = [
         "annotate",
         "archive",
@@ -181,7 +181,7 @@ def self_test_suite():
     return loader.loadTestsFromNames(module_names)
 
 
-def tutorial_test_suite():
+def tutorial_test_suite() -> unittest.TestSuite:
     tutorial = [
         "introduction",
         "file-format",
@@ -194,7 +194,7 @@ def tutorial_test_suite():
 
     to_restore = []
 
-    def overrideEnv(name, value) -> None:
+    def overrideEnv(name: str, value: Optional[str]) -> None:
         oldval = os.environ.get(name)
         if value is not None:
             os.environ[name] = value
@@ -202,17 +202,17 @@ def tutorial_test_suite():
             del os.environ[name]
         to_restore.append((name, oldval))
 
-    def setup(test) -> None:
-        test.__old_cwd = os.getcwd()
-        test.tempdir = tempfile.mkdtemp()
-        test.globs.update({"tempdir": test.tempdir})
-        os.chdir(test.tempdir)
+    def setup(test: doctest.DocTest) -> None:
+        test.__old_cwd = os.getcwd()  # type: ignore[attr-defined]
+        test.tempdir = tempfile.mkdtemp()  # type: ignore[attr-defined]
+        test.globs.update({"tempdir": test.tempdir})  # type: ignore[attr-defined]
+        os.chdir(test.tempdir)  # type: ignore[attr-defined]
         overrideEnv("HOME", "/nonexistent")
         overrideEnv("GIT_CONFIG_NOSYSTEM", "1")
 
-    def teardown(test) -> None:
-        os.chdir(test.__old_cwd)
-        shutil.rmtree(test.tempdir)
+    def teardown(test: doctest.DocTest) -> None:
+        os.chdir(test.__old_cwd)  # type: ignore[attr-defined]
+        shutil.rmtree(test.tempdir)  # type: ignore[attr-defined]
         for name, oldval in to_restore:
             if oldval is not None:
                 os.environ[name] = oldval
@@ -229,7 +229,7 @@ def tutorial_test_suite():
     )
 
 
-def nocompat_test_suite():
+def nocompat_test_suite() -> unittest.TestSuite:
     result = unittest.TestSuite()
     result.addTests(self_test_suite())
     result.addTests(tutorial_test_suite())
@@ -239,7 +239,7 @@ def nocompat_test_suite():
     return result
 
 
-def compat_test_suite():
+def compat_test_suite() -> unittest.TestSuite:
     result = unittest.TestSuite()
     from .compat import test_suite as compat_test_suite
 
@@ -247,7 +247,7 @@ def compat_test_suite():
     return result
 
 
-def test_suite():
+def test_suite() -> unittest.TestSuite:
     result = unittest.TestSuite()
     result.addTests(self_test_suite())
     if sys.platform != "win32":

+ 1 - 1
tests/compat/__init__.py

@@ -24,7 +24,7 @@
 import unittest
 
 
-def test_suite():
+def test_suite() -> unittest.TestSuite:
     names = [
         "bundle",
         "check_ignore",

+ 2 - 2
tests/contrib/__init__.py

@@ -19,10 +19,10 @@
 # License, Version 2.0.
 #
 
+import unittest
 
-def test_suite():
-    import unittest
 
+def test_suite() -> unittest.TestSuite:
     names = [
         "diffstat",
         "paramiko_vendor",

+ 51 - 34
tests/test_annotate.py

@@ -21,6 +21,7 @@
 import os
 import tempfile
 import unittest
+from typing import Any, Optional
 from unittest import TestCase
 
 from dulwich.annotate import annotate_lines, update_lines
@@ -32,28 +33,28 @@ from dulwich.repo import Repo
 class UpdateLinesTestCase(TestCase):
     """Tests for update_lines function."""
 
-    def test_update_lines_equal(self):
+    def test_update_lines_equal(self) -> None:
         """Test update_lines when all lines are equal."""
-        old_lines = [
+        old_lines: list[tuple[tuple[Any, Any], bytes]] = [
             (("commit1", "entry1"), b"line1"),
             (("commit2", "entry2"), b"line2"),
         ]
         new_blob = b"line1\nline2"
         new_history_data = ("commit3", "entry3")
 
-        result = update_lines(old_lines, new_history_data, new_blob)
+        result = update_lines(old_lines, new_history_data, new_blob)  # type: ignore[arg-type]
         self.assertEqual(old_lines, result)
 
-    def test_update_lines_insert(self):
+    def test_update_lines_insert(self) -> None:
         """Test update_lines when new lines are inserted."""
-        old_lines = [
+        old_lines: list[tuple[tuple[Any, Any], bytes]] = [
             (("commit1", "entry1"), b"line1"),
             (("commit2", "entry2"), b"line3"),
         ]
         new_blob = b"line1\nline2\nline3"
         new_history_data = ("commit3", "entry3")
 
-        result = update_lines(old_lines, new_history_data, new_blob)
+        result = update_lines(old_lines, new_history_data, new_blob)  # type: ignore[arg-type]
         expected = [
             (("commit1", "entry1"), b"line1"),
             (("commit3", "entry3"), b"line2"),
@@ -61,9 +62,9 @@ class UpdateLinesTestCase(TestCase):
         ]
         self.assertEqual(expected, result)
 
-    def test_update_lines_delete(self):
+    def test_update_lines_delete(self) -> None:
         """Test update_lines when lines are deleted."""
-        old_lines = [
+        old_lines: list[tuple[tuple[Any, Any], bytes]] = [
             (("commit1", "entry1"), b"line1"),
             (("commit2", "entry2"), b"line2"),
             (("commit3", "entry3"), b"line3"),
@@ -71,66 +72,70 @@ class UpdateLinesTestCase(TestCase):
         new_blob = b"line1\nline3"
         new_history_data = ("commit4", "entry4")
 
-        result = update_lines(old_lines, new_history_data, new_blob)
+        result = update_lines(old_lines, new_history_data, new_blob)  # type: ignore[arg-type]
         expected = [
             (("commit1", "entry1"), b"line1"),
             (("commit3", "entry3"), b"line3"),
         ]
         self.assertEqual(expected, result)
 
-    def test_update_lines_replace(self):
+    def test_update_lines_replace(self) -> None:
         """Test update_lines when lines are replaced."""
-        old_lines = [
+        old_lines: list[tuple[tuple[Any, Any], bytes]] = [
             (("commit1", "entry1"), b"line1"),
             (("commit2", "entry2"), b"line2"),
         ]
         new_blob = b"line1\nline2_modified"
         new_history_data = ("commit3", "entry3")
 
-        result = update_lines(old_lines, new_history_data, new_blob)
+        result = update_lines(old_lines, new_history_data, new_blob)  # type: ignore[arg-type]
         expected = [
             (("commit1", "entry1"), b"line1"),
             (("commit3", "entry3"), b"line2_modified"),
         ]
         self.assertEqual(expected, result)
 
-    def test_update_lines_empty_old(self):
+    def test_update_lines_empty_old(self) -> None:
         """Test update_lines with empty old lines."""
-        old_lines = []
+        old_lines: list[tuple[tuple[Any, Any], bytes]] = []
         new_blob = b"line1\nline2"
         new_history_data = ("commit1", "entry1")
 
-        result = update_lines(old_lines, new_history_data, new_blob)
+        result = update_lines(old_lines, new_history_data, new_blob)  # type: ignore[arg-type]
         expected = [
             (("commit1", "entry1"), b"line1"),
             (("commit1", "entry1"), b"line2"),
         ]
         self.assertEqual(expected, result)
 
-    def test_update_lines_empty_new(self):
+    def test_update_lines_empty_new(self) -> None:
         """Test update_lines with empty new blob."""
-        old_lines = [(("commit1", "entry1"), b"line1")]
+        old_lines: list[tuple[tuple[Any, Any], bytes]] = [
+            (("commit1", "entry1"), b"line1")
+        ]
         new_blob = b""
         new_history_data = ("commit2", "entry2")
 
-        result = update_lines(old_lines, new_history_data, new_blob)
+        result = update_lines(old_lines, new_history_data, new_blob)  # type: ignore[arg-type]
         self.assertEqual([], result)
 
 
 class AnnotateLinesTestCase(TestCase):
     """Tests for annotate_lines function."""
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.temp_dir = tempfile.mkdtemp()
         self.repo = Repo.init(self.temp_dir)
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         self.repo.close()
         import shutil
 
         shutil.rmtree(self.temp_dir)
 
-    def _make_commit(self, blob_content, message, parent=None):
+    def _make_commit(
+        self, blob_content: bytes, message: str, parent: Optional[bytes] = None
+    ) -> bytes:
         """Helper to create a commit with a single file."""
         # Create blob
         blob = Blob()
@@ -159,7 +164,7 @@ class AnnotateLinesTestCase(TestCase):
         self.repo.object_store.add_object(commit)
         return commit.id
 
-    def test_annotate_lines_single_commit(self):
+    def test_annotate_lines_single_commit(self) -> None:
         """Test annotating a file with a single commit."""
         commit_id = self._make_commit(b"line1\nline2\nline3\n", "Initial commit")
 
@@ -170,7 +175,7 @@ class AnnotateLinesTestCase(TestCase):
             self.assertEqual(commit_id, commit.id)
             self.assertIn(line, [b"line1\n", b"line2\n", b"line3\n"])
 
-    def test_annotate_lines_multiple_commits(self):
+    def test_annotate_lines_multiple_commits(self) -> None:
         """Test annotating a file with multiple commits."""
         # First commit
         commit1_id = self._make_commit(b"line1\nline2\n", "Initial commit")
@@ -200,7 +205,7 @@ class AnnotateLinesTestCase(TestCase):
         self.assertEqual(commit1_id, result[2][0][0].id)
         self.assertEqual(b"line2\n", result[2][1])
 
-    def test_annotate_lines_nonexistent_path(self):
+    def test_annotate_lines_nonexistent_path(self) -> None:
         """Test annotating a nonexistent file."""
         commit_id = self._make_commit(b"content\n", "Initial commit")
 
@@ -211,17 +216,23 @@ class AnnotateLinesTestCase(TestCase):
 class PorcelainAnnotateTestCase(TestCase):
     """Tests for the porcelain annotate function."""
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.temp_dir = tempfile.mkdtemp()
         self.repo = Repo.init(self.temp_dir)
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         self.repo.close()
         import shutil
 
         shutil.rmtree(self.temp_dir)
 
-    def _make_commit_with_file(self, filename, content, message, parent=None):
+    def _make_commit_with_file(
+        self,
+        filename: str,
+        content: bytes,
+        message: str,
+        parent: Optional[bytes] = None,
+    ) -> bytes:
         """Helper to create a commit with a file."""
         # Create blob
         blob = Blob()
@@ -254,7 +265,7 @@ class PorcelainAnnotateTestCase(TestCase):
 
         return commit.id
 
-    def test_porcelain_annotate(self):
+    def test_porcelain_annotate(self) -> None:
         """Test the porcelain annotate function."""
         # Create commits
         commit1_id = self._make_commit_with_file(
@@ -274,7 +285,7 @@ class PorcelainAnnotateTestCase(TestCase):
             self.assertIsNotNone(entry)
             self.assertIn(line, [b"line1\n", b"line2\n", b"line3\n"])
 
-    def test_porcelain_annotate_with_committish(self):
+    def test_porcelain_annotate_with_committish(self) -> None:
         """Test porcelain annotate with specific commit."""
         # Create commits
         commit1_id = self._make_commit_with_file(
@@ -296,7 +307,7 @@ class PorcelainAnnotateTestCase(TestCase):
         self.assertEqual(1, len(result))
         self.assertEqual(b"modified\n", result[0][1])
 
-    def test_blame_alias(self):
+    def test_blame_alias(self) -> None:
         """Test that blame is an alias for annotate."""
         self.assertIs(blame, annotate)
 
@@ -304,17 +315,23 @@ class PorcelainAnnotateTestCase(TestCase):
 class IntegrationTestCase(TestCase):
     """Integration tests with more complex scenarios."""
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.temp_dir = tempfile.mkdtemp()
         self.repo = Repo.init(self.temp_dir)
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         self.repo.close()
         import shutil
 
         shutil.rmtree(self.temp_dir)
 
-    def _create_file_commit(self, filename, content, message, parent=None):
+    def _create_file_commit(
+        self,
+        filename: str,
+        content: bytes,
+        message: str,
+        parent: Optional[bytes] = None,
+    ) -> bytes:
         """Helper to create a commit with file content."""
         # Write file to working directory
         filepath = os.path.join(self.temp_dir, filename)
@@ -337,7 +354,7 @@ class IntegrationTestCase(TestCase):
 
         return commit_id
 
-    def test_complex_file_history(self):
+    def test_complex_file_history(self) -> None:
         """Test annotating a file with complex history."""
         # Initial commit with 3 lines
         self._create_file_commit(

+ 7 - 10
tests/test_archive.py

@@ -24,7 +24,8 @@
 import struct
 import tarfile
 from io import BytesIO
-from unittest import skipUnless
+from typing import Optional
+from unittest.mock import patch
 
 from dulwich.archive import tar_stream
 from dulwich.object_store import MemoryObjectStore
@@ -33,11 +34,6 @@ from dulwich.tests.utils import build_commit_graph
 
 from . import TestCase
 
-try:
-    from unittest.mock import patch
-except ImportError:
-    patch = None
-
 
 class ArchiveTests(TestCase):
     def test_empty(self) -> None:
@@ -50,14 +46,16 @@ class ArchiveTests(TestCase):
         self.addCleanup(tf.close)
         self.assertEqual([], tf.getnames())
 
-    def _get_example_tar_stream(self, *tar_stream_args, **tar_stream_kwargs):
+    def _get_example_tar_stream(
+        self, mtime: int, prefix: bytes = b"", format: str = ""
+    ) -> BytesIO:
         store = MemoryObjectStore()
         b1 = Blob.from_string(b"somedata")
         store.add_object(b1)
         t1 = Tree()
         t1.add(b"somename", 0o100644, b1.id)
         store.add_object(t1)
-        stream = b"".join(tar_stream(store, t1, *tar_stream_args, **tar_stream_kwargs))
+        stream = b"".join(tar_stream(store, t1, mtime, prefix, format))
         return BytesIO(stream)
 
     def test_simple(self) -> None:
@@ -89,9 +87,8 @@ class ArchiveTests(TestCase):
         expected_mtime = struct.pack("<L", 1234)
         self.assertEqual(stream.getvalue()[4:8], expected_mtime)
 
-    @skipUnless(patch, "Required mock.patch")
     def test_same_file(self) -> None:
-        contents = [None, None]
+        contents: list[Optional[bytes]] = [None, None]
         for format in ["", "gz", "bz2"]:
             for i in [0, 1]:
                 with patch("time.time", return_value=i):

+ 16 - 16
tests/test_bisect.py

@@ -36,19 +36,19 @@ from . import TestCase
 class BisectStateTests(TestCase):
     """Tests for BisectState class."""
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.test_dir = tempfile.mkdtemp()
         self.repo = porcelain.init(self.test_dir)
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         shutil.rmtree(self.test_dir)
 
-    def test_is_active_false(self):
+    def test_is_active_false(self) -> None:
         """Test is_active when no bisect session is active."""
         state = BisectState(self.repo)
         self.assertFalse(state.is_active)
 
-    def test_start_bisect(self):
+    def test_start_bisect(self) -> None:
         """Test starting a bisect session."""
         # Create at least one commit so HEAD exists
         c1 = make_commit(id=b"1" * 40, message=b"initial commit")
@@ -73,7 +73,7 @@ class BisectStateTests(TestCase):
             os.path.exists(os.path.join(self.repo.controldir(), "BISECT_LOG"))
         )
 
-    def test_start_bisect_no_head(self):
+    def test_start_bisect_no_head(self) -> None:
         """Test starting a bisect session when repository has no HEAD."""
         state = BisectState(self.repo)
 
@@ -81,7 +81,7 @@ class BisectStateTests(TestCase):
             state.start()
         self.assertIn("Cannot start bisect: repository has no HEAD", str(cm.exception))
 
-    def test_start_bisect_already_active(self):
+    def test_start_bisect_already_active(self) -> None:
         """Test starting a bisect session when one is already active."""
         # Create at least one commit so HEAD exists
         c1 = make_commit(id=b"1" * 40, message=b"initial commit")
@@ -94,28 +94,28 @@ class BisectStateTests(TestCase):
         with self.assertRaises(ValueError):
             state.start()
 
-    def test_mark_bad_no_session(self):
+    def test_mark_bad_no_session(self) -> None:
         """Test marking bad commit when no session is active."""
         state = BisectState(self.repo)
 
         with self.assertRaises(ValueError):
             state.mark_bad()
 
-    def test_mark_good_no_session(self):
+    def test_mark_good_no_session(self) -> None:
         """Test marking good commit when no session is active."""
         state = BisectState(self.repo)
 
         with self.assertRaises(ValueError):
             state.mark_good()
 
-    def test_reset_no_session(self):
+    def test_reset_no_session(self) -> None:
         """Test resetting when no session is active."""
         state = BisectState(self.repo)
 
         with self.assertRaises(ValueError):
             state.reset()
 
-    def test_bisect_workflow(self):
+    def test_bisect_workflow(self) -> None:
         """Test a complete bisect workflow."""
         # Create some commits
         c1 = make_commit(id=b"1" * 40, message=b"good commit 1")
@@ -163,7 +163,7 @@ class BisectStateTests(TestCase):
 class BisectPorcelainTests(TestCase):
     """Tests for porcelain bisect functions."""
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.test_dir = tempfile.mkdtemp()
         self.repo = porcelain.init(self.test_dir)
 
@@ -191,10 +191,10 @@ class BisectPorcelainTests(TestCase):
         self.repo.refs[b"HEAD"] = self.c4.id
         self.repo.refs[b"refs/heads/master"] = self.c4.id
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         shutil.rmtree(self.test_dir)
 
-    def test_bisect_start(self):
+    def test_bisect_start(self) -> None:
         """Test bisect_start porcelain function."""
         porcelain.bisect_start(self.test_dir)
 
@@ -203,7 +203,7 @@ class BisectPorcelainTests(TestCase):
             os.path.exists(os.path.join(self.repo.controldir(), "BISECT_START"))
         )
 
-    def test_bisect_bad_good(self):
+    def test_bisect_bad_good(self) -> None:
         """Test marking commits as bad and good."""
         porcelain.bisect_start(self.test_dir)
         porcelain.bisect_bad(self.test_dir, self.c4.id.decode("ascii"))
@@ -226,7 +226,7 @@ class BisectPorcelainTests(TestCase):
             )
         )
 
-    def test_bisect_log(self):
+    def test_bisect_log(self) -> None:
         """Test getting bisect log."""
         porcelain.bisect_start(self.test_dir)
         porcelain.bisect_bad(self.test_dir, self.c4.id.decode("ascii"))
@@ -238,7 +238,7 @@ class BisectPorcelainTests(TestCase):
         self.assertIn("git bisect bad", log)
         self.assertIn("git bisect good", log)
 
-    def test_bisect_reset(self):
+    def test_bisect_reset(self) -> None:
         """Test resetting bisect state."""
         porcelain.bisect_start(self.test_dir)
         porcelain.bisect_bad(self.test_dir)

+ 2 - 2
tests/test_cloud_gcs.py

@@ -39,8 +39,8 @@ class GcsObjectStoreTests(unittest.TestCase):
         self.assertIn("git", repr(self.store))
 
     def test_remove_pack(self):
-        """Test _remove_pack method."""
-        self.store._remove_pack("pack-1234")
+        """Test _remove_pack_by_name method."""
+        self.store._remove_pack_by_name("pack-1234")
         self.mock_bucket.delete_blobs.assert_called_once()
         args = self.mock_bucket.delete_blobs.call_args[0][0]
         self.assertEqual(

+ 51 - 49
tests/test_commit_graph.py

@@ -35,7 +35,7 @@ from dulwich.commit_graph import (
 class CommitGraphEntryTests(unittest.TestCase):
     """Tests for CommitGraphEntry."""
 
-    def test_init(self):
+    def test_init(self) -> None:
         commit_id = b"a" * 40
         tree_id = b"b" * 40
         parents = [b"c" * 40, b"d" * 40]
@@ -50,7 +50,7 @@ class CommitGraphEntryTests(unittest.TestCase):
         self.assertEqual(entry.generation, generation)
         self.assertEqual(entry.commit_time, commit_time)
 
-    def test_repr(self):
+    def test_repr(self) -> None:
         entry = CommitGraphEntry(b"a" * 40, b"b" * 40, [], 1, 1000)
         repr_str = repr(entry)
         self.assertIn("CommitGraphEntry", repr_str)
@@ -60,12 +60,12 @@ class CommitGraphEntryTests(unittest.TestCase):
 class CommitGraphChunkTests(unittest.TestCase):
     """Tests for CommitGraphChunk."""
 
-    def test_init(self):
+    def test_init(self) -> None:
         chunk = CommitGraphChunk(b"TEST", b"test data")
         self.assertEqual(chunk.chunk_id, b"TEST")
         self.assertEqual(chunk.data, b"test data")
 
-    def test_repr(self):
+    def test_repr(self) -> None:
         chunk = CommitGraphChunk(b"TEST", b"x" * 100)
         repr_str = repr(chunk)
         self.assertIn("CommitGraphChunk", repr_str)
@@ -75,13 +75,13 @@ class CommitGraphChunkTests(unittest.TestCase):
 class CommitGraphTests(unittest.TestCase):
     """Tests for CommitGraph."""
 
-    def test_init(self):
+    def test_init(self) -> None:
         graph = CommitGraph()
         self.assertEqual(graph.hash_version, HASH_VERSION_SHA1)
         self.assertEqual(len(graph.entries), 0)
         self.assertEqual(len(graph.chunks), 0)
 
-    def test_len(self):
+    def test_len(self) -> None:
         graph = CommitGraph()
         self.assertEqual(len(graph), 0)
 
@@ -90,7 +90,7 @@ class CommitGraphTests(unittest.TestCase):
         graph.entries.append(entry)
         self.assertEqual(len(graph), 1)
 
-    def test_iter(self):
+    def test_iter(self) -> None:
         graph = CommitGraph()
         entry1 = CommitGraphEntry(b"a" * 40, b"b" * 40, [], 1, 1000)
         entry2 = CommitGraphEntry(b"c" * 40, b"d" * 40, [], 2, 2000)
@@ -101,22 +101,22 @@ class CommitGraphTests(unittest.TestCase):
         self.assertEqual(entries[0], entry1)
         self.assertEqual(entries[1], entry2)
 
-    def test_get_entry_by_oid_missing(self):
+    def test_get_entry_by_oid_missing(self) -> None:
         graph = CommitGraph()
         result = graph.get_entry_by_oid(b"f" * 40)
         self.assertIsNone(result)
 
-    def test_get_generation_number_missing(self):
+    def test_get_generation_number_missing(self) -> None:
         graph = CommitGraph()
         result = graph.get_generation_number(b"f" * 40)
         self.assertIsNone(result)
 
-    def test_get_parents_missing(self):
+    def test_get_parents_missing(self) -> None:
         graph = CommitGraph()
         result = graph.get_parents(b"f" * 40)
         self.assertIsNone(result)
 
-    def test_from_invalid_signature(self):
+    def test_from_invalid_signature(self) -> None:
         data = b"XXXX" + b"\\x00" * 100
         f = io.BytesIO(data)
 
@@ -124,7 +124,7 @@ class CommitGraphTests(unittest.TestCase):
             CommitGraph.from_file(f)
         self.assertIn("Invalid commit graph signature", str(cm.exception))
 
-    def test_from_invalid_version(self):
+    def test_from_invalid_version(self) -> None:
         data = COMMIT_GRAPH_SIGNATURE + struct.pack(">B", 99) + b"\\x00" * 100
         f = io.BytesIO(data)
 
@@ -132,7 +132,7 @@ class CommitGraphTests(unittest.TestCase):
             CommitGraph.from_file(f)
         self.assertIn("Unsupported commit graph version", str(cm.exception))
 
-    def test_from_invalid_hash_version(self):
+    def test_from_invalid_hash_version(self) -> None:
         data = (
             COMMIT_GRAPH_SIGNATURE
             + struct.pack(">B", COMMIT_GRAPH_VERSION)
@@ -145,7 +145,7 @@ class CommitGraphTests(unittest.TestCase):
             CommitGraph.from_file(f)
         self.assertIn("Unsupported hash version", str(cm.exception))
 
-    def create_minimal_commit_graph_data(self):
+    def create_minimal_commit_graph_data(self) -> bytes:
         """Create minimal valid commit graph data for testing."""
         # Create the data in order and calculate offsets properly
 
@@ -209,7 +209,7 @@ class CommitGraphTests(unittest.TestCase):
 
         return header + toc + fanout + oid_lookup + commit_data
 
-    def test_from_minimal_valid_file(self):
+    def test_from_minimal_valid_file(self) -> None:
         """Test parsing a minimal but valid commit graph file."""
         data = self.create_minimal_commit_graph_data()
         f = io.BytesIO(data)
@@ -233,7 +233,7 @@ class CommitGraphTests(unittest.TestCase):
         self.assertEqual(graph.get_parents(commit_oid), [])
         self.assertIsNotNone(graph.get_entry_by_oid(commit_oid))
 
-    def test_missing_required_chunks(self):
+    def test_missing_required_chunks(self) -> None:
         """Test error handling for missing required chunks."""
         # Create data with header but no chunks
         header = (
@@ -254,7 +254,7 @@ class CommitGraphTests(unittest.TestCase):
             CommitGraph.from_file(f)
         self.assertIn("Missing required OID lookup chunk", str(cm.exception))
 
-    def test_write_empty_graph_raises(self):
+    def test_write_empty_graph_raises(self) -> None:
         """Test that writing empty graph raises ValueError."""
         graph = CommitGraph()
         f = io.BytesIO()
@@ -262,7 +262,7 @@ class CommitGraphTests(unittest.TestCase):
         with self.assertRaises(ValueError):
             graph.write_to_file(f)
 
-    def test_write_and_read_round_trip(self):
+    def test_write_and_read_round_trip(self) -> None:
         """Test writing and reading a commit graph."""
         # Create a simple commit graph
         graph = CommitGraph()
@@ -297,21 +297,21 @@ class CommitGraphTests(unittest.TestCase):
 class CommitGraphFileOperationsTests(unittest.TestCase):
     """Tests for commit graph file operations."""
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.tempdir = tempfile.mkdtemp()
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         import shutil
 
         shutil.rmtree(self.tempdir, ignore_errors=True)
 
-    def test_read_commit_graph_missing_file(self):
+    def test_read_commit_graph_missing_file(self) -> None:
         """Test reading from non-existent file."""
         missing_path = os.path.join(self.tempdir, "missing.graph")
         result = read_commit_graph(missing_path)
         self.assertIsNone(result)
 
-    def test_read_commit_graph_invalid_file(self):
+    def test_read_commit_graph_invalid_file(self) -> None:
         """Test reading from invalid file."""
         invalid_path = os.path.join(self.tempdir, "invalid.graph")
         with open(invalid_path, "wb") as f:
@@ -320,12 +320,12 @@ class CommitGraphFileOperationsTests(unittest.TestCase):
         with self.assertRaises(ValueError):
             read_commit_graph(invalid_path)
 
-    def test_find_commit_graph_file_missing(self):
+    def test_find_commit_graph_file_missing(self) -> None:
         """Test finding commit graph file when it doesn't exist."""
         result = find_commit_graph_file(self.tempdir)
         self.assertIsNone(result)
 
-    def test_find_commit_graph_file_standard_location(self):
+    def test_find_commit_graph_file_standard_location(self) -> None:
         """Test finding commit graph file in standard location."""
         # Create .git/objects/info/commit-graph
         objects_dir = os.path.join(self.tempdir, "objects")
@@ -339,7 +339,7 @@ class CommitGraphFileOperationsTests(unittest.TestCase):
         result = find_commit_graph_file(self.tempdir)
         self.assertEqual(result, graph_path.encode())
 
-    def test_find_commit_graph_file_chain_location(self):
+    def test_find_commit_graph_file_chain_location(self) -> None:
         """Test finding commit graph file in chain location."""
         # Create .git/objects/info/commit-graphs/graph-{hash}.graph
         objects_dir = os.path.join(self.tempdir, "objects")
@@ -354,7 +354,7 @@ class CommitGraphFileOperationsTests(unittest.TestCase):
         result = find_commit_graph_file(self.tempdir)
         self.assertEqual(result, graph_path.encode())
 
-    def test_find_commit_graph_file_prefers_standard(self):
+    def test_find_commit_graph_file_prefers_standard(self) -> None:
         """Test that standard location is preferred over chain location."""
         # Create both locations
         objects_dir = os.path.join(self.tempdir, "objects")
@@ -380,15 +380,15 @@ class CommitGraphFileOperationsTests(unittest.TestCase):
 class CommitGraphGenerationTests(unittest.TestCase):
     """Tests for commit graph generation functionality."""
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.tempdir = tempfile.mkdtemp()
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         import shutil
 
         shutil.rmtree(self.tempdir, ignore_errors=True)
 
-    def test_generate_commit_graph_empty(self):
+    def test_generate_commit_graph_empty(self) -> None:
         """Test generating commit graph with no commits."""
         from dulwich.object_store import MemoryObjectStore
 
@@ -397,7 +397,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
 
         self.assertEqual(len(graph), 0)
 
-    def test_generate_commit_graph_single_commit(self):
+    def test_generate_commit_graph_single_commit(self) -> None:
         """Test generating commit graph with single commit."""
         from dulwich.object_store import MemoryObjectStore
         from dulwich.objects import Commit, Tree
@@ -428,7 +428,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
         self.assertEqual(entry.generation, 1)
         self.assertEqual(entry.commit_time, 1234567890)
 
-    def test_get_reachable_commits(self):
+    def test_get_reachable_commits(self) -> None:
         """Test getting reachable commits."""
         from dulwich.object_store import MemoryObjectStore
         from dulwich.objects import Commit, Tree
@@ -465,7 +465,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
         self.assertIn(commit1.id, reachable)
         self.assertIn(commit2.id, reachable)
 
-    def test_write_commit_graph_to_file(self):
+    def test_write_commit_graph_to_file(self) -> None:
         """Test writing commit graph to file."""
         from dulwich.object_store import DiskObjectStore
         from dulwich.objects import Commit, Tree
@@ -498,13 +498,14 @@ class CommitGraphGenerationTests(unittest.TestCase):
         # Read back and verify
         graph = read_commit_graph(graph_path)
         self.assertIsNotNone(graph)
+        assert graph is not None  # For mypy
         self.assertEqual(len(graph), 1)
 
         entry = graph.entries[0]
         self.assertEqual(entry.commit_id, commit.id)
         self.assertEqual(entry.tree_id, commit.tree)
 
-    def test_object_store_commit_graph_methods(self):
+    def test_object_store_commit_graph_methods(self) -> None:
         """Test ObjectStore commit graph methods."""
         from dulwich.object_store import DiskObjectStore
         from dulwich.objects import Commit, Tree
@@ -515,7 +516,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
         object_store = DiskObjectStore(object_store_path)
 
         # Initially no commit graph
-        self.assertIsNone(object_store.get_commit_graph())
+        self.assertIsNone(object_store.get_commit_graph())  # type: ignore[no-untyped-call]
 
         # Create a tree and commit
         tree = Tree()
@@ -534,13 +535,13 @@ class CommitGraphGenerationTests(unittest.TestCase):
         object_store.write_commit_graph([commit.id], reachable=False)
 
         # Now should have commit graph
-        self.assertIsNotNone(object_store.get_commit_graph())
+        self.assertIsNotNone(object_store.get_commit_graph())  # type: ignore[no-untyped-call]
 
         # Test update (should still have commit graph)
         object_store.write_commit_graph()
-        self.assertIsNot(None, object_store.get_commit_graph())
+        self.assertIsNot(None, object_store.get_commit_graph())  # type: ignore[no-untyped-call]
 
-    def test_parents_provider_commit_graph_integration(self):
+    def test_parents_provider_commit_graph_integration(self) -> None:
         """Test that ParentsProvider uses commit graph when available."""
         from dulwich.object_store import DiskObjectStore
         from dulwich.objects import Commit, Tree
@@ -584,10 +585,10 @@ class CommitGraphGenerationTests(unittest.TestCase):
         self.assertIsNotNone(provider.commit_graph)
 
         # Test parent lookups
-        parents1 = provider.get_parents(commit1.id)
+        parents1 = provider.get_parents(commit1.id)  # type: ignore[no-untyped-call]
         self.assertEqual(parents1, [])
 
-        parents2 = provider.get_parents(commit2.id)
+        parents2 = provider.get_parents(commit2.id)  # type: ignore[no-untyped-call]
         self.assertEqual(parents2, [commit1.id])
 
         # Test fallback behavior by creating provider without commit graph
@@ -602,13 +603,13 @@ class CommitGraphGenerationTests(unittest.TestCase):
         self.assertIsNone(provider_no_graph.commit_graph)
 
         # Should still work via commit object fallback
-        parents1_fallback = provider_no_graph.get_parents(commit1.id)
+        parents1_fallback = provider_no_graph.get_parents(commit1.id)  # type: ignore[no-untyped-call]
         self.assertEqual(parents1_fallback, [])
 
-        parents2_fallback = provider_no_graph.get_parents(commit2.id)
+        parents2_fallback = provider_no_graph.get_parents(commit2.id)  # type: ignore[no-untyped-call]
         self.assertEqual(parents2_fallback, [commit1.id])
 
-    def test_graph_operations_use_commit_graph(self):
+    def test_graph_operations_use_commit_graph(self) -> None:
         """Test that graph operations use commit graph when available."""
         from dulwich.graph import can_fast_forward, find_merge_base
         from dulwich.object_store import DiskObjectStore
@@ -695,7 +696,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
         repo2.object_store = object_store
 
         # Verify commit graph is available
-        commit_graph = repo2.object_store.get_commit_graph()
+        commit_graph = repo2.object_store.get_commit_graph()  # type: ignore[no-untyped-call]
         self.assertIsNotNone(commit_graph)
 
         # Test graph operations WITH commit graph
@@ -732,13 +733,14 @@ class CommitGraphGenerationTests(unittest.TestCase):
         )
 
         # Verify parent lookups work through the provider
-        self.assertEqual(parents_provider.get_parents(commit1.id), [])
-        self.assertEqual(parents_provider.get_parents(commit2.id), [commit1.id])
+        self.assertEqual(parents_provider.get_parents(commit1.id), [])  # type: ignore[no-untyped-call]
+        self.assertEqual(parents_provider.get_parents(commit2.id), [commit1.id])  # type: ignore[no-untyped-call]
         self.assertEqual(
-            parents_provider.get_parents(commit5.id), [commit3.id, commit4.id]
+            parents_provider.get_parents(commit5.id),
+            [commit3.id, commit4.id],  # type: ignore[no-untyped-call]
         )
 
-    def test_performance_with_commit_graph(self):
+    def test_performance_with_commit_graph(self) -> None:
         """Test that using commit graph provides performance benefits."""
         from dulwich.graph import find_merge_base
         from dulwich.object_store import DiskObjectStore
@@ -754,7 +756,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
         object_store.add_object(tree)
 
         # Create a chain of 20 commits
-        commits = []
+        commits: list[Commit] = []
         for i in range(20):
             commit = Commit()
             commit.tree = tree.id
@@ -785,7 +787,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
         repo2.object_store = object_store
 
         # Verify commit graph is loaded
-        self.assertIsNotNone(repo2.object_store.get_commit_graph())
+        self.assertIsNotNone(repo2.object_store.get_commit_graph())  # type: ignore[no-untyped-call]
 
         # Time operations with commit graph
         for _ in range(10):  # Run multiple times for better measurement

+ 46 - 35
tests/test_dumb.py

@@ -22,45 +22,53 @@
 """Tests for dumb HTTP git repositories."""
 
 import zlib
+from typing import Callable, Optional, Union
 from unittest import TestCase
 from unittest.mock import Mock
 
 from dulwich.dumb import DumbHTTPObjectStore, DumbRemoteHTTPRepo
 from dulwich.errors import NotGitRepository
-from dulwich.objects import Blob, Commit, Tag, Tree, sha_to_hex
+from dulwich.objects import Blob, Commit, ShaFile, Tag, Tree, sha_to_hex
 
 
 class MockResponse:
-    def __init__(self, status=200, content=b"", headers=None):
+    def __init__(
+        self,
+        status: int = 200,
+        content: bytes = b"",
+        headers: Optional[dict[str, str]] = None,
+    ) -> None:
         self.status = status
         self.content = content
         self.headers = headers or {}
         self.closed = False
 
-    def close(self):
+    def close(self) -> None:
         self.closed = True
 
 
 class DumbHTTPObjectStoreTests(TestCase):
     """Tests for DumbHTTPObjectStore."""
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.base_url = "https://example.com/repo.git/"
-        self.responses = {}
+        self.responses: dict[str, dict[str, Union[int, bytes]]] = {}
         self.store = DumbHTTPObjectStore(self.base_url, self._mock_http_request)
 
-    def _mock_http_request(self, url, headers):
+    def _mock_http_request(
+        self, url: str, headers: dict[str, str]
+    ) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
         """Mock HTTP request function."""
         if url in self.responses:
             resp_data = self.responses[url]
             resp = MockResponse(
-                resp_data.get("status", 200), resp_data.get("content", b"")
+                int(resp_data.get("status", 200)), bytes(resp_data.get("content", b""))
             )
             # Create a mock read function that behaves like urllib3's read
             content = resp.content
             offset = [0]  # Use list to make it mutable in closure
 
-            def read_func(size=None):
+            def read_func(size: Optional[int] = None) -> bytes:
                 if offset[0] >= len(content):
                     return b""
                 if size is None:
@@ -76,12 +84,12 @@ class DumbHTTPObjectStoreTests(TestCase):
             resp = MockResponse(404)
             return resp, lambda size: b""
 
-    def _add_response(self, path, content, status=200):
+    def _add_response(self, path: str, content: bytes, status: int = 200) -> None:
         """Add a mock response for a given path."""
         url = self.base_url + path
         self.responses[url] = {"status": status, "content": content}
 
-    def _make_object(self, obj):
+    def _make_object(self, obj: ShaFile) -> bytes:
         """Create compressed git object data."""
         type_name = {
             Blob.type_num: b"blob",
@@ -94,7 +102,7 @@ class DumbHTTPObjectStoreTests(TestCase):
         header = type_name + b" " + str(len(content)).encode() + b"\x00"
         return zlib.compress(header + content)
 
-    def test_fetch_loose_object_blob(self):
+    def test_fetch_loose_object_blob(self) -> None:
         # Create a blob object
         blob = Blob()
         blob.data = b"Hello, world!"
@@ -109,32 +117,33 @@ class DumbHTTPObjectStoreTests(TestCase):
         self.assertEqual(Blob.type_num, type_num)
         self.assertEqual(b"Hello, world!", content)
 
-    def test_fetch_loose_object_not_found(self):
+    def test_fetch_loose_object_not_found(self) -> None:
         hex_sha = b"1" * 40
         self.assertRaises(KeyError, self.store._fetch_loose_object, hex_sha)
 
-    def test_fetch_loose_object_invalid_format(self):
+    def test_fetch_loose_object_invalid_format(self) -> None:
         sha = b"1" * 20
         hex_sha = sha_to_hex(sha)
-        path = f"objects/{hex_sha[:2]}/{hex_sha[2:]}"
+        path = f"objects/{hex_sha[:2].decode('ascii')}/{hex_sha[2:].decode('ascii')}"
 
         # Add invalid compressed data
         self._add_response(path, b"invalid data")
 
         self.assertRaises(Exception, self.store._fetch_loose_object, sha)
 
-    def test_load_packs_empty(self):
+    def test_load_packs_empty(self) -> None:
         # No packs file
         self.store._load_packs()
         self.assertEqual([], self.store._packs)
 
-    def test_load_packs_with_entries(self):
+    def test_load_packs_with_entries(self) -> None:
         packs_content = b"""P pack-1234567890abcdef1234567890abcdef12345678.pack
 P pack-abcdef1234567890abcdef1234567890abcdef12.pack
 """
         self._add_response("objects/info/packs", packs_content)
 
         self.store._load_packs()
+        assert self.store._packs is not None
         self.assertEqual(2, len(self.store._packs))
         self.assertEqual(
             "pack-1234567890abcdef1234567890abcdef12345678", self.store._packs[0][0]
@@ -143,7 +152,7 @@ P pack-abcdef1234567890abcdef1234567890abcdef12.pack
             "pack-abcdef1234567890abcdef1234567890abcdef12", self.store._packs[1][0]
         )
 
-    def test_get_raw_from_cache(self):
+    def test_get_raw_from_cache(self) -> None:
         sha = b"1" * 40
         self.store._cached_objects[sha] = (Blob.type_num, b"cached content")
 
@@ -151,7 +160,7 @@ P pack-abcdef1234567890abcdef1234567890abcdef12.pack
         self.assertEqual(Blob.type_num, type_num)
         self.assertEqual(b"cached content", content)
 
-    def test_contains_loose(self):
+    def test_contains_loose(self) -> None:
         # Create a blob object
         blob = Blob()
         blob.data = b"Test blob"
@@ -164,35 +173,37 @@ P pack-abcdef1234567890abcdef1234567890abcdef12.pack
         self.assertTrue(self.store.contains_loose(hex_sha))
         self.assertFalse(self.store.contains_loose(b"0" * 40))
 
-    def test_add_object_not_implemented(self):
+    def test_add_object_not_implemented(self) -> None:
         blob = Blob()
         blob.data = b"test"
         self.assertRaises(NotImplementedError, self.store.add_object, blob)
 
-    def test_add_objects_not_implemented(self):
+    def test_add_objects_not_implemented(self) -> None:
         self.assertRaises(NotImplementedError, self.store.add_objects, [])
 
 
 class DumbRemoteHTTPRepoTests(TestCase):
     """Tests for DumbRemoteHTTPRepo."""
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.base_url = "https://example.com/repo.git/"
-        self.responses = {}
+        self.responses: dict[str, dict[str, Union[int, bytes]]] = {}
         self.repo = DumbRemoteHTTPRepo(self.base_url, self._mock_http_request)
 
-    def _mock_http_request(self, url, headers):
+    def _mock_http_request(
+        self, url: str, headers: dict[str, str]
+    ) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
         """Mock HTTP request function."""
         if url in self.responses:
             resp_data = self.responses[url]
             resp = MockResponse(
-                resp_data.get("status", 200), resp_data.get("content", b"")
+                int(resp_data.get("status", 200)), bytes(resp_data.get("content", b""))
             )
             # Create a mock read function that behaves like urllib3's read
             content = resp.content
             offset = [0]  # Use list to make it mutable in closure
 
-            def read_func(size=None):
+            def read_func(size: Optional[int] = None) -> bytes:
                 if offset[0] >= len(content):
                     return b""
                 if size is None:
@@ -208,12 +219,12 @@ class DumbRemoteHTTPRepoTests(TestCase):
             resp = MockResponse(404)
             return resp, lambda size: b""
 
-    def _add_response(self, path, content, status=200):
+    def _add_response(self, path: str, content: bytes, status: int = 200) -> None:
         """Add a mock response for a given path."""
         url = self.base_url + path
         self.responses[url] = {"status": status, "content": content}
 
-    def test_get_refs(self):
+    def test_get_refs(self) -> None:
         refs_content = b"""0123456789abcdef0123456789abcdef01234567\trefs/heads/master
 abcdef0123456789abcdef0123456789abcdef01\trefs/heads/develop
 fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
@@ -235,10 +246,10 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
             refs[b"refs/tags/v1.0"],
         )
 
-    def test_get_refs_not_found(self):
+    def test_get_refs_not_found(self) -> None:
         self.assertRaises(NotGitRepository, self.repo.get_refs)
 
-    def test_get_peeled(self):
+    def test_get_peeled(self) -> None:
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         self._add_response("info/refs", refs_content)
 
@@ -246,19 +257,19 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         peeled = self.repo.get_peeled(b"refs/heads/master")
         self.assertEqual(b"0123456789abcdef0123456789abcdef01234567", peeled)
 
-    def test_fetch_pack_data_no_wants(self):
+    def test_fetch_pack_data_no_wants(self) -> None:
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         self._add_response("info/refs", refs_content)
 
         graph_walker = Mock()
 
-        def determine_wants(refs):
+        def determine_wants(refs: dict[bytes, bytes]) -> list[bytes]:
             return []
 
         result = list(self.repo.fetch_pack_data(graph_walker, determine_wants))
         self.assertEqual([], result)
 
-    def test_fetch_pack_data_with_blob(self):
+    def test_fetch_pack_data_with_blob(self) -> None:
         # Set up refs
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         self._add_response("info/refs", refs_content)
@@ -277,10 +288,10 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         graph_walker = Mock()
         graph_walker.ack.return_value = []  # No existing objects
 
-        def determine_wants(refs):
+        def determine_wants(refs: dict[bytes, bytes]) -> list[bytes]:
             return [blob_sha]
 
-        def progress(msg):
+        def progress(msg: bytes) -> None:
             assert isinstance(msg, bytes)
 
         result = list(
@@ -290,6 +301,6 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         self.assertEqual(Blob.type_num, result[0].pack_type_num)
         self.assertEqual([blob.as_raw_string()], result[0].obj_chunks)
 
-    def test_object_store_property(self):
+    def test_object_store_property(self) -> None:
         self.assertIsInstance(self.repo.object_store, DumbHTTPObjectStore)
         self.assertEqual(self.base_url, self.repo.object_store.base_url)

+ 5 - 5
tests/test_pack.py

@@ -1465,7 +1465,7 @@ class DeltaChainIteratorTests(TestCase):
         entries = build_pack(f, [(REF_DELTA, (blob.id, b"blob1"))], store=self.store)
         pack_iter = self.make_pack_iter(f)
         self.assertEntriesMatch([0], entries, pack_iter)
-        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs())
+        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs)
 
     def test_ext_ref_chain(self) -> None:
         (blob,) = self.store_blobs([b"blob"])
@@ -1480,7 +1480,7 @@ class DeltaChainIteratorTests(TestCase):
         )
         pack_iter = self.make_pack_iter(f)
         self.assertEntriesMatch([1, 0], entries, pack_iter)
-        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs())
+        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs)
 
     def test_ext_ref_chain_degenerate(self) -> None:
         # Test a degenerate case where the sender is sending a REF_DELTA
@@ -1500,7 +1500,7 @@ class DeltaChainIteratorTests(TestCase):
         )
         pack_iter = self.make_pack_iter(f)
         self.assertEntriesMatch([0, 1], entries, pack_iter)
-        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs())
+        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs)
 
     def test_ext_ref_multiple_times(self) -> None:
         (blob,) = self.store_blobs([b"blob"])
@@ -1515,7 +1515,7 @@ class DeltaChainIteratorTests(TestCase):
         )
         pack_iter = self.make_pack_iter(f)
         self.assertEntriesMatch([0, 1], entries, pack_iter)
-        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs())
+        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs)
 
     def test_multiple_ext_refs(self) -> None:
         b1, b2 = self.store_blobs([b"foo", b"bar"])
@@ -1530,7 +1530,7 @@ class DeltaChainIteratorTests(TestCase):
         )
         pack_iter = self.make_pack_iter(f)
         self.assertEntriesMatch([0, 1], entries, pack_iter)
-        self.assertEqual([hex_to_sha(b1.id), hex_to_sha(b2.id)], pack_iter.ext_refs())
+        self.assertEqual([hex_to_sha(b1.id), hex_to_sha(b2.id)], pack_iter.ext_refs)
 
     def test_bad_ext_ref_non_thin_pack(self) -> None:
         (blob,) = self.store_blobs([b"blob"])

+ 1 - 1
tests/test_protocol.py

@@ -36,10 +36,10 @@ from dulwich.protocol import (
     ack_type,
     extract_capabilities,
     extract_want_line_capabilities,
-    filter_ref_prefix,
     pkt_line,
     pkt_seq,
 )
+from dulwich.refs import filter_ref_prefix
 
 from . import TestCase
 

Some files were not shown because too many files changed in this diff