Selaa lähdekoodia

Add more typing (#1750)

Jelmer Vernooij 5 kuukautta sitten
vanhempi
commit
7db9043418

+ 6 - 5
dulwich/annotate.py

@@ -27,8 +27,9 @@ Python's difflib.
 """
 """
 
 
 import difflib
 import difflib
-from typing import TYPE_CHECKING, Optional, cast
+from typing import TYPE_CHECKING, Optional
 
 
+from dulwich.objects import Blob
 from dulwich.walk import (
 from dulwich.walk import (
     ORDER_DATE,
     ORDER_DATE,
     Walker,
     Walker,
@@ -37,7 +38,7 @@ from dulwich.walk import (
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from dulwich.diff_tree import TreeChange, TreeEntry
     from dulwich.diff_tree import TreeChange, TreeEntry
     from dulwich.object_store import BaseObjectStore
     from dulwich.object_store import BaseObjectStore
-    from dulwich.objects import Blob, Commit
+    from dulwich.objects import Commit
 
 
 # Walk over ancestry graph breadth-first
 # Walk over ancestry graph breadth-first
 # When checking each revision, find lines that according to difflib.Differ()
 # When checking each revision, find lines that according to difflib.Differ()
@@ -108,7 +109,7 @@ def annotate_lines(
 
 
     lines_annotated: list[tuple[tuple[Commit, TreeEntry], bytes]] = []
     lines_annotated: list[tuple[tuple[Commit, TreeEntry], bytes]] = []
     for commit, entry in reversed(revs):
     for commit, entry in reversed(revs):
-        lines_annotated = update_lines(
-            lines_annotated, (commit, entry), cast("Blob", store[entry.sha])
-        )
+        blob_obj = store[entry.sha]
+        assert isinstance(blob_obj, Blob)
+        lines_annotated = update_lines(lines_annotated, (commit, entry), blob_obj)
     return lines_annotated
     return lines_annotated

+ 4 - 2
dulwich/bisect.py

@@ -49,7 +49,7 @@ class BisectState:
         self,
         self,
         bad: Optional[bytes] = None,
         bad: Optional[bytes] = None,
         good: Optional[list[bytes]] = None,
         good: Optional[list[bytes]] = None,
-        paths: Optional[list[str]] = None,
+        paths: Optional[list[bytes]] = None,
         no_checkout: bool = False,
         no_checkout: bool = False,
         term_bad: str = "bad",
         term_bad: str = "bad",
         term_good: str = "good",
         term_good: str = "good",
@@ -103,7 +103,9 @@ class BisectState:
         names_file = os.path.join(self.repo.controldir(), "BISECT_NAMES")
         names_file = os.path.join(self.repo.controldir(), "BISECT_NAMES")
         with open(names_file, "w") as f:
         with open(names_file, "w") as f:
             if paths:
             if paths:
-                f.write("\n".join(paths) + "\n")
+                f.write(
+                    "\n".join(path.decode("utf-8", "replace") for path in paths) + "\n"
+                )
             else:
             else:
                 f.write("\n")
                 f.write("\n")
 
 

+ 32 - 7
dulwich/bundle.py

@@ -22,9 +22,30 @@
 """Bundle format support."""
 """Bundle format support."""
 
 
 from collections.abc import Iterator
 from collections.abc import Iterator
-from typing import TYPE_CHECKING, Any, BinaryIO, Callable, Optional
+from typing import (
+    TYPE_CHECKING,
+    BinaryIO,
+    Callable,
+    Optional,
+    Protocol,
+    runtime_checkable,
+)
+
+from .pack import PackData, UnpackedObject, write_pack_data
+
+
+@runtime_checkable
+class PackDataLike(Protocol):
+    """Protocol for objects that behave like PackData."""
+
+    def __len__(self) -> int:
+        """Return the number of objects in the pack."""
+        ...
+
+    def iter_unpacked(self) -> Iterator[UnpackedObject]:
+        """Iterate over unpacked objects in the pack."""
+        ...
 
 
-from .pack import PackData, write_pack_data
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
     from .object_store import BaseObjectStore
@@ -39,7 +60,7 @@ class Bundle:
     capabilities: dict[str, Optional[str]]
     capabilities: dict[str, Optional[str]]
     prerequisites: list[tuple[bytes, bytes]]
     prerequisites: list[tuple[bytes, bytes]]
     references: dict[bytes, bytes]
     references: dict[bytes, bytes]
-    pack_data: PackData
+    pack_data: Optional[PackDataLike]
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         """Return string representation of Bundle."""
         """Return string representation of Bundle."""
@@ -79,10 +100,12 @@ class Bundle:
         """
         """
         from .objects import ShaFile
         from .objects import ShaFile
 
 
+        if self.pack_data is None:
+            raise ValueError("pack_data is not loaded")
         count = 0
         count = 0
         for unpacked in self.pack_data.iter_unpacked():
         for unpacked in self.pack_data.iter_unpacked():
             # Convert the unpacked object to a proper git object
             # Convert the unpacked object to a proper git object
-            if unpacked.decomp_chunks:
+            if unpacked.decomp_chunks and unpacked.obj_type_num is not None:
                 git_obj = ShaFile.from_raw_chunks(
                 git_obj = ShaFile.from_raw_chunks(
                     unpacked.obj_type_num, unpacked.decomp_chunks
                     unpacked.obj_type_num, unpacked.decomp_chunks
                 )
                 )
@@ -187,6 +210,8 @@ def write_bundle(f: BinaryIO, bundle: Bundle) -> None:
     for ref, obj_id in bundle.references.items():
     for ref, obj_id in bundle.references.items():
         f.write(obj_id + b" " + ref + b"\n")
         f.write(obj_id + b" " + ref + b"\n")
     f.write(b"\n")
     f.write(b"\n")
+    if bundle.pack_data is None:
+        raise ValueError("bundle.pack_data is not loaded")
     write_pack_data(
     write_pack_data(
         f.write,
         f.write,
         num_records=len(bundle.pack_data),
         num_records=len(bundle.pack_data),
@@ -283,14 +308,14 @@ def create_bundle_from_repo(
     # Store the pack objects directly, we'll write them when saving the bundle
     # Store the pack objects directly, we'll write them when saving the bundle
     # For now, create a simple wrapper to hold the data
     # For now, create a simple wrapper to hold the data
     class _BundlePackData:
     class _BundlePackData:
-        def __init__(self, count: int, objects: Iterator[Any]) -> None:
+        def __init__(self, count: int, objects: Iterator[UnpackedObject]) -> None:
             self._count = count
             self._count = count
             self._objects = list(objects)  # Materialize the iterator
             self._objects = list(objects)  # Materialize the iterator
 
 
         def __len__(self) -> int:
         def __len__(self) -> int:
             return self._count
             return self._count
 
 
-        def iter_unpacked(self) -> Iterator[Any]:
+        def iter_unpacked(self) -> Iterator[UnpackedObject]:
             return iter(self._objects)
             return iter(self._objects)
 
 
     pack_data = _BundlePackData(pack_count, pack_objects)
     pack_data = _BundlePackData(pack_count, pack_objects)
@@ -301,6 +326,6 @@ def create_bundle_from_repo(
     bundle.capabilities = capabilities
     bundle.capabilities = capabilities
     bundle.prerequisites = bundle_prerequisites
     bundle.prerequisites = bundle_prerequisites
     bundle.references = bundle_refs
     bundle.references = bundle_refs
-    bundle.pack_data = pack_data  # type: ignore[assignment]
+    bundle.pack_data = pack_data
 
 
     return bundle
     return bundle

+ 118 - 65
dulwich/cli.py

@@ -37,7 +37,7 @@ import subprocess
 import sys
 import sys
 import tempfile
 import tempfile
 from pathlib import Path
 from pathlib import Path
-from typing import Callable, ClassVar, Optional, Union
+from typing import BinaryIO, Callable, ClassVar, Optional, Union
 
 
 from dulwich import porcelain
 from dulwich import porcelain
 
 
@@ -45,17 +45,31 @@ from .bundle import create_bundle_from_repo, read_bundle, write_bundle
 from .client import GitProtocolError, get_transport_and_path
 from .client import GitProtocolError, get_transport_and_path
 from .errors import ApplyDeltaError
 from .errors import ApplyDeltaError
 from .index import Index
 from .index import Index
-from .objects import valid_hexsha
+from .objects import Commit, valid_hexsha
 from .objectspec import parse_commit_range
 from .objectspec import parse_commit_range
 from .pack import Pack, sha_to_hex
 from .pack import Pack, sha_to_hex
 from .repo import Repo
 from .repo import Repo
 
 
 
 
+def to_display_str(value: Union[bytes, str]) -> str:
+    """Convert a bytes or string value to a display string.
+
+    Args:
+        value: The value to convert (bytes or str)
+
+    Returns:
+        A string suitable for display
+    """
+    if isinstance(value, bytes):
+        return value.decode("utf-8", "replace")
+    return value
+
+
 class CommitMessageError(Exception):
 class CommitMessageError(Exception):
     """Raised when there's an issue with the commit message."""
     """Raised when there's an issue with the commit message."""
 
 
 
 
-def signal_int(signal, frame) -> None:
+def signal_int(signal: int, frame) -> None:
     """Handle interrupt signal by exiting.
     """Handle interrupt signal by exiting.
 
 
     Args:
     Args:
@@ -65,7 +79,7 @@ def signal_int(signal, frame) -> None:
     sys.exit(1)
     sys.exit(1)
 
 
 
 
-def signal_quit(signal, frame) -> None:
+def signal_quit(signal: int, frame) -> None:
     """Handle quit signal by entering debugger.
     """Handle quit signal by entering debugger.
 
 
     Args:
     Args:
@@ -77,7 +91,7 @@ def signal_quit(signal, frame) -> None:
     pdb.set_trace()
     pdb.set_trace()
 
 
 
 
-def parse_relative_time(time_str):
+def parse_relative_time(time_str: str) -> int:
     """Parse a relative time string like '2 weeks ago' into seconds.
     """Parse a relative time string like '2 weeks ago' into seconds.
 
 
     Args:
     Args:
@@ -126,7 +140,7 @@ def parse_relative_time(time_str):
         raise
         raise
 
 
 
 
-def format_bytes(bytes):
+def format_bytes(bytes: float) -> str:
     """Format bytes as human-readable string.
     """Format bytes as human-readable string.
 
 
     Args:
     Args:
@@ -142,7 +156,7 @@ def format_bytes(bytes):
     return f"{bytes:.1f} TB"
     return f"{bytes:.1f} TB"
 
 
 
 
-def launch_editor(template_content=b""):
+def launch_editor(template_content: bytes = b"") -> bytes:
     """Launch an editor for the user to enter text.
     """Launch an editor for the user to enter text.
 
 
     Args:
     Args:
@@ -176,7 +190,7 @@ def launch_editor(template_content=b""):
 class PagerBuffer:
 class PagerBuffer:
     """Binary buffer wrapper for Pager to mimic sys.stdout.buffer."""
     """Binary buffer wrapper for Pager to mimic sys.stdout.buffer."""
 
 
-    def __init__(self, pager):
+    def __init__(self, pager: "Pager") -> None:
         """Initialize PagerBuffer.
         """Initialize PagerBuffer.
 
 
         Args:
         Args:
@@ -184,40 +198,40 @@ class PagerBuffer:
         """
         """
         self.pager = pager
         self.pager = pager
 
 
-    def write(self, data: bytes):
+    def write(self, data: bytes) -> int:
         """Write bytes to pager."""
         """Write bytes to pager."""
         if isinstance(data, bytes):
         if isinstance(data, bytes):
             text = data.decode("utf-8", errors="replace")
             text = data.decode("utf-8", errors="replace")
             return self.pager.write(text)
             return self.pager.write(text)
         return self.pager.write(data)
         return self.pager.write(data)
 
 
-    def flush(self):
+    def flush(self) -> None:
         """Flush the pager."""
         """Flush the pager."""
         return self.pager.flush()
         return self.pager.flush()
 
 
-    def writelines(self, lines):
+    def writelines(self, lines) -> None:
         """Write multiple lines to pager."""
         """Write multiple lines to pager."""
         for line in lines:
         for line in lines:
             self.write(line)
             self.write(line)
 
 
-    def readable(self):
+    def readable(self) -> bool:
         """Return whether the buffer is readable (it's not)."""
         """Return whether the buffer is readable (it's not)."""
         return False
         return False
 
 
-    def writable(self):
+    def writable(self) -> bool:
         """Return whether the buffer is writable."""
         """Return whether the buffer is writable."""
         return not self.pager._closed
         return not self.pager._closed
 
 
-    def seekable(self):
+    def seekable(self) -> bool:
         """Return whether the buffer is seekable (it's not)."""
         """Return whether the buffer is seekable (it's not)."""
         return False
         return False
 
 
-    def close(self):
+    def close(self) -> None:
         """Close the pager."""
         """Close the pager."""
         return self.pager.close()
         return self.pager.close()
 
 
     @property
     @property
-    def closed(self):
+    def closed(self) -> bool:
         """Return whether the buffer is closed."""
         """Return whether the buffer is closed."""
         return self.pager.closed
         return self.pager.closed
 
 
@@ -225,13 +239,13 @@ class PagerBuffer:
 class Pager:
 class Pager:
     """File-like object that pages output through external pager programs."""
     """File-like object that pages output through external pager programs."""
 
 
-    def __init__(self, pager_cmd="cat"):
+    def __init__(self, pager_cmd: str = "cat") -> None:
         """Initialize Pager.
         """Initialize Pager.
 
 
         Args:
         Args:
             pager_cmd: Command to use for paging (default: "cat")
             pager_cmd: Command to use for paging (default: "cat")
         """
         """
-        self.pager_process = None
+        self.pager_process: Optional[subprocess.Popen] = None
         self.buffer = PagerBuffer(self)
         self.buffer = PagerBuffer(self)
         self._closed = False
         self._closed = False
         self.pager_cmd = pager_cmd
         self.pager_cmd = pager_cmd
@@ -241,7 +255,7 @@ class Pager:
         """Get the pager command to use."""
         """Get the pager command to use."""
         return self.pager_cmd
         return self.pager_cmd
 
 
-    def _ensure_pager_started(self):
+    def _ensure_pager_started(self) -> None:
         """Start the pager process if not already started."""
         """Start the pager process if not already started."""
         if self.pager_process is None and not self._closed:
         if self.pager_process is None and not self._closed:
             try:
             try:
@@ -280,7 +294,7 @@ class Pager:
             # No pager available, write directly to stdout
             # No pager available, write directly to stdout
             return sys.stdout.write(text)
             return sys.stdout.write(text)
 
 
-    def flush(self):
+    def flush(self) -> None:
         """Flush the pager."""
         """Flush the pager."""
         if self._closed or self._pager_died:
         if self._closed or self._pager_died:
             return
             return
@@ -293,7 +307,7 @@ class Pager:
         else:
         else:
             sys.stdout.flush()
             sys.stdout.flush()
 
 
-    def close(self):
+    def close(self) -> None:
         """Close the pager."""
         """Close the pager."""
         if self._closed:
         if self._closed:
             return
             return
@@ -308,16 +322,16 @@ class Pager:
                 pass
                 pass
             self.pager_process = None
             self.pager_process = None
 
 
-    def __enter__(self):
+    def __enter__(self) -> "Pager":
         """Context manager entry."""
         """Context manager entry."""
         return self
         return self
 
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
         """Context manager exit."""
         """Context manager exit."""
         self.close()
         self.close()
 
 
     # Additional file-like methods for compatibility
     # Additional file-like methods for compatibility
-    def writelines(self, lines):
+    def writelines(self, lines) -> None:
         """Write a list of lines to the pager."""
         """Write a list of lines to the pager."""
         if self._pager_died:
         if self._pager_died:
             return
             return
@@ -325,19 +339,19 @@ class Pager:
             self.write(line)
             self.write(line)
 
 
     @property
     @property
-    def closed(self):
+    def closed(self) -> bool:
         """Return whether the pager is closed."""
         """Return whether the pager is closed."""
         return self._closed
         return self._closed
 
 
-    def readable(self):
+    def readable(self) -> bool:
         """Return whether the pager is readable (it's not)."""
         """Return whether the pager is readable (it's not)."""
         return False
         return False
 
 
-    def writable(self):
+    def writable(self) -> bool:
         """Return whether the pager is writable."""
         """Return whether the pager is writable."""
         return not self._closed
         return not self._closed
 
 
-    def seekable(self):
+    def seekable(self) -> bool:
         """Return whether the pager is seekable (it's not)."""
         """Return whether the pager is seekable (it's not)."""
         return False
         return False
 
 
@@ -345,7 +359,7 @@ class Pager:
 class _StreamContextAdapter:
 class _StreamContextAdapter:
     """Adapter to make streams work with context manager protocol."""
     """Adapter to make streams work with context manager protocol."""
 
 
-    def __init__(self, stream):
+    def __init__(self, stream) -> None:
         self.stream = stream
         self.stream = stream
         # Expose buffer if it exists
         # Expose buffer if it exists
         if hasattr(stream, "buffer"):
         if hasattr(stream, "buffer"):
@@ -356,15 +370,15 @@ class _StreamContextAdapter:
     def __enter__(self):
     def __enter__(self):
         return self.stream
         return self.stream
 
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
         # For stdout/stderr, we don't close them
         # For stdout/stderr, we don't close them
         pass
         pass
 
 
-    def __getattr__(self, name):
+    def __getattr__(self, name: str):
         return getattr(self.stream, name)
         return getattr(self.stream, name)
 
 
 
 
-def get_pager(config=None, cmd_name=None):
+def get_pager(config=None, cmd_name: Optional[str] = None):
     """Get a pager instance if paging should be used, otherwise return sys.stdout.
     """Get a pager instance if paging should be used, otherwise return sys.stdout.
 
 
     Args:
     Args:
@@ -447,14 +461,14 @@ def get_pager(config=None, cmd_name=None):
     return Pager(pager_cmd)
     return Pager(pager_cmd)
 
 
 
 
-def disable_pager():
+def disable_pager() -> None:
     """Disable pager for this session."""
     """Disable pager for this session."""
-    get_pager._disabled = True
+    get_pager._disabled = True  # type: ignore[attr-defined]
 
 
 
 
-def enable_pager():
+def enable_pager() -> None:
     """Enable pager for this session."""
     """Enable pager for this session."""
-    get_pager._disabled = False
+    get_pager._disabled = False  # type: ignore[attr-defined]
 
 
 
 
 class Command:
 class Command:
@@ -491,10 +505,14 @@ class cmd_archive(Command):
                 write_error=sys.stderr.write,
                 write_error=sys.stderr.write,
             )
             )
         else:
         else:
-            # Use buffer if available (for binary output), otherwise use stdout
-            outstream = getattr(sys.stdout, "buffer", sys.stdout)
+            # Use binary buffer for archive output
+            outstream: BinaryIO = sys.stdout.buffer
+            errstream: BinaryIO = sys.stderr.buffer
             porcelain.archive(
             porcelain.archive(
-                ".", args.committish, outstream=outstream, errstream=sys.stderr
+                ".",
+                args.committish,
+                outstream=outstream,
+                errstream=errstream,
             )
             )
 
 
 
 
@@ -642,10 +660,11 @@ class cmd_fetch(Command):
         def progress(msg: bytes) -> None:
         def progress(msg: bytes) -> None:
             sys.stdout.buffer.write(msg)
             sys.stdout.buffer.write(msg)
 
 
-        refs = client.fetch(path, r, progress=progress)
+        result = client.fetch(path, r, progress=progress)
         print("Remote refs:")
         print("Remote refs:")
-        for item in refs.items():
-            print("{} -> {}".format(*item))
+        for ref, sha in result.refs.items():
+            if sha is not None:
+                print(f"{ref.decode()} -> {sha.decode()}")
 
 
 
 
 class cmd_for_each_ref(Command):
 class cmd_for_each_ref(Command):
@@ -676,7 +695,7 @@ class cmd_fsck(Command):
         parser = argparse.ArgumentParser()
         parser = argparse.ArgumentParser()
         parser.parse_args(args)
         parser.parse_args(args)
         for obj, msg in porcelain.fsck("."):
         for obj, msg in porcelain.fsck("."):
-            print(f"{obj}: {msg}")
+            print(f"{obj.decode() if isinstance(obj, bytes) else obj}: {msg}")
 
 
 
 
 class cmd_log(Command):
 class cmd_log(Command):
@@ -838,7 +857,7 @@ class cmd_dump_pack(Command):
 
 
         basename, _ = os.path.splitext(args.filename)
         basename, _ = os.path.splitext(args.filename)
         x = Pack(basename)
         x = Pack(basename)
-        print(f"Object names checksum: {x.name()}")
+        print(f"Object names checksum: {x.name().decode('ascii', 'replace')}")
         print(f"Checksum: {sha_to_hex(x.get_stored_checksum())!r}")
         print(f"Checksum: {sha_to_hex(x.get_stored_checksum())!r}")
         x.check()
         x.check()
         print(f"Length: {len(x)}")
         print(f"Length: {len(x)}")
@@ -846,9 +865,13 @@ class cmd_dump_pack(Command):
             try:
             try:
                 print(f"\t{x[name]}")
                 print(f"\t{x[name]}")
             except KeyError as k:
             except KeyError as k:
-                print(f"\t{name}: Unable to resolve base {k}")
+                print(
+                    f"\t{name.decode('ascii', 'replace')}: Unable to resolve base {k!r}"
+                )
             except ApplyDeltaError as e:
             except ApplyDeltaError as e:
-                print(f"\t{name}: Unable to apply delta: {e!r}")
+                print(
+                    f"\t{name.decode('ascii', 'replace')}: Unable to apply delta: {e!r}"
+                )
 
 
 
 
 class cmd_dump_index(Command):
 class cmd_dump_index(Command):
@@ -1302,9 +1325,17 @@ class cmd_reflog(Command):
 
 
                     for i, entry in enumerate(porcelain.reflog(repo, ref)):
                     for i, entry in enumerate(porcelain.reflog(repo, ref)):
                         # Format similar to git reflog
                         # Format similar to git reflog
+                        from dulwich.reflog import Entry
+
+                        assert isinstance(entry, Entry)
                         short_new = entry.new_sha[:8].decode("ascii")
                         short_new = entry.new_sha[:8].decode("ascii")
+                        message = (
+                            entry.message.decode("utf-8", "replace")
+                            if entry.message
+                            else ""
+                        )
                         outstream.write(
                         outstream.write(
-                            f"{short_new} {ref.decode('utf-8', 'replace')}@{{{i}}}: {entry.message.decode('utf-8', 'replace')}\n"
+                            f"{short_new} {ref.decode('utf-8', 'replace')}@{{{i}}}: {message}\n"
                         )
                         )
 
 
 
 
@@ -1538,11 +1569,14 @@ class cmd_ls_remote(Command):
         if args.symref:
         if args.symref:
             # Show symrefs first, like git does
             # Show symrefs first, like git does
             for ref, target in sorted(result.symrefs.items()):
             for ref, target in sorted(result.symrefs.items()):
-                sys.stdout.write(f"ref: {target.decode()}\t{ref.decode()}\n")
+                if target:
+                    sys.stdout.write(f"ref: {target.decode()}\t{ref.decode()}\n")
 
 
         # Show regular refs
         # Show regular refs
         for ref in sorted(result.refs):
         for ref in sorted(result.refs):
-            sys.stdout.write(f"{result.refs[ref].decode()}\t{ref.decode()}\n")
+            sha = result.refs[ref]
+            if sha is not None:
+                sys.stdout.write(f"{sha.decode()}\t{ref.decode()}\n")
 
 
 
 
 class cmd_ls_tree(Command):
 class cmd_ls_tree(Command):
@@ -1601,12 +1635,13 @@ class cmd_pack_objects(Command):
         if not args.stdout and not args.basename:
         if not args.stdout and not args.basename:
             parser.error("basename required when not using --stdout")
             parser.error("basename required when not using --stdout")
 
 
-        object_ids = [line.strip() for line in sys.stdin.readlines()]
+        object_ids = [line.strip().encode() for line in sys.stdin.readlines()]
         deltify = args.deltify
         deltify = args.deltify
         reuse_deltas = not args.no_reuse_deltas
         reuse_deltas = not args.no_reuse_deltas
 
 
         if args.stdout:
         if args.stdout:
             packf = getattr(sys.stdout, "buffer", sys.stdout)
             packf = getattr(sys.stdout, "buffer", sys.stdout)
+            assert isinstance(packf, BinaryIO)
             idxf = None
             idxf = None
             close = []
             close = []
         else:
         else:
@@ -2022,8 +2057,17 @@ class cmd_stash_list(Command):
         """
         """
         parser = argparse.ArgumentParser()
         parser = argparse.ArgumentParser()
         parser.parse_args(args)
         parser.parse_args(args)
-        for i, entry in porcelain.stash_list("."):
-            print("stash@{{{}}}: {}".format(i, entry.message.rstrip("\n")))
+        from .repo import Repo
+        from .stash import Stash
+
+        with Repo(".") as r:
+            stash = Stash.from_repo(r)
+            for i, entry in enumerate(stash.stashes()):
+                print(
+                    "stash@{{{}}}: {}".format(
+                        i, entry.message.decode("utf-8", "replace").rstrip("\n")
+                    )
+                )
 
 
 
 
 class cmd_stash_push(Command):
 class cmd_stash_push(Command):
@@ -2145,6 +2189,7 @@ class cmd_bisect(SuperCommand):
                         with open(bad_ref, "rb") as f:
                         with open(bad_ref, "rb") as f:
                             bad_sha = f.read().strip()
                             bad_sha = f.read().strip()
                         commit = r.object_store[bad_sha]
                         commit = r.object_store[bad_sha]
+                        assert isinstance(commit, Commit)
                         message = commit.message.decode(
                         message = commit.message.decode(
                             "utf-8", errors="replace"
                             "utf-8", errors="replace"
                         ).split("\n")[0]
                         ).split("\n")[0]
@@ -2173,7 +2218,7 @@ class cmd_bisect(SuperCommand):
                 print(log, end="")
                 print(log, end="")
 
 
             elif parsed_args.subcommand == "replay":
             elif parsed_args.subcommand == "replay":
-                porcelain.bisect_replay(log_file=parsed_args.logfile)
+                porcelain.bisect_replay(".", log_file=parsed_args.logfile)
                 print(f"Replayed bisect log from {parsed_args.logfile}")
                 print(f"Replayed bisect log from {parsed_args.logfile}")
 
 
             elif parsed_args.subcommand == "help":
             elif parsed_args.subcommand == "help":
@@ -2270,6 +2315,7 @@ class cmd_merge(Command):
             elif args.no_commit:
             elif args.no_commit:
                 print("Automatic merge successful; not committing as requested.")
                 print("Automatic merge successful; not committing as requested.")
             else:
             else:
+                assert merge_commit_id is not None
                 print(
                 print(
                     f"Merge successful. Created merge commit {merge_commit_id.decode()}"
                     f"Merge successful. Created merge commit {merge_commit_id.decode()}"
                 )
                 )
@@ -3129,12 +3175,14 @@ class cmd_lfs(Command):
             tracked = porcelain.lfs_untrack(patterns=args.patterns)
             tracked = porcelain.lfs_untrack(patterns=args.patterns)
             print("Remaining tracked patterns:")
             print("Remaining tracked patterns:")
             for pattern in tracked:
             for pattern in tracked:
-                print(f"  {pattern}")
+                print(f"  {to_display_str(pattern)}")
 
 
         elif args.subcommand == "ls-files":
         elif args.subcommand == "ls-files":
             files = porcelain.lfs_ls_files(ref=args.ref)
             files = porcelain.lfs_ls_files(ref=args.ref)
             for path, oid, size in files:
             for path, oid, size in files:
-                print(f"{oid[:12]} * {path} ({format_bytes(size)})")
+                print(
+                    f"{to_display_str(oid[:12])} * {to_display_str(path)} ({format_bytes(size)})"
+                )
 
 
         elif args.subcommand == "migrate":
         elif args.subcommand == "migrate":
             count = porcelain.lfs_migrate(
             count = porcelain.lfs_migrate(
@@ -3145,13 +3193,13 @@ class cmd_lfs(Command):
         elif args.subcommand == "pointer":
         elif args.subcommand == "pointer":
             if args.paths is not None:
             if args.paths is not None:
                 results = porcelain.lfs_pointer_check(paths=args.paths or None)
                 results = porcelain.lfs_pointer_check(paths=args.paths or None)
-                for path, pointer in results.items():
+                for file_path, pointer in results.items():
                     if pointer:
                     if pointer:
                         print(
                         print(
-                            f"{path}: LFS pointer (oid: {pointer.oid[:12]}, size: {format_bytes(pointer.size)})"
+                            f"{to_display_str(file_path)}: LFS pointer (oid: {to_display_str(pointer.oid[:12])}, size: {format_bytes(pointer.size)})"
                         )
                         )
                     else:
                     else:
-                        print(f"{path}: Not an LFS pointer")
+                        print(f"{to_display_str(file_path)}: Not an LFS pointer")
 
 
         elif args.subcommand == "clean":
         elif args.subcommand == "clean":
             pointer = porcelain.lfs_clean(path=args.path)
             pointer = porcelain.lfs_clean(path=args.path)
@@ -3188,13 +3236,13 @@ class cmd_lfs(Command):
 
 
             if status["missing"]:
             if status["missing"]:
                 print("\nMissing LFS objects:")
                 print("\nMissing LFS objects:")
-                for path in status["missing"]:
-                    print(f"  {path}")
+                for file_path in status["missing"]:
+                    print(f"  {to_display_str(file_path)}")
 
 
             if status["not_staged"]:
             if status["not_staged"]:
                 print("\nModified LFS files not staged:")
                 print("\nModified LFS files not staged:")
-                for path in status["not_staged"]:
-                    print(f"  {path}")
+                for file_path in status["not_staged"]:
+                    print(f"  {to_display_str(file_path)}")
 
 
             if not any(status.values()):
             if not any(status.values()):
                 print("No LFS files found.")
                 print("No LFS files found.")
@@ -3273,14 +3321,19 @@ class cmd_format_patch(Command):
         args = parser.parse_args(args)
         args = parser.parse_args(args)
 
 
         # Parse committish using the new function
         # Parse committish using the new function
-        committish = None
+        committish: Optional[Union[bytes, tuple[bytes, bytes]]] = None
         if args.committish:
         if args.committish:
             with Repo(".") as r:
             with Repo(".") as r:
                 range_result = parse_commit_range(r, args.committish)
                 range_result = parse_commit_range(r, args.committish)
                 if range_result:
                 if range_result:
-                    committish = range_result
+                    # Convert Commit objects to their SHAs
+                    committish = (range_result[0].id, range_result[1].id)
                 else:
                 else:
-                    committish = args.committish
+                    committish = (
+                        args.committish.encode()
+                        if isinstance(args.committish, str)
+                        else args.committish
+                    )
 
 
         filenames = porcelain.format_patch(
         filenames = porcelain.format_patch(
             ".",
             ".",

+ 144 - 123
dulwich/client.py

@@ -69,11 +69,11 @@ import dulwich
 
 
 from .config import Config, apply_instead_of, get_xdg_config_home_path
 from .config import Config, apply_instead_of, get_xdg_config_home_path
 from .errors import GitProtocolError, NotGitRepository, SendPackError
 from .errors import GitProtocolError, NotGitRepository, SendPackError
+from .object_store import GraphWalker
 from .pack import (
 from .pack import (
     PACK_SPOOL_FILE_MAX_SIZE,
     PACK_SPOOL_FILE_MAX_SIZE,
     PackChunkGenerator,
     PackChunkGenerator,
     PackData,
     PackData,
-    UnpackedObject,
     write_pack_from_container,
     write_pack_from_container,
 )
 )
 from .protocol import (
 from .protocol import (
@@ -116,7 +116,6 @@ from .protocol import (
     capability_agent,
     capability_agent,
     extract_capabilities,
     extract_capabilities,
     extract_capability_names,
     extract_capability_names,
-    filter_ref_prefix,
     parse_capability,
     parse_capability,
     pkt_line,
     pkt_line,
     pkt_seq,
     pkt_seq,
@@ -129,6 +128,7 @@ from .refs import (
     _set_default_branch,
     _set_default_branch,
     _set_head,
     _set_head,
     _set_origin_head,
     _set_origin_head,
+    filter_ref_prefix,
     read_info_refs,
     read_info_refs,
     split_peeled_refs,
     split_peeled_refs,
 )
 )
@@ -149,7 +149,7 @@ logger = logging.getLogger(__name__)
 class InvalidWants(Exception):
 class InvalidWants(Exception):
     """Invalid wants."""
     """Invalid wants."""
 
 
-    def __init__(self, wants) -> None:
+    def __init__(self, wants: set[bytes]) -> None:
         """Initialize InvalidWants exception.
         """Initialize InvalidWants exception.
 
 
         Args:
         Args:
@@ -163,7 +163,7 @@ class InvalidWants(Exception):
 class HTTPUnauthorized(Exception):
 class HTTPUnauthorized(Exception):
     """Raised when authentication fails."""
     """Raised when authentication fails."""
 
 
-    def __init__(self, www_authenticate, url) -> None:
+    def __init__(self, www_authenticate: Optional[str], url: str) -> None:
         """Initialize HTTPUnauthorized exception.
         """Initialize HTTPUnauthorized exception.
 
 
         Args:
         Args:
@@ -178,7 +178,7 @@ class HTTPUnauthorized(Exception):
 class HTTPProxyUnauthorized(Exception):
 class HTTPProxyUnauthorized(Exception):
     """Raised when proxy authentication fails."""
     """Raised when proxy authentication fails."""
 
 
-    def __init__(self, proxy_authenticate, url) -> None:
+    def __init__(self, proxy_authenticate: Optional[str], url: str) -> None:
         """Initialize HTTPProxyUnauthorized exception.
         """Initialize HTTPProxyUnauthorized exception.
 
 
         Args:
         Args:
@@ -190,22 +190,28 @@ class HTTPProxyUnauthorized(Exception):
         self.url = url
         self.url = url
 
 
 
 
-def _fileno_can_read(fileno):
+def _fileno_can_read(fileno: int) -> bool:
     """Check if a file descriptor is readable."""
     """Check if a file descriptor is readable."""
     return len(select.select([fileno], [], [], 0)[0]) > 0
     return len(select.select([fileno], [], [], 0)[0]) > 0
 
 
 
 
-def _win32_peek_avail(handle):
+def _win32_peek_avail(handle: int) -> int:
     """Wrapper around PeekNamedPipe to check how many bytes are available."""
     """Wrapper around PeekNamedPipe to check how many bytes are available."""
-    from ctypes import byref, windll, wintypes
+    from ctypes import (  # type: ignore[attr-defined]
+        byref,
+        windll,  # type: ignore[attr-defined]
+        wintypes,
+    )
 
 
     c_avail = wintypes.DWORD()
     c_avail = wintypes.DWORD()
     c_message = wintypes.DWORD()
     c_message = wintypes.DWORD()
-    success = windll.kernel32.PeekNamedPipe(
+    success = windll.kernel32.PeekNamedPipe(  # type: ignore[attr-defined]
         handle, None, 0, None, byref(c_avail), byref(c_message)
         handle, None, 0, None, byref(c_avail), byref(c_message)
     )
     )
     if not success:
     if not success:
-        raise OSError(wintypes.GetLastError())
+        from ctypes import GetLastError  # type: ignore[attr-defined]
+
+        raise OSError(GetLastError())
     return c_avail.value
     return c_avail.value
 
 
 
 
@@ -230,10 +236,10 @@ class ReportStatusParser:
     def __init__(self) -> None:
     def __init__(self) -> None:
         """Initialize ReportStatusParser."""
         """Initialize ReportStatusParser."""
         self._done = False
         self._done = False
-        self._pack_status = None
+        self._pack_status: Optional[bytes] = None
         self._ref_statuses: list[bytes] = []
         self._ref_statuses: list[bytes] = []
 
 
-    def check(self):
+    def check(self) -> Iterator[tuple[bytes, Optional[str]]]:
         """Check if there were any errors and, if so, raise exceptions.
         """Check if there were any errors and, if so, raise exceptions.
 
 
         Raises:
         Raises:
@@ -257,7 +263,7 @@ class ReportStatusParser:
             else:
             else:
                 raise GitProtocolError(f"invalid ref status {status!r}")
                 raise GitProtocolError(f"invalid ref status {status!r}")
 
 
-    def handle_packet(self, pkt) -> None:
+    def handle_packet(self, pkt: Optional[bytes]) -> None:
         """Handle a packet.
         """Handle a packet.
 
 
         Raises:
         Raises:
@@ -276,13 +282,8 @@ class ReportStatusParser:
             self._ref_statuses.append(ref_status)
             self._ref_statuses.append(ref_status)
 
 
 
 
-def negotiate_protocol_version(proto) -> int:
-    """Negotiate the protocol version to use.
-
-    Args:
-      proto: Protocol instance to negotiate with
-    Returns: Protocol version (0, 1, or 2)
-    """
+def negotiate_protocol_version(proto: Protocol) -> int:
+    """Negotiate protocol version with the server."""
     pkt = proto.read_pkt_line()
     pkt = proto.read_pkt_line()
     if pkt is not None and pkt.strip() == b"version 2":
     if pkt is not None and pkt.strip() == b"version 2":
         return 2
         return 2
@@ -290,13 +291,8 @@ def negotiate_protocol_version(proto) -> int:
     return 0
     return 0
 
 
 
 
-def read_server_capabilities(pkt_seq):
-    """Read server capabilities from a packet sequence.
-
-    Args:
-      pkt_seq: Sequence of packets from server
-    Returns: Set of server capabilities
-    """
+def read_server_capabilities(pkt_seq: Iterable[bytes]) -> set[bytes]:
+    """Read server capabilities from packet sequence."""
     server_capabilities = []
     server_capabilities = []
     for pkt in pkt_seq:
     for pkt in pkt_seq:
         server_capabilities.append(pkt)
         server_capabilities.append(pkt)
@@ -304,21 +300,16 @@ def read_server_capabilities(pkt_seq):
 
 
 
 
 def read_pkt_refs_v2(
 def read_pkt_refs_v2(
-    pkt_seq,
-) -> tuple[dict[bytes, bytes], dict[bytes, bytes], dict[bytes, bytes]]:
-    """Read packet references in protocol v2 format.
-
-    Args:
-      pkt_seq: Sequence of packets
-    Returns: Tuple of (refs dict, symrefs dict, peeled dict)
-    """
-    refs = {}
+    pkt_seq: Iterable[bytes],
+) -> tuple[dict[bytes, Optional[bytes]], dict[bytes, bytes], dict[bytes, bytes]]:
+    """Read references using protocol version 2."""
+    refs: dict[bytes, Optional[bytes]] = {}
     symrefs = {}
     symrefs = {}
     peeled = {}
     peeled = {}
     # Receive refs from server
     # Receive refs from server
     for pkt in pkt_seq:
     for pkt in pkt_seq:
         parts = pkt.rstrip(b"\n").split(b" ")
         parts = pkt.rstrip(b"\n").split(b" ")
-        sha = parts[0]
+        sha: Optional[bytes] = parts[0]
         if sha == b"unborn":
         if sha == b"unborn":
             sha = None
             sha = None
         ref = parts[1]
         ref = parts[1]
@@ -334,15 +325,12 @@ def read_pkt_refs_v2(
     return refs, symrefs, peeled
     return refs, symrefs, peeled
 
 
 
 
-def read_pkt_refs_v1(pkt_seq) -> tuple[dict[bytes, bytes], set[bytes]]:
-    """Read packet references in protocol v1 format.
-
-    Args:
-      pkt_seq: Sequence of packets
-    Returns: Tuple of (refs dict, server capabilities set)
-    """
+def read_pkt_refs_v1(
+    pkt_seq: Iterable[bytes],
+) -> tuple[dict[bytes, Optional[bytes]], set[bytes]]:
+    """Read references using protocol version 1."""
     server_capabilities = None
     server_capabilities = None
-    refs = {}
+    refs: dict[bytes, Optional[bytes]] = {}
     # Receive refs from server
     # Receive refs from server
     for pkt in pkt_seq:
     for pkt in pkt_seq:
         (sha, ref) = pkt.rstrip(b"\n").split(None, 1)
         (sha, ref) = pkt.rstrip(b"\n").split(None, 1)
@@ -363,6 +351,8 @@ def read_pkt_refs_v1(pkt_seq) -> tuple[dict[bytes, bytes], set[bytes]]:
 class _DeprecatedDictProxy:
 class _DeprecatedDictProxy:
     """Base class for result objects that provide deprecated dict-like interface."""
     """Base class for result objects that provide deprecated dict-like interface."""
 
 
+    refs: dict[bytes, Optional[bytes]]  # To be overridden by subclasses
+
     _FORWARDED_ATTRS: ClassVar[set[str]] = {
     _FORWARDED_ATTRS: ClassVar[set[str]] = {
         "clear",
         "clear",
         "copy",
         "copy",
@@ -389,11 +379,11 @@ class _DeprecatedDictProxy:
             stacklevel=3,
             stacklevel=3,
         )
         )
 
 
-    def __contains__(self, name) -> bool:
+    def __contains__(self, name: bytes) -> bool:
         self._warn_deprecated()
         self._warn_deprecated()
         return name in self.refs
         return name in self.refs
 
 
-    def __getitem__(self, name):
+    def __getitem__(self, name: bytes) -> Optional[bytes]:
         self._warn_deprecated()
         self._warn_deprecated()
         return self.refs[name]
         return self.refs[name]
 
 
@@ -401,11 +391,11 @@ class _DeprecatedDictProxy:
         self._warn_deprecated()
         self._warn_deprecated()
         return len(self.refs)
         return len(self.refs)
 
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[bytes]:
         self._warn_deprecated()
         self._warn_deprecated()
         return iter(self.refs)
         return iter(self.refs)
 
 
-    def __getattribute__(self, name):
+    def __getattribute__(self, name: str) -> object:
         # Avoid infinite recursion by checking against class variable directly
         # Avoid infinite recursion by checking against class variable directly
         if name != "_FORWARDED_ATTRS" and name in type(self)._FORWARDED_ATTRS:
         if name != "_FORWARDED_ATTRS" and name in type(self)._FORWARDED_ATTRS:
             self._warn_deprecated()
             self._warn_deprecated()
@@ -424,8 +414,16 @@ class FetchPackResult(_DeprecatedDictProxy):
       agent: User agent string
       agent: User agent string
     """
     """
 
 
+    symrefs: dict[bytes, bytes]
+    agent: Optional[bytes]
+
     def __init__(
     def __init__(
-        self, refs, symrefs, agent, new_shallow=None, new_unshallow=None
+        self,
+        refs: dict[bytes, Optional[bytes]],
+        symrefs: dict[bytes, bytes],
+        agent: Optional[bytes],
+        new_shallow: Optional[set[bytes]] = None,
+        new_unshallow: Optional[set[bytes]] = None,
     ) -> None:
     ) -> None:
         """Initialize FetchPackResult.
         """Initialize FetchPackResult.
 
 
@@ -442,11 +440,13 @@ class FetchPackResult(_DeprecatedDictProxy):
         self.new_shallow = new_shallow
         self.new_shallow = new_shallow
         self.new_unshallow = new_unshallow
         self.new_unshallow = new_unshallow
 
 
-    def __eq__(self, other):
-        """Check equality with another FetchPackResult."""
+    def __eq__(self, other: object) -> bool:
+        """Check equality with another object."""
         if isinstance(other, dict):
         if isinstance(other, dict):
             self._warn_deprecated()
             self._warn_deprecated()
             return self.refs == other
             return self.refs == other
+        if not isinstance(other, FetchPackResult):
+            return False
         return (
         return (
             self.refs == other.refs
             self.refs == other.refs
             and self.symrefs == other.symrefs
             and self.symrefs == other.symrefs
@@ -466,7 +466,11 @@ class LsRemoteResult(_DeprecatedDictProxy):
       symrefs: Dictionary with remote symrefs
       symrefs: Dictionary with remote symrefs
     """
     """
 
 
-    def __init__(self, refs, symrefs) -> None:
+    symrefs: dict[bytes, bytes]
+
+    def __init__(
+        self, refs: dict[bytes, Optional[bytes]], symrefs: dict[bytes, bytes]
+    ) -> None:
         """Initialize LsRemoteResult.
         """Initialize LsRemoteResult.
 
 
         Args:
         Args:
@@ -486,11 +490,13 @@ class LsRemoteResult(_DeprecatedDictProxy):
             stacklevel=3,
             stacklevel=3,
         )
         )
 
 
-    def __eq__(self, other):
-        """Check equality with another LsRemoteResult."""
+    def __eq__(self, other: object) -> bool:
+        """Check equality with another object."""
         if isinstance(other, dict):
         if isinstance(other, dict):
             self._warn_deprecated()
             self._warn_deprecated()
             return self.refs == other
             return self.refs == other
+        if not isinstance(other, LsRemoteResult):
+            return False
         return self.refs == other.refs and self.symrefs == other.symrefs
         return self.refs == other.refs and self.symrefs == other.symrefs
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
@@ -508,7 +514,12 @@ class SendPackResult(_DeprecatedDictProxy):
         failed to update), or None if it was updated successfully
         failed to update), or None if it was updated successfully
     """
     """
 
 
-    def __init__(self, refs, agent=None, ref_status=None) -> None:
+    def __init__(
+        self,
+        refs: dict[bytes, Optional[bytes]],
+        agent: Optional[bytes] = None,
+        ref_status: Optional[dict[bytes, Optional[str]]] = None,
+    ) -> None:
         """Initialize SendPackResult.
         """Initialize SendPackResult.
 
 
         Args:
         Args:
@@ -520,11 +531,13 @@ class SendPackResult(_DeprecatedDictProxy):
         self.agent = agent
         self.agent = agent
         self.ref_status = ref_status
         self.ref_status = ref_status
 
 
-    def __eq__(self, other):
-        """Check equality with another SendPackResult."""
+    def __eq__(self, other: object) -> bool:
+        """Check equality with another object."""
         if isinstance(other, dict):
         if isinstance(other, dict):
             self._warn_deprecated()
             self._warn_deprecated()
             return self.refs == other
             return self.refs == other
+        if not isinstance(other, SendPackResult):
+            return False
         return self.refs == other.refs and self.agent == other.agent
         return self.refs == other.refs and self.agent == other.agent
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
@@ -532,13 +545,7 @@ class SendPackResult(_DeprecatedDictProxy):
         return f"{self.__class__.__name__}({self.refs!r}, {self.agent!r})"
         return f"{self.__class__.__name__}({self.refs!r}, {self.agent!r})"
 
 
 
 
-def _read_shallow_updates(pkt_seq):
-    """Read shallow/unshallow updates from a packet sequence.
-
-    Args:
-      pkt_seq: Sequence of packets
-    Returns: Tuple of (new_shallow set, new_unshallow set)
-    """
+def _read_shallow_updates(pkt_seq: Iterable[bytes]) -> tuple[set[bytes], set[bytes]]:
     new_shallow = set()
     new_shallow = set()
     new_unshallow = set()
     new_unshallow = set()
     for pkt in pkt_seq:
     for pkt in pkt_seq:
@@ -547,30 +554,29 @@ def _read_shallow_updates(pkt_seq):
         try:
         try:
             cmd, sha = pkt.split(b" ", 1)
             cmd, sha = pkt.split(b" ", 1)
         except ValueError:
         except ValueError:
-            raise GitProtocolError(f"unknown command {pkt}")
+            raise GitProtocolError(f"unknown command {pkt!r}")
         if cmd == COMMAND_SHALLOW:
         if cmd == COMMAND_SHALLOW:
             new_shallow.add(sha.strip())
             new_shallow.add(sha.strip())
         elif cmd == COMMAND_UNSHALLOW:
         elif cmd == COMMAND_UNSHALLOW:
             new_unshallow.add(sha.strip())
             new_unshallow.add(sha.strip())
         else:
         else:
-            raise GitProtocolError(f"unknown command {pkt}")
+            raise GitProtocolError(f"unknown command {pkt!r}")
     return (new_shallow, new_unshallow)
     return (new_shallow, new_unshallow)
 
 
 
 
 class _v1ReceivePackHeader:
 class _v1ReceivePackHeader:
-    """Handler for v1 receive-pack header."""
-
-    def __init__(self, capabilities, old_refs, new_refs) -> None:
-        self.want: list[bytes] = []
-        self.have: list[bytes] = []
+    def __init__(self, capabilities: list, old_refs: dict, new_refs: dict) -> None:
+        self.want: set[bytes] = set()
+        self.have: set[bytes] = set()
         self._it = self._handle_receive_pack_head(capabilities, old_refs, new_refs)
         self._it = self._handle_receive_pack_head(capabilities, old_refs, new_refs)
         self.sent_capabilities = False
         self.sent_capabilities = False
 
 
-    def __iter__(self):
-        """Iterate over the receive-pack header lines."""
+    def __iter__(self) -> Iterator[Optional[bytes]]:
         return self._it
         return self._it
 
 
-    def _handle_receive_pack_head(self, capabilities, old_refs, new_refs):
+    def _handle_receive_pack_head(
+        self, capabilities: list, old_refs: dict, new_refs: dict
+    ) -> Iterator[Optional[bytes]]:
         """Handle the head of a 'git-receive-pack' request.
         """Handle the head of a 'git-receive-pack' request.
 
 
         Args:
         Args:
@@ -581,7 +587,7 @@ class _v1ReceivePackHeader:
         Returns:
         Returns:
           (have, want) tuple
           (have, want) tuple
         """
         """
-        self.have = [x for x in old_refs.values() if not x == ZERO_SHA]
+        self.have = {x for x in old_refs.values() if not x == ZERO_SHA}
 
 
         for refname in new_refs:
         for refname in new_refs:
             if not isinstance(refname, bytes):
             if not isinstance(refname, bytes):
@@ -615,7 +621,7 @@ class _v1ReceivePackHeader:
                     )
                     )
                     self.sent_capabilities = True
                     self.sent_capabilities = True
             if new_sha1 not in self.have and new_sha1 != ZERO_SHA:
             if new_sha1 not in self.have and new_sha1 != ZERO_SHA:
-                self.want.append(new_sha1)
+                self.want.add(new_sha1)
         yield None
         yield None
 
 
 
 
@@ -632,33 +638,29 @@ def _read_side_band64k_data(pkt_seq: Iterable[bytes]) -> Iterator[tuple[int, byt
         yield channel, pkt[1:]
         yield channel, pkt[1:]
 
 
 
 
-def find_capability(capabilities, key, value):
-    """Find a capability in the list of capabilities.
-
-    Args:
-      capabilities: List of capabilities
-      key: Capability key to search for
-      value: Optional specific value to match
-    Returns: The matching capability or None
-    """
+def find_capability(
+    capabilities: list, key: bytes, value: Optional[bytes]
+) -> Optional[bytes]:
+    """Find a capability with a specific key and value."""
     for capability in capabilities:
     for capability in capabilities:
         k, v = parse_capability(capability)
         k, v = parse_capability(capability)
         if k != key:
         if k != key:
             continue
             continue
-        if value and value not in v.split(b" "):
+        if value and v and value not in v.split(b" "):
             continue
             continue
         return capability
         return capability
+    return None
 
 
 
 
 def _handle_upload_pack_head(
 def _handle_upload_pack_head(
-    proto,
-    capabilities,
-    graph_walker,
-    wants,
-    can_read,
+    proto: Protocol,
+    capabilities: list,
+    graph_walker: GraphWalker,
+    wants: list,
+    can_read: Optional[Callable],
     depth: Optional[int],
     depth: Optional[int],
-    protocol_version,
-):
+    protocol_version: Optional[int],
+) -> tuple[Optional[set[bytes]], Optional[set[bytes]]]:
     """Handle the head of a 'git-upload-pack' request.
     """Handle the head of a 'git-upload-pack' request.
 
 
     Args:
     Args:
@@ -671,6 +673,8 @@ def _handle_upload_pack_head(
       depth: Depth for request
       depth: Depth for request
       protocol_version: Neogiated Git protocol version.
       protocol_version: Neogiated Git protocol version.
     """
     """
+    new_shallow: Optional[set[bytes]]
+    new_unshallow: Optional[set[bytes]]
     assert isinstance(wants, list) and isinstance(wants[0], bytes)
     assert isinstance(wants, list) and isinstance(wants[0], bytes)
     wantcmd = COMMAND_WANT + b" " + wants[0]
     wantcmd = COMMAND_WANT + b" " + wants[0]
     if protocol_version is None:
     if protocol_version is None:
@@ -681,7 +685,9 @@ def _handle_upload_pack_head(
     proto.write_pkt_line(wantcmd)
     proto.write_pkt_line(wantcmd)
     for want in wants[1:]:
     for want in wants[1:]:
         proto.write_pkt_line(COMMAND_WANT + b" " + want + b"\n")
         proto.write_pkt_line(COMMAND_WANT + b" " + want + b"\n")
-    if depth not in (0, None) or graph_walker.shallow:
+    if depth not in (0, None) or (
+        hasattr(graph_walker, "shallow") and graph_walker.shallow
+    ):
         if protocol_version == 2:
         if protocol_version == 2:
             if not find_capability(capabilities, CAPABILITY_FETCH, CAPABILITY_SHALLOW):
             if not find_capability(capabilities, CAPABILITY_FETCH, CAPABILITY_SHALLOW):
                 raise GitProtocolError(
                 raise GitProtocolError(
@@ -691,8 +697,9 @@ def _handle_upload_pack_head(
             raise GitProtocolError(
             raise GitProtocolError(
                 "server does not support shallow capability required for depth"
                 "server does not support shallow capability required for depth"
             )
             )
-        for sha in graph_walker.shallow:
-            proto.write_pkt_line(COMMAND_SHALLOW + b" " + sha + b"\n")
+        if hasattr(graph_walker, "shallow"):
+            for sha in graph_walker.shallow:
+                proto.write_pkt_line(COMMAND_SHALLOW + b" " + sha + b"\n")
         if depth is not None:
         if depth is not None:
             proto.write_pkt_line(
             proto.write_pkt_line(
                 COMMAND_DEEPEN + b" " + str(depth).encode("ascii") + b"\n"
                 COMMAND_DEEPEN + b" " + str(depth).encode("ascii") + b"\n"
@@ -705,6 +712,7 @@ def _handle_upload_pack_head(
         proto.write_pkt_line(COMMAND_HAVE + b" " + have + b"\n")
         proto.write_pkt_line(COMMAND_HAVE + b" " + have + b"\n")
         if can_read is not None and can_read():
         if can_read is not None and can_read():
             pkt = proto.read_pkt_line()
             pkt = proto.read_pkt_line()
+            assert pkt is not None
             parts = pkt.rstrip(b"\n").split(b" ")
             parts = pkt.rstrip(b"\n").split(b" ")
             if parts[0] == b"ACK":
             if parts[0] == b"ACK":
                 graph_walker.ack(parts[1])
                 graph_walker.ack(parts[1])
@@ -714,7 +722,7 @@ def _handle_upload_pack_head(
                     break
                     break
                 else:
                 else:
                     raise AssertionError(
                     raise AssertionError(
-                        f"{parts[2]} not in ('continue', 'ready', 'common)"
+                        f"{parts[2]!r} not in ('continue', 'ready', 'common)"
                     )
                     )
         have = next(graph_walker)
         have = next(graph_walker)
     proto.write_pkt_line(COMMAND_DONE + b"\n")
     proto.write_pkt_line(COMMAND_DONE + b"\n")
@@ -725,7 +733,8 @@ def _handle_upload_pack_head(
         if can_read is not None:
         if can_read is not None:
             (new_shallow, new_unshallow) = _read_shallow_updates(proto.read_pkt_seq())
             (new_shallow, new_unshallow) = _read_shallow_updates(proto.read_pkt_seq())
         else:
         else:
-            new_shallow = new_unshallow = None
+            new_shallow = None
+            new_unshallow = None
     else:
     else:
         new_shallow = new_unshallow = set()
         new_shallow = new_unshallow = set()
 
 
@@ -773,7 +782,7 @@ def _handle_upload_pack_tail(
         if progress is None:
         if progress is None:
             # Just ignore progress data
             # Just ignore progress data
 
 
-            def progress(x) -> None:
+            def progress(x: bytes) -> None:
                 pass
                 pass
 
 
         for chan, data in _read_side_band64k_data(proto.read_pkt_seq()):
         for chan, data in _read_side_band64k_data(proto.read_pkt_seq()):
@@ -804,6 +813,7 @@ def _extract_symrefs_and_agent(capabilities):
     for capability in capabilities:
     for capability in capabilities:
         k, v = parse_capability(capability)
         k, v = parse_capability(capability)
         if k == CAPABILITY_SYMREF:
         if k == CAPABILITY_SYMREF:
+            assert v is not None
             (src, dst) = v.split(b":", 1)
             (src, dst) = v.split(b":", 1)
             symrefs[src] = dst
             symrefs[src] = dst
         if k == CAPABILITY_AGENT:
         if k == CAPABILITY_AGENT:
@@ -879,9 +889,7 @@ class GitClient:
         self,
         self,
         path: str,
         path: str,
         update_refs,
         update_refs,
-        generate_pack_data: Callable[
-            [set[bytes], set[bytes], bool], tuple[int, Iterator[UnpackedObject]]
-        ],
+        generate_pack_data,
         progress=None,
         progress=None,
     ) -> SendPackResult:
     ) -> SendPackResult:
         """Upload a pack to a remote repository.
         """Upload a pack to a remote repository.
@@ -972,8 +980,11 @@ class GitClient:
             origin_sha = result.refs.get(b"HEAD")
             origin_sha = result.refs.get(b"HEAD")
             if origin is None or (origin_sha and not origin_head):
             if origin is None or (origin_sha and not origin_head):
                 # set detached HEAD
                 # set detached HEAD
-                target.refs[b"HEAD"] = origin_sha
-                head = origin_sha
+                if origin_sha is not None:
+                    target.refs[b"HEAD"] = origin_sha
+                    head = origin_sha
+                else:
+                    head = None
             else:
             else:
                 _set_origin_head(target.refs, origin.encode("utf-8"), origin_head)
                 _set_origin_head(target.refs, origin.encode("utf-8"), origin_head)
                 head_ref = _set_default_branch(
                 head_ref = _set_default_branch(
@@ -1203,10 +1214,11 @@ class GitClient:
             if self.protocol_version == 2 and k == CAPABILITY_FETCH:
             if self.protocol_version == 2 and k == CAPABILITY_FETCH:
                 fetch_capa = CAPABILITY_FETCH
                 fetch_capa = CAPABILITY_FETCH
                 fetch_features = []
                 fetch_features = []
-                v = v.strip().split(b" ")
-                if b"shallow" in v:
+                assert v is not None
+                v_list = v.strip().split(b" ")
+                if b"shallow" in v_list:
                     fetch_features.append(CAPABILITY_SHALLOW)
                     fetch_features.append(CAPABILITY_SHALLOW)
-                if b"filter" in v:
+                if b"filter" in v_list:
                     fetch_features.append(CAPABILITY_FILTER)
                     fetch_features.append(CAPABILITY_FILTER)
                 for i in range(len(fetch_features)):
                 for i in range(len(fetch_features)):
                     if i == 0:
                     if i == 0:
@@ -1357,10 +1369,10 @@ class TraditionalGitClient(GitClient):
                 for ref, sha in orig_new_refs.items():
                 for ref, sha in orig_new_refs.items():
                     if sha == ZERO_SHA:
                     if sha == ZERO_SHA:
                         if CAPABILITY_REPORT_STATUS in negotiated_capabilities:
                         if CAPABILITY_REPORT_STATUS in negotiated_capabilities:
+                            assert report_status_parser is not None
                             report_status_parser._ref_statuses.append(
                             report_status_parser._ref_statuses.append(
                                 b"ng " + ref + b" remote does not support deleting refs"
                                 b"ng " + ref + b" remote does not support deleting refs"
                             )
                             )
-                            report_status_parser._ref_status_ok = False
                         del new_refs[ref]
                         del new_refs[ref]
 
 
             if new_refs is None:
             if new_refs is None:
@@ -1767,7 +1779,7 @@ class TCPGitClient(TraditionalGitClient):
         proto.send_cmd(
         proto.send_cmd(
             b"git-" + cmd, path, b"host=" + self._host.encode("ascii") + version_str
             b"git-" + cmd, path, b"host=" + self._host.encode("ascii") + version_str
         )
         )
-        return proto, lambda: _fileno_can_read(s), None
+        return proto, lambda: _fileno_can_read(s.fileno()), None
 
 
 
 
 class SubprocessWrapper:
 class SubprocessWrapper:
@@ -1997,7 +2009,7 @@ class LocalGitClient(GitClient):
                 *generate_pack_data(have, want, ofs_delta=True)
                 *generate_pack_data(have, want, ofs_delta=True)
             )
             )
 
 
-            ref_status = {}
+            ref_status: dict[bytes, Optional[str]] = {}
 
 
             for refname, new_sha1 in new_refs.items():
             for refname, new_sha1 in new_refs.items():
                 old_sha1 = old_refs.get(refname, ZERO_SHA)
                 old_sha1 = old_refs.get(refname, ZERO_SHA)
@@ -2236,7 +2248,7 @@ class BundleClient(GitClient):
 
 
             while line.startswith(b"-"):
             while line.startswith(b"-"):
                 (obj_id, comment) = line[1:].rstrip(b"\n").split(b" ", 1)
                 (obj_id, comment) = line[1:].rstrip(b"\n").split(b" ", 1)
-                prerequisites.append((obj_id, comment.decode("utf-8")))
+                prerequisites.append((obj_id, comment))
                 line = f.readline()
                 line = f.readline()
 
 
             while line != b"\n":
             while line != b"\n":
@@ -2977,7 +2989,11 @@ class AbstractHttpGitClient(GitClient):
         protocol_version: Optional[int] = None,
         protocol_version: Optional[int] = None,
         ref_prefix: Optional[list[Ref]] = None,
         ref_prefix: Optional[list[Ref]] = None,
     ) -> tuple[
     ) -> tuple[
-        dict[Ref, ObjectID], set[bytes], str, dict[Ref, Ref], dict[Ref, ObjectID]
+        dict[Ref, Optional[ObjectID]],
+        set[bytes],
+        str,
+        dict[Ref, Ref],
+        dict[Ref, ObjectID],
     ]:
     ]:
         if (
         if (
             protocol_version is not None
             protocol_version is not None
@@ -3040,10 +3056,10 @@ class AbstractHttpGitClient(GitClient):
                     resp, read = self._smart_request(
                     resp, read = self._smart_request(
                         service.decode("ascii"), base_url, body
                         service.decode("ascii"), base_url, body
                     )
                     )
-                    proto = Protocol(read, None)
+                    proto = Protocol(read, lambda data: None)
                     return server_capabilities, resp, read, proto
                     return server_capabilities, resp, read, proto
 
 
-                proto = Protocol(read, None)  # type: ignore
+                proto = Protocol(read, lambda data: None)
                 server_protocol_version = negotiate_protocol_version(proto)
                 server_protocol_version = negotiate_protocol_version(proto)
                 if server_protocol_version not in GIT_PROTOCOL_VERSIONS:
                 if server_protocol_version not in GIT_PROTOCOL_VERSIONS:
                     raise ValueError(
                     raise ValueError(
@@ -3108,7 +3124,12 @@ class AbstractHttpGitClient(GitClient):
                     if not chunk:
                     if not chunk:
                         break
                         break
                     data += chunk
                     data += chunk
-                (refs, peeled) = split_peeled_refs(read_info_refs(BytesIO(data)))
+                from typing import Optional, cast
+
+                info_refs = read_info_refs(BytesIO(data))
+                (refs, peeled) = split_peeled_refs(
+                    cast(dict[bytes, Optional[bytes]], info_refs)
+                )
                 if ref_prefix is not None:
                 if ref_prefix is not None:
                     refs = filter_ref_prefix(refs, ref_prefix)
                     refs = filter_ref_prefix(refs, ref_prefix)
                 return refs, set(), base_url, {}, peeled
                 return refs, set(), base_url, {}, peeled
@@ -3196,7 +3217,7 @@ class AbstractHttpGitClient(GitClient):
 
 
         resp, read = self._smart_request("git-receive-pack", url, data=body_generator())
         resp, read = self._smart_request("git-receive-pack", url, data=body_generator())
         try:
         try:
-            resp_proto = Protocol(read, None)
+            resp_proto = Protocol(read, lambda data: None)
             ref_status = self._handle_receive_pack_tail(
             ref_status = self._handle_receive_pack_tail(
                 resp_proto, negotiated_capabilities, progress
                 resp_proto, negotiated_capabilities, progress
             )
             )
@@ -3441,7 +3462,7 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
     def _http_request(self, url, headers=None, data=None, raise_for_status=True):
     def _http_request(self, url, headers=None, data=None, raise_for_status=True):
         import urllib3.exceptions
         import urllib3.exceptions
 
 
-        req_headers = self.pool_manager.headers.copy()
+        req_headers = dict(self.pool_manager.headers)
         if headers is not None:
         if headers is not None:
             req_headers.update(headers)
             req_headers.update(headers)
         req_headers["Pragma"] = "no-cache"
         req_headers["Pragma"] = "no-cache"
@@ -3455,10 +3476,10 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
                 request_kwargs["timeout"] = self._timeout
                 request_kwargs["timeout"] = self._timeout
 
 
             if data is None:
             if data is None:
-                resp = self.pool_manager.request("GET", url, **request_kwargs)
+                resp = self.pool_manager.request("GET", url, **request_kwargs)  # type: ignore[arg-type]
             else:
             else:
                 request_kwargs["body"] = data
                 request_kwargs["body"] = data
-                resp = self.pool_manager.request("POST", url, **request_kwargs)
+                resp = self.pool_manager.request("POST", url, **request_kwargs)  # type: ignore[arg-type]
         except urllib3.exceptions.HTTPError as e:
         except urllib3.exceptions.HTTPError as e:
             raise GitProtocolError(str(e)) from e
             raise GitProtocolError(str(e)) from e
 
 
@@ -3472,15 +3493,15 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
             if resp.status != 200:
             if resp.status != 200:
                 raise GitProtocolError(f"unexpected http resp {resp.status} for {url}")
                 raise GitProtocolError(f"unexpected http resp {resp.status} for {url}")
 
 
-        resp.content_type = resp.headers.get("Content-Type")
+        resp.content_type = resp.headers.get("Content-Type")  # type: ignore[attr-defined]
         # Check if geturl() is available (urllib3 version >= 1.23)
         # Check if geturl() is available (urllib3 version >= 1.23)
         try:
         try:
             resp_url = resp.geturl()
             resp_url = resp.geturl()
         except AttributeError:
         except AttributeError:
             # get_redirect_location() is available for urllib3 >= 1.1
             # get_redirect_location() is available for urllib3 >= 1.1
-            resp.redirect_location = resp.get_redirect_location()
+            resp.redirect_location = resp.get_redirect_location()  # type: ignore[attr-defined]
         else:
         else:
-            resp.redirect_location = resp_url if resp_url != url else ""
+            resp.redirect_location = resp_url if resp_url != url else ""  # type: ignore[attr-defined]
         return resp, _wrap_urllib3_exceptions(resp.read)
         return resp, _wrap_urllib3_exceptions(resp.read)
 
 
 
 

+ 32 - 15
dulwich/cloud/gcs.py

@@ -24,22 +24,33 @@
 
 
 import posixpath
 import posixpath
 import tempfile
 import tempfile
+from collections.abc import Iterator
+from typing import TYPE_CHECKING, BinaryIO
 
 
 from ..object_store import BucketBasedObjectStore
 from ..object_store import BucketBasedObjectStore
-from ..pack import PACK_SPOOL_FILE_MAX_SIZE, Pack, PackData, load_pack_index_file
+from ..pack import (
+    PACK_SPOOL_FILE_MAX_SIZE,
+    Pack,
+    PackData,
+    PackIndex,
+    load_pack_index_file,
+)
+
+if TYPE_CHECKING:
+    from google.cloud.storage import Bucket
 
 
 # TODO(jelmer): For performance, read ranges?
 # TODO(jelmer): For performance, read ranges?
 
 
 
 
 class GcsObjectStore(BucketBasedObjectStore):
 class GcsObjectStore(BucketBasedObjectStore):
-    """Object store implementation for Google Cloud Storage."""
+    """Object store implementation using Google Cloud Storage."""
 
 
-    def __init__(self, bucket, subpath="") -> None:
-        """Initialize GcsObjectStore.
+    def __init__(self, bucket: "Bucket", subpath: str = "") -> None:
+        """Initialize GCS object store.
 
 
         Args:
         Args:
-          bucket: GCS bucket instance
-          subpath: Subpath within the bucket
+            bucket: GCS bucket instance
+            subpath: Optional subpath within the bucket
         """
         """
         super().__init__()
         super().__init__()
         self.bucket = bucket
         self.bucket = bucket
@@ -49,13 +60,13 @@ class GcsObjectStore(BucketBasedObjectStore):
         """Return string representation of GcsObjectStore."""
         """Return string representation of GcsObjectStore."""
         return f"{type(self).__name__}({self.bucket!r}, subpath={self.subpath!r})"
         return f"{type(self).__name__}({self.bucket!r}, subpath={self.subpath!r})"
 
 
-    def _remove_pack(self, name) -> None:
+    def _remove_pack_by_name(self, name: str) -> None:
         self.bucket.delete_blobs(
         self.bucket.delete_blobs(
             [posixpath.join(self.subpath, name) + "." + ext for ext in ["pack", "idx"]]
             [posixpath.join(self.subpath, name) + "." + ext for ext in ["pack", "idx"]]
         )
         )
 
 
-    def _iter_pack_names(self):
-        packs = {}
+    def _iter_pack_names(self) -> Iterator[str]:
+        packs: dict[str, set[str]] = {}
         for blob in self.bucket.list_blobs(prefix=self.subpath):
         for blob in self.bucket.list_blobs(prefix=self.subpath):
             name, ext = posixpath.splitext(posixpath.basename(blob.name))
             name, ext = posixpath.splitext(posixpath.basename(blob.name))
             packs.setdefault(name, set()).add(ext)
             packs.setdefault(name, set()).add(ext)
@@ -63,26 +74,32 @@ class GcsObjectStore(BucketBasedObjectStore):
             if exts == {".pack", ".idx"}:
             if exts == {".pack", ".idx"}:
                 yield name
                 yield name
 
 
-    def _load_pack_data(self, name):
+    def _load_pack_data(self, name: str) -> PackData:
         b = self.bucket.blob(posixpath.join(self.subpath, name + ".pack"))
         b = self.bucket.blob(posixpath.join(self.subpath, name + ".pack"))
+        from typing import cast
+
+        from ..file import _GitFile
+
         f = tempfile.SpooledTemporaryFile(max_size=PACK_SPOOL_FILE_MAX_SIZE)
         f = tempfile.SpooledTemporaryFile(max_size=PACK_SPOOL_FILE_MAX_SIZE)
         b.download_to_file(f)
         b.download_to_file(f)
         f.seek(0)
         f.seek(0)
-        return PackData(name + ".pack", f)
+        return PackData(name + ".pack", cast(_GitFile, f))
 
 
-    def _load_pack_index(self, name):
+    def _load_pack_index(self, name: str) -> PackIndex:
         b = self.bucket.blob(posixpath.join(self.subpath, name + ".idx"))
         b = self.bucket.blob(posixpath.join(self.subpath, name + ".idx"))
         f = tempfile.SpooledTemporaryFile(max_size=PACK_SPOOL_FILE_MAX_SIZE)
         f = tempfile.SpooledTemporaryFile(max_size=PACK_SPOOL_FILE_MAX_SIZE)
         b.download_to_file(f)
         b.download_to_file(f)
         f.seek(0)
         f.seek(0)
         return load_pack_index_file(name + ".idx", f)
         return load_pack_index_file(name + ".idx", f)
 
 
-    def _get_pack(self, name):
-        return Pack.from_lazy_objects(
+    def _get_pack(self, name: str) -> Pack:
+        return Pack.from_lazy_objects(  # type: ignore[no-untyped-call]
             lambda: self._load_pack_data(name), lambda: self._load_pack_index(name)
             lambda: self._load_pack_data(name), lambda: self._load_pack_index(name)
         )
         )
 
 
-    def _upload_pack(self, basename, pack_file, index_file) -> None:
+    def _upload_pack(
+        self, basename: str, pack_file: BinaryIO, index_file: BinaryIO
+    ) -> None:
         idxblob = self.bucket.blob(posixpath.join(self.subpath, basename + ".idx"))
         idxblob = self.bucket.blob(posixpath.join(self.subpath, basename + ".idx"))
         datablob = self.bucket.blob(posixpath.join(self.subpath, basename + ".pack"))
         datablob = self.bucket.blob(posixpath.join(self.subpath, basename + ".pack"))
         idxblob.upload_from_file(index_file)
         idxblob.upload_from_file(index_file)

+ 1 - 3
dulwich/commit_graph.py

@@ -567,9 +567,7 @@ def write_commit_graph(
 
 
     graph_path = os.path.join(info_dir, b"commit-graph")
     graph_path = os.path.join(info_dir, b"commit-graph")
     with GitFile(graph_path, "wb") as f:
     with GitFile(graph_path, "wb") as f:
-        from typing import BinaryIO, cast
-
-        graph.write_to_file(cast(BinaryIO, f))
+        graph.write_to_file(f)
 
 
 
 
 def get_reachable_commits(
 def get_reachable_commits(

+ 39 - 12
dulwich/config.py

@@ -41,14 +41,11 @@ from contextlib import suppress
 from pathlib import Path
 from pathlib import Path
 from typing import (
 from typing import (
     IO,
     IO,
-    Any,
-    BinaryIO,
     Callable,
     Callable,
     Generic,
     Generic,
     Optional,
     Optional,
     TypeVar,
     TypeVar,
     Union,
     Union,
-    cast,
     overload,
     overload,
 )
 )
 
 
@@ -60,7 +57,7 @@ ConfigValue = Union[str, bytes, bool, int]
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 # Type for file opener callback
 # Type for file opener callback
-FileOpener = Callable[[Union[str, os.PathLike]], BinaryIO]
+FileOpener = Callable[[Union[str, os.PathLike]], IO[bytes]]
 
 
 # Type for includeIf condition matcher
 # Type for includeIf condition matcher
 # Takes the condition value (e.g., "main" for onbranch:main) and returns bool
 # Takes the condition value (e.g., "main" for onbranch:main) and returns bool
@@ -194,7 +191,7 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
           default_factory: Optional factory function for default values
           default_factory: Optional factory function for default values
         """
         """
         self._real: list[tuple[K, V]] = []
         self._real: list[tuple[K, V]] = []
-        self._keyed: dict[Any, V] = {}
+        self._keyed: dict[ConfigKey, V] = {}
         self._default_factory = default_factory
         self._default_factory = default_factory
 
 
     @classmethod
     @classmethod
@@ -239,7 +236,31 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
 
 
     def keys(self) -> KeysView[K]:
     def keys(self) -> KeysView[K]:
         """Return a view of the dictionary's keys."""
         """Return a view of the dictionary's keys."""
-        return self._keyed.keys()  # type: ignore[return-value]
+        # Return a view of the original keys (not lowercased)
+        # We need to deduplicate since _real can have duplicates
+        seen = set()
+        unique_keys = []
+        for k, _ in self._real:
+            lower = lower_key(k)
+            if lower not in seen:
+                seen.add(lower)
+                unique_keys.append(k)
+        from collections.abc import KeysView as ABCKeysView
+
+        class UniqueKeysView(ABCKeysView[K]):
+            def __init__(self, keys: list[K]):
+                self._keys = keys
+
+            def __contains__(self, key: object) -> bool:
+                return key in self._keys
+
+            def __iter__(self):
+                return iter(self._keys)
+
+            def __len__(self) -> int:
+                return len(self._keys)
+
+        return UniqueKeysView(unique_keys)
 
 
     def items(self) -> ItemsView[K, V]:
     def items(self) -> ItemsView[K, V]:
         """Return a view of the dictionary's (key, value) pairs in insertion order."""
         """Return a view of the dictionary's (key, value) pairs in insertion order."""
@@ -267,7 +288,13 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
 
 
     def __iter__(self) -> Iterator[K]:
     def __iter__(self) -> Iterator[K]:
         """Iterate over the dictionary's keys."""
         """Iterate over the dictionary's keys."""
-        return iter(self._keyed)
+        # Return iterator over original keys (not lowercased), deduplicated
+        seen = set()
+        for k, _ in self._real:
+            lower = lower_key(k)
+            if lower not in seen:
+                seen.add(lower)
+                yield k
 
 
     def values(self) -> ValuesView[V]:
     def values(self) -> ValuesView[V]:
         """Return a view of the dictionary's values."""
         """Return a view of the dictionary's values."""
@@ -898,7 +925,7 @@ class ConfigFile(ConfigDict):
     @classmethod
     @classmethod
     def from_file(
     def from_file(
         cls,
         cls,
-        f: BinaryIO,
+        f: IO[bytes],
         *,
         *,
         config_dir: Optional[str] = None,
         config_dir: Optional[str] = None,
         included_paths: Optional[set[str]] = None,
         included_paths: Optional[set[str]] = None,
@@ -1075,8 +1102,8 @@ class ConfigFile(ConfigDict):
             opener: FileOpener
             opener: FileOpener
             if file_opener is None:
             if file_opener is None:
 
 
-                def opener(path: Union[str, os.PathLike]) -> BinaryIO:
-                    return cast(BinaryIO, GitFile(path, "rb"))
+                def opener(path: Union[str, os.PathLike]) -> IO[bytes]:
+                    return GitFile(path, "rb")
             else:
             else:
                 opener = file_opener
                 opener = file_opener
 
 
@@ -1236,8 +1263,8 @@ class ConfigFile(ConfigDict):
         opener: FileOpener
         opener: FileOpener
         if file_opener is None:
         if file_opener is None:
 
 
-            def opener(p: Union[str, os.PathLike]) -> BinaryIO:
-                return cast(BinaryIO, GitFile(p, "rb"))
+            def opener(p: Union[str, os.PathLike]) -> IO[bytes]:
+                return GitFile(p, "rb")
         else:
         else:
             opener = file_opener
             opener = file_opener
 
 

+ 14 - 8
dulwich/contrib/swift.py

@@ -43,6 +43,7 @@ from typing import BinaryIO, Callable, Optional, Union, cast
 
 
 from geventhttpclient import HTTPClient
 from geventhttpclient import HTTPClient
 
 
+from ..file import _GitFile
 from ..greenthreads import GreenThreadsMissingObjectFinder
 from ..greenthreads import GreenThreadsMissingObjectFinder
 from ..lru_cache import LRUSizeCache
 from ..lru_cache import LRUSizeCache
 from ..object_store import INFODIR, PACKDIR, ObjectContainer, PackBasedObjectStore
 from ..object_store import INFODIR, PACKDIR, ObjectContainer, PackBasedObjectStore
@@ -823,7 +824,11 @@ class SwiftObjectStore(PackBasedObjectStore):
               The created SwiftPack or None if empty
               The created SwiftPack or None if empty
             """
             """
             f.seek(0)
             f.seek(0)
-            pack = PackData(file=f, filename="")
+            from typing import cast
+
+            from ..file import _GitFile
+
+            pack = PackData(file=cast(_GitFile, f), filename="")
             entries = pack.sorted_entries()
             entries = pack.sorted_entries()
             if entries:
             if entries:
                 basename = posixpath.join(
                 basename = posixpath.join(
@@ -875,7 +880,8 @@ class SwiftObjectStore(PackBasedObjectStore):
         fd, path = tempfile.mkstemp(prefix="tmp_pack_")
         fd, path = tempfile.mkstemp(prefix="tmp_pack_")
         f = os.fdopen(fd, "w+b")
         f = os.fdopen(fd, "w+b")
         try:
         try:
-            indexer = PackIndexer(f, resolve_ext_ref=None)
+            pack_data = PackData(file=cast(_GitFile, f), filename=path)
+            indexer = PackIndexer(cast(BinaryIO, pack_data._file), resolve_ext_ref=None)
             copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer)
             copier = PackStreamCopier(read_all, read_some, f, delta_iter=indexer)
             copier.verify()
             copier.verify()
             return self._complete_thin_pack(f, path, copier, indexer)
             return self._complete_thin_pack(f, path, copier, indexer)
@@ -890,7 +896,7 @@ class SwiftObjectStore(PackBasedObjectStore):
 
 
         # Update the header with the new number of objects.
         # Update the header with the new number of objects.
         f.seek(0)
         f.seek(0)
-        write_pack_header(f, len(entries) + len(indexer.ext_refs()))  # type: ignore
+        write_pack_header(f, len(entries) + len(indexer.ext_refs))  # type: ignore
 
 
         # Must flush before reading (http://bugs.python.org/issue3207)
         # Must flush before reading (http://bugs.python.org/issue3207)
         f.flush()
         f.flush()
@@ -902,7 +908,7 @@ class SwiftObjectStore(PackBasedObjectStore):
         f.seek(0, os.SEEK_CUR)
         f.seek(0, os.SEEK_CUR)
 
 
         # Complete the pack.
         # Complete the pack.
-        for ext_sha in indexer.ext_refs():  # type: ignore
+        for ext_sha in indexer.ext_refs:  # type: ignore
             assert len(ext_sha) == 20
             assert len(ext_sha) == 20
             type_num, data = self.get_raw(ext_sha)
             type_num, data = self.get_raw(ext_sha)
             offset = f.tell()
             offset = f.tell()
@@ -928,7 +934,7 @@ class SwiftObjectStore(PackBasedObjectStore):
 
 
         # Write pack info.
         # Write pack info.
         f.seek(0)
         f.seek(0)
-        pack_data = PackData(filename="", file=f)
+        pack_data = PackData(filename="", file=cast(_GitFile, f))
         index_file.seek(0)
         index_file.seek(0)
         pack_index = load_pack_index_file("", index_file)
         pack_index = load_pack_index_file("", index_file)
         serialized_pack_info = pack_info_create(pack_data, pack_index)
         serialized_pack_info = pack_info_create(pack_data, pack_index)
@@ -1030,17 +1036,17 @@ class SwiftInfoRefsContainer(InfoRefsContainer):
         del self._refs[name]
         del self._refs[name]
         return True
         return True
 
 
-    def allkeys(self) -> Iterator[bytes]:
+    def allkeys(self) -> set[bytes]:
         """Get all reference names.
         """Get all reference names.
 
 
         Returns:
         Returns:
-          Iterator of reference names as bytes
+          Set of reference names as bytes
         """
         """
         try:
         try:
             self._refs[b"HEAD"] = self._refs[b"refs/heads/master"]
             self._refs[b"HEAD"] = self._refs[b"refs/heads/master"]
         except KeyError:
         except KeyError:
             pass
             pass
-        return iter(self._refs.keys())
+        return set(self._refs.keys())
 
 
 
 
 class SwiftRepo(BaseRepo):
 class SwiftRepo(BaseRepo):

+ 12 - 6
dulwich/diff.py

@@ -47,11 +47,11 @@ Example usage:
 import logging
 import logging
 import os
 import os
 import stat
 import stat
-from typing import BinaryIO, Optional, cast
+from typing import BinaryIO, Optional
 
 
 from .index import ConflictedIndexEntry, commit_index
 from .index import ConflictedIndexEntry, commit_index
 from .object_store import iter_tree_contents
 from .object_store import iter_tree_contents
-from .objects import S_ISGITLINK, Blob
+from .objects import S_ISGITLINK, Blob, Commit
 from .patch import write_blob_diff, write_object_diff
 from .patch import write_blob_diff, write_object_diff
 from .repo import Repo
 from .repo import Repo
 
 
@@ -90,12 +90,16 @@ def diff_index_to_tree(
     if commit_sha is None:
     if commit_sha is None:
         try:
         try:
             commit_sha = repo.refs[b"HEAD"]
             commit_sha = repo.refs[b"HEAD"]
-            old_tree = repo[commit_sha].tree
+            old_commit = repo[commit_sha]
+            assert isinstance(old_commit, Commit)
+            old_tree = old_commit.tree
         except KeyError:
         except KeyError:
             # No HEAD means no commits yet
             # No HEAD means no commits yet
             old_tree = None
             old_tree = None
     else:
     else:
-        old_tree = repo[commit_sha].tree
+        old_commit = repo[commit_sha]
+        assert isinstance(old_commit, Commit)
+        old_tree = old_commit.tree
 
 
     # Get tree from index
     # Get tree from index
     index = repo.open_index()
     index = repo.open_index()
@@ -125,7 +129,9 @@ def diff_working_tree_to_tree(
         commit_sha: SHA of commit to compare against
         commit_sha: SHA of commit to compare against
         paths: Optional list of paths to filter (as bytes)
         paths: Optional list of paths to filter (as bytes)
     """
     """
-    tree = repo[commit_sha].tree
+    commit = repo[commit_sha]
+    assert isinstance(commit, Commit)
+    tree = commit.tree
     normalizer = repo.get_blob_normalizer()
     normalizer = repo.get_blob_normalizer()
     filter_callback = normalizer.checkin_normalize
     filter_callback = normalizer.checkin_normalize
 
 
@@ -382,7 +388,7 @@ def diff_working_tree_to_index(
         old_obj = repo.object_store[old_sha]
         old_obj = repo.object_store[old_sha]
         # Type check and cast to Blob
         # Type check and cast to Blob
         if isinstance(old_obj, Blob):
         if isinstance(old_obj, Blob):
-            old_blob = cast(Blob, old_obj)
+            old_blob = old_obj
         else:
         else:
             old_blob = None
             old_blob = None
 
 

+ 4 - 4
dulwich/dumb.py

@@ -24,7 +24,7 @@
 import os
 import os
 import tempfile
 import tempfile
 import zlib
 import zlib
-from collections.abc import Iterator
+from collections.abc import Iterator, Sequence
 from io import BytesIO
 from io import BytesIO
 from typing import Any, Callable, Optional
 from typing import Any, Callable, Optional
 from urllib.parse import urljoin
 from urllib.parse import urljoin
@@ -338,9 +338,9 @@ class DumbHTTPObjectStore(BaseObjectStore):
 
 
     def add_objects(
     def add_objects(
         self,
         self,
-        objects: Iterator[ShaFile],
-        progress: Optional[Callable[[int], None]] = None,
-    ) -> None:
+        objects: Sequence[tuple[ShaFile, Optional[str]]],
+        progress: Optional[Callable[[str], None]] = None,
+    ) -> Optional["Pack"]:
         """Add a set of objects to this object store."""
         """Add a set of objects to this object store."""
         raise NotImplementedError("Cannot add objects to dumb HTTP repository")
         raise NotImplementedError("Cannot add objects to dumb HTTP repository")
 
 

+ 57 - 9
dulwich/file.py

@@ -24,10 +24,15 @@
 import os
 import os
 import sys
 import sys
 import warnings
 import warnings
-from collections.abc import Iterator
+from collections.abc import Iterable, Iterator
 from types import TracebackType
 from types import TracebackType
 from typing import IO, Any, ClassVar, Literal, Optional, Union, overload
 from typing import IO, Any, ClassVar, Literal, Optional, Union, overload
 
 
+if sys.version_info >= (3, 12):
+    from collections.abc import Buffer
+else:
+    Buffer = Union[bytes, bytearray, memoryview]
+
 
 
 def ensure_dir_exists(dirname: Union[str, bytes, os.PathLike]) -> None:
 def ensure_dir_exists(dirname: Union[str, bytes, os.PathLike]) -> None:
     """Ensure a directory exists, creating if necessary."""
     """Ensure a directory exists, creating if necessary."""
@@ -136,7 +141,7 @@ class FileLocked(Exception):
         super().__init__(filename, lockfilename)
         super().__init__(filename, lockfilename)
 
 
 
 
-class _GitFile:
+class _GitFile(IO[bytes]):
     """File that follows the git locking protocol for writes.
     """File that follows the git locking protocol for writes.
 
 
     All writes to a file foo will be written into foo.lock in the same
     All writes to a file foo will be written into foo.lock in the same
@@ -148,7 +153,6 @@ class _GitFile:
     """
     """
 
 
     PROXY_PROPERTIES: ClassVar[set[str]] = {
     PROXY_PROPERTIES: ClassVar[set[str]] = {
-        "closed",
         "encoding",
         "encoding",
         "errors",
         "errors",
         "mode",
         "mode",
@@ -158,15 +162,19 @@ class _GitFile:
     }
     }
     PROXY_METHODS: ClassVar[set[str]] = {
     PROXY_METHODS: ClassVar[set[str]] = {
         "__iter__",
         "__iter__",
+        "__next__",
         "flush",
         "flush",
         "fileno",
         "fileno",
         "isatty",
         "isatty",
         "read",
         "read",
+        "readable",
         "readline",
         "readline",
         "readlines",
         "readlines",
         "seek",
         "seek",
+        "seekable",
         "tell",
         "tell",
         "truncate",
         "truncate",
+        "writable",
         "write",
         "write",
         "writelines",
         "writelines",
     }
     }
@@ -195,9 +203,6 @@ class _GitFile:
         self._file = os.fdopen(fd, mode, bufsize)
         self._file = os.fdopen(fd, mode, bufsize)
         self._closed = False
         self._closed = False
 
 
-        for method in self.PROXY_METHODS:
-            setattr(self, method, getattr(self._file, method))
-
     def __iter__(self) -> Iterator[bytes]:
     def __iter__(self) -> Iterator[bytes]:
         """Iterate over lines in the file."""
         """Iterate over lines in the file."""
         return iter(self._file)
         return iter(self._file)
@@ -267,20 +272,63 @@ class _GitFile:
         else:
         else:
             self.close()
             self.close()
 
 
+    @property
+    def closed(self) -> bool:
+        """Return whether the file is closed."""
+        return self._closed
+
     def __getattr__(self, name: str) -> Any:  # noqa: ANN401
     def __getattr__(self, name: str) -> Any:  # noqa: ANN401
         """Proxy property calls to the underlying file."""
         """Proxy property calls to the underlying file."""
         if name in self.PROXY_PROPERTIES:
         if name in self.PROXY_PROPERTIES:
             return getattr(self._file, name)
             return getattr(self._file, name)
         raise AttributeError(name)
         raise AttributeError(name)
 
 
+    # Implement IO[bytes] methods by delegating to the underlying file
+    def read(self, size: int = -1) -> bytes:
+        return self._file.read(size)
+
+    # TODO: Remove type: ignore when Python 3.10 support is dropped (Oct 2026)
+    # Python 3.9/3.10 have issues with IO[bytes] overload signatures
+    def write(self, data: Buffer, /) -> int:  # type: ignore[override]
+        return self._file.write(data)
+
+    def readline(self, size: int = -1) -> bytes:
+        return self._file.readline(size)
+
+    def readlines(self, hint: int = -1) -> list[bytes]:
+        return self._file.readlines(hint)
+
+    # TODO: Remove type: ignore when Python 3.10 support is dropped (Oct 2026)
+    # Python 3.9/3.10 have issues with IO[bytes] overload signatures
+    def writelines(self, lines: Iterable[Buffer], /) -> None:  # type: ignore[override]
+        return self._file.writelines(lines)
+
+    def seek(self, offset: int, whence: int = 0) -> int:
+        return self._file.seek(offset, whence)
+
+    def tell(self) -> int:
+        return self._file.tell()
+
+    def flush(self) -> None:
+        return self._file.flush()
+
+    def truncate(self, size: Optional[int] = None) -> int:
+        return self._file.truncate(size)
+
+    def fileno(self) -> int:
+        return self._file.fileno()
+
+    def isatty(self) -> bool:
+        return self._file.isatty()
+
     def readable(self) -> bool:
     def readable(self) -> bool:
-        """Return whether the file is readable."""
         return self._file.readable()
         return self._file.readable()
 
 
     def writable(self) -> bool:
     def writable(self) -> bool:
-        """Return whether the file is writable."""
         return self._file.writable()
         return self._file.writable()
 
 
     def seekable(self) -> bool:
     def seekable(self) -> bool:
-        """Return whether the file is seekable."""
         return self._file.seekable()
         return self._file.seekable()
+
+    def __next__(self) -> bytes:
+        return next(iter(self._file))

+ 15 - 7
dulwich/filters.py

@@ -30,7 +30,7 @@ from .objects import Blob
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .config import StackedConfig
     from .config import StackedConfig
-    from .repo import Repo
+    from .repo import BaseRepo
 
 
 
 
 class FilterError(Exception):
 class FilterError(Exception):
@@ -128,7 +128,9 @@ class FilterRegistry:
     """Registry for filter drivers."""
     """Registry for filter drivers."""
 
 
     def __init__(
     def __init__(
-        self, config: Optional["StackedConfig"] = None, repo: Optional["Repo"] = None
+        self,
+        config: Optional["StackedConfig"] = None,
+        repo: Optional["BaseRepo"] = None,
     ) -> None:
     ) -> None:
         """Initialize FilterRegistry.
         """Initialize FilterRegistry.
 
 
@@ -211,8 +213,12 @@ class FilterRegistry:
         required = self.config.get_boolean(("filter", name), "required", False)
         required = self.config.get_boolean(("filter", name), "required", False)
 
 
         if clean_cmd or smudge_cmd:
         if clean_cmd or smudge_cmd:
-            # Get repository working directory
-            repo_path = self.repo.path if self.repo else None
+            # Get repository working directory (only for Repo, not BaseRepo)
+            from .repo import Repo
+
+            repo_path = (
+                self.repo.path if self.repo and isinstance(self.repo, Repo) else None
+            )
             return ProcessFilterDriver(clean_cmd, smudge_cmd, required, repo_path)
             return ProcessFilterDriver(clean_cmd, smudge_cmd, required, repo_path)
 
 
         return None
         return None
@@ -221,8 +227,10 @@ class FilterRegistry:
         """Create LFS filter driver."""
         """Create LFS filter driver."""
         from .lfs import LFSFilterDriver, LFSStore
         from .lfs import LFSFilterDriver, LFSStore
 
 
-        # If we have a repo, use its LFS store
-        if registry.repo is not None:
+        # If we have a Repo (not just BaseRepo), use its LFS store
+        from .repo import Repo
+
+        if registry.repo is not None and isinstance(registry.repo, Repo):
             lfs_store = LFSStore.from_repo(registry.repo, create=True)
             lfs_store = LFSStore.from_repo(registry.repo, create=True)
         else:
         else:
             # Fall back to creating a temporary LFS store
             # Fall back to creating a temporary LFS store
@@ -389,7 +397,7 @@ class FilterBlobNormalizer:
         config_stack: Optional["StackedConfig"],
         config_stack: Optional["StackedConfig"],
         gitattributes: GitAttributes,
         gitattributes: GitAttributes,
         filter_registry: Optional[FilterRegistry] = None,
         filter_registry: Optional[FilterRegistry] = None,
-        repo: Optional["Repo"] = None,
+        repo: Optional["BaseRepo"] = None,
     ) -> None:
     ) -> None:
         """Initialize FilterBlobNormalizer.
         """Initialize FilterBlobNormalizer.
 
 

+ 24 - 22
dulwich/index.py

@@ -32,13 +32,13 @@ from collections.abc import Generator, Iterable, Iterator
 from dataclasses import dataclass
 from dataclasses import dataclass
 from enum import Enum
 from enum import Enum
 from typing import (
 from typing import (
+    IO,
     TYPE_CHECKING,
     TYPE_CHECKING,
     Any,
     Any,
     BinaryIO,
     BinaryIO,
     Callable,
     Callable,
     Optional,
     Optional,
     Union,
     Union,
-    cast,
 )
 )
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -588,7 +588,7 @@ def read_cache_time(f: BinaryIO) -> tuple[int, int]:
     return struct.unpack(">LL", f.read(8))
     return struct.unpack(">LL", f.read(8))
 
 
 
 
-def write_cache_time(f: BinaryIO, t: Union[int, float, tuple[int, int]]) -> None:
+def write_cache_time(f: IO[bytes], t: Union[int, float, tuple[int, int]]) -> None:
     """Write a cache time.
     """Write a cache time.
 
 
     Args:
     Args:
@@ -664,7 +664,7 @@ def read_cache_entry(
 
 
 
 
 def write_cache_entry(
 def write_cache_entry(
-    f: BinaryIO, entry: SerializedIndexEntry, version: int, previous_path: bytes = b""
+    f: IO[bytes], entry: SerializedIndexEntry, version: int, previous_path: bytes = b""
 ) -> None:
 ) -> None:
     """Write an index entry to a file.
     """Write an index entry to a file.
 
 
@@ -745,7 +745,7 @@ def read_index_header(f: BinaryIO) -> tuple[int, int]:
     return version, num_entries
     return version, num_entries
 
 
 
 
-def write_index_extension(f: BinaryIO, extension: IndexExtension) -> None:
+def write_index_extension(f: IO[bytes], extension: IndexExtension) -> None:
     """Write an index extension.
     """Write an index extension.
 
 
     Args:
     Args:
@@ -868,7 +868,7 @@ def read_index_dict(
 
 
 
 
 def write_index(
 def write_index(
-    f: BinaryIO,
+    f: IO[bytes],
     entries: list[SerializedIndexEntry],
     entries: list[SerializedIndexEntry],
     version: Optional[int] = None,
     version: Optional[int] = None,
     extensions: Optional[list[IndexExtension]] = None,
     extensions: Optional[list[IndexExtension]] = None,
@@ -909,7 +909,7 @@ def write_index(
 
 
 
 
 def write_index_dict(
 def write_index_dict(
-    f: BinaryIO,
+    f: IO[bytes],
     entries: dict[bytes, Union[IndexEntry, ConflictedIndexEntry]],
     entries: dict[bytes, Union[IndexEntry, ConflictedIndexEntry]],
     version: Optional[int] = None,
     version: Optional[int] = None,
     extensions: Optional[list[IndexExtension]] = None,
     extensions: Optional[list[IndexExtension]] = None,
@@ -1007,8 +1007,6 @@ class Index:
 
 
     def write(self) -> None:
     def write(self) -> None:
         """Write current contents of index to disk."""
         """Write current contents of index to disk."""
-        from typing import BinaryIO, cast
-
         f = GitFile(self._filename, "wb")
         f = GitFile(self._filename, "wb")
         try:
         try:
             # Filter out extensions with no meaningful data
             # Filter out extensions with no meaningful data
@@ -1022,7 +1020,7 @@ class Index:
             if self._skip_hash:
             if self._skip_hash:
                 # When skipHash is enabled, write the index without computing SHA1
                 # When skipHash is enabled, write the index without computing SHA1
                 write_index_dict(
                 write_index_dict(
-                    cast(BinaryIO, f),
+                    f,
                     self._byname,
                     self._byname,
                     version=self._version,
                     version=self._version,
                     extensions=meaningful_extensions,
                     extensions=meaningful_extensions,
@@ -1031,9 +1029,9 @@ class Index:
                 f.write(b"\x00" * 20)
                 f.write(b"\x00" * 20)
                 f.close()
                 f.close()
             else:
             else:
-                sha1_writer = SHA1Writer(cast(BinaryIO, f))
+                sha1_writer = SHA1Writer(f)
                 write_index_dict(
                 write_index_dict(
-                    cast(BinaryIO, sha1_writer),
+                    sha1_writer,
                     self._byname,
                     self._byname,
                     version=self._version,
                     version=self._version,
                     extensions=meaningful_extensions,
                     extensions=meaningful_extensions,
@@ -1050,9 +1048,7 @@ class Index:
         f = GitFile(self._filename, "rb")
         f = GitFile(self._filename, "rb")
         try:
         try:
             sha1_reader = SHA1Reader(f)
             sha1_reader = SHA1Reader(f)
-            entries, version, extensions = read_index_dict_with_version(
-                cast(BinaryIO, sha1_reader)
-            )
+            entries, version, extensions = read_index_dict_with_version(sha1_reader)
             self._version = version
             self._version = version
             self._extensions = extensions
             self._extensions = extensions
             self.update(entries)
             self.update(entries)
@@ -1411,7 +1407,9 @@ def build_file_from_blob(
     *,
     *,
     honor_filemode: bool = True,
     honor_filemode: bool = True,
     tree_encoding: str = "utf-8",
     tree_encoding: str = "utf-8",
-    symlink_fn: Optional[Callable] = None,
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ] = None,
 ) -> os.stat_result:
 ) -> os.stat_result:
     """Build a file or symlink on disk based on a Git object.
     """Build a file or symlink on disk based on a Git object.
 
 
@@ -1596,7 +1594,9 @@ def build_index_from_tree(
     tree_id: bytes,
     tree_id: bytes,
     honor_filemode: bool = True,
     honor_filemode: bool = True,
     validate_path_element: Callable[[bytes], bool] = validate_path_element_default,
     validate_path_element: Callable[[bytes], bool] = validate_path_element_default,
-    symlink_fn: Optional[Callable] = None,
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ] = None,
     blob_normalizer: Optional["BlobNormalizer"] = None,
     blob_normalizer: Optional["BlobNormalizer"] = None,
     tree_encoding: str = "utf-8",
     tree_encoding: str = "utf-8",
 ) -> None:
 ) -> None:
@@ -1956,7 +1956,9 @@ def _transition_to_file(
     entry: IndexEntry,
     entry: IndexEntry,
     index: Index,
     index: Index,
     honor_filemode: bool,
     honor_filemode: bool,
-    symlink_fn: Optional[Callable[[bytes, bytes], None]],
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ],
     blob_normalizer: Optional["BlobNormalizer"],
     blob_normalizer: Optional["BlobNormalizer"],
     tree_encoding: str = "utf-8",
     tree_encoding: str = "utf-8",
 ) -> None:
 ) -> None:
@@ -2208,7 +2210,9 @@ def update_working_tree(
     change_iterator: Iterator["TreeChange"],
     change_iterator: Iterator["TreeChange"],
     honor_filemode: bool = True,
     honor_filemode: bool = True,
     validate_path_element: Optional[Callable[[bytes], bool]] = None,
     validate_path_element: Optional[Callable[[bytes], bool]] = None,
-    symlink_fn: Optional[Callable] = None,
+    symlink_fn: Optional[
+        Callable[[Union[str, bytes, os.PathLike], Union[str, bytes, os.PathLike]], None]
+    ] = None,
     force_remove_untracked: bool = False,
     force_remove_untracked: bool = False,
     blob_normalizer: Optional["BlobNormalizer"] = None,
     blob_normalizer: Optional["BlobNormalizer"] = None,
     tree_encoding: str = "utf-8",
     tree_encoding: str = "utf-8",
@@ -2685,10 +2689,8 @@ class locked_index:
             self._file.abort()
             self._file.abort()
             return
             return
         try:
         try:
-            from typing import BinaryIO, cast
-
-            f = SHA1Writer(cast(BinaryIO, self._file))
-            write_index_dict(cast(BinaryIO, f), self._index._byname)
+            f = SHA1Writer(self._file)
+            write_index_dict(f, self._index._byname)
         except BaseException:
         except BaseException:
             self._file.abort()
             self._file.abort()
         else:
         else:

+ 153 - 103
dulwich/object_store.py

@@ -33,11 +33,11 @@ from collections.abc import Iterable, Iterator, Sequence
 from contextlib import suppress
 from contextlib import suppress
 from io import BytesIO
 from io import BytesIO
 from typing import (
 from typing import (
+    TYPE_CHECKING,
     Callable,
     Callable,
     Optional,
     Optional,
     Protocol,
     Protocol,
     Union,
     Union,
-    cast,
 )
 )
 
 
 from .errors import NotTreeError
 from .errors import NotTreeError
@@ -82,6 +82,23 @@ from .pack import (
 from .protocol import DEPTH_INFINITE
 from .protocol import DEPTH_INFINITE
 from .refs import PEELED_TAG_SUFFIX, Ref
 from .refs import PEELED_TAG_SUFFIX, Ref
 
 
+if TYPE_CHECKING:
+    from .commit_graph import CommitGraph
+    from .diff_tree import RenameDetector
+
+
+class GraphWalker(Protocol):
+    """Protocol for graph walker objects."""
+
+    def __next__(self) -> Optional[bytes]:
+        """Return the next object SHA to visit."""
+        ...
+
+    def ack(self, sha: bytes) -> None:
+        """Acknowledge that an object has been received."""
+        ...
+
+
 INFODIR = "info"
 INFODIR = "info"
 PACKDIR = "pack"
 PACKDIR = "pack"
 
 
@@ -95,7 +112,9 @@ PACK_MODE = 0o444 if sys.platform != "win32" else 0o644
 DEFAULT_TEMPFILE_GRACE_PERIOD = 14 * 24 * 60 * 60  # 2 weeks
 DEFAULT_TEMPFILE_GRACE_PERIOD = 14 * 24 * 60 * 60  # 2 weeks
 
 
 
 
-def find_shallow(store, heads, depth):
+def find_shallow(
+    store: ObjectContainer, heads: Iterable[bytes], depth: int
+) -> tuple[set[bytes], set[bytes]]:
     """Find shallow commits according to a given depth.
     """Find shallow commits according to a given depth.
 
 
     Args:
     Args:
@@ -107,10 +126,10 @@ def find_shallow(store, heads, depth):
         considered shallow and unshallow according to the arguments. Note that
         considered shallow and unshallow according to the arguments. Note that
         these sets may overlap if a commit is reachable along multiple paths.
         these sets may overlap if a commit is reachable along multiple paths.
     """
     """
-    parents = {}
+    parents: dict[bytes, list[bytes]] = {}
     commit_graph = store.get_commit_graph()
     commit_graph = store.get_commit_graph()
 
 
-    def get_parents(sha):
+    def get_parents(sha: bytes) -> list[bytes]:
         result = parents.get(sha, None)
         result = parents.get(sha, None)
         if not result:
         if not result:
             # Try to use commit graph first if available
             # Try to use commit graph first if available
@@ -121,7 +140,9 @@ def find_shallow(store, heads, depth):
                     parents[sha] = result
                     parents[sha] = result
                     return result
                     return result
             # Fall back to loading the object
             # Fall back to loading the object
-            result = store[sha].parents
+            commit = store[sha]
+            assert isinstance(commit, Commit)
+            result = commit.parents
             parents[sha] = result
             parents[sha] = result
         return result
         return result
 
 
@@ -150,11 +171,11 @@ def find_shallow(store, heads, depth):
 
 
 
 
 def get_depth(
 def get_depth(
-    store,
-    head,
-    get_parents=lambda commit: commit.parents,
-    max_depth=None,
-):
+    store: ObjectContainer,
+    head: bytes,
+    get_parents: Callable = lambda commit: commit.parents,
+    max_depth: Optional[int] = None,
+) -> int:
     """Return the current available depth for the given head.
     """Return the current available depth for the given head.
 
 
     For commits with multiple parents, the largest possible depth will be
     For commits with multiple parents, the largest possible depth will be
@@ -206,17 +227,9 @@ class BaseObjectStore:
     def determine_wants_all(
     def determine_wants_all(
         self, refs: dict[Ref, ObjectID], depth: Optional[int] = None
         self, refs: dict[Ref, ObjectID], depth: Optional[int] = None
     ) -> list[ObjectID]:
     ) -> list[ObjectID]:
-        """Determine all objects that are wanted by the client.
+        """Determine which objects are wanted based on refs."""
 
 
-        Args:
-          refs: Dictionary mapping ref names to object IDs
-          depth: Shallow fetch depth (None for full fetch)
-
-        Returns:
-          List of object IDs that are wanted
-        """
-
-        def _want_deepen(sha):
+        def _want_deepen(sha: bytes) -> bool:
             if not depth:
             if not depth:
                 return False
                 return False
             if depth == DEPTH_INFINITE:
             if depth == DEPTH_INFINITE:
@@ -231,7 +244,7 @@ class BaseObjectStore:
             and not sha == ZERO_SHA
             and not sha == ZERO_SHA
         ]
         ]
 
 
-    def contains_loose(self, sha) -> bool:
+    def contains_loose(self, sha: bytes) -> bool:
         """Check if a particular object is present by SHA1 and is loose."""
         """Check if a particular object is present by SHA1 and is loose."""
         raise NotImplementedError(self.contains_loose)
         raise NotImplementedError(self.contains_loose)
 
 
@@ -243,11 +256,11 @@ class BaseObjectStore:
         return self.contains_loose(sha1)
         return self.contains_loose(sha1)
 
 
     @property
     @property
-    def packs(self):
+    def packs(self) -> list[Pack]:
         """Iterable of pack objects."""
         """Iterable of pack objects."""
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def get_raw(self, name) -> tuple[int, bytes]:
+    def get_raw(self, name: bytes) -> tuple[int, bytes]:
         """Obtain the raw text for an object.
         """Obtain the raw text for an object.
 
 
         Args:
         Args:
@@ -261,15 +274,19 @@ class BaseObjectStore:
         type_num, uncomp = self.get_raw(sha1)
         type_num, uncomp = self.get_raw(sha1)
         return ShaFile.from_raw_string(type_num, uncomp, sha=sha1)
         return ShaFile.from_raw_string(type_num, uncomp, sha=sha1)
 
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[bytes]:
         """Iterate over the SHAs that are present in this store."""
         """Iterate over the SHAs that are present in this store."""
         raise NotImplementedError(self.__iter__)
         raise NotImplementedError(self.__iter__)
 
 
-    def add_object(self, obj) -> None:
+    def add_object(self, obj: ShaFile) -> None:
         """Add a single object to this object store."""
         """Add a single object to this object store."""
         raise NotImplementedError(self.add_object)
         raise NotImplementedError(self.add_object)
 
 
-    def add_objects(self, objects, progress=None) -> None:
+    def add_objects(
+        self,
+        objects: Sequence[tuple[ShaFile, Optional[str]]],
+        progress: Optional[Callable] = None,
+    ) -> Optional["Pack"]:
         """Add a set of objects to this object store.
         """Add a set of objects to this object store.
 
 
         Args:
         Args:
@@ -280,14 +297,20 @@ class BaseObjectStore:
 
 
     def tree_changes(
     def tree_changes(
         self,
         self,
-        source,
-        target,
-        want_unchanged=False,
-        include_trees=False,
-        change_type_same=False,
-        rename_detector=None,
-        paths=None,
-    ):
+        source: Optional[bytes],
+        target: Optional[bytes],
+        want_unchanged: bool = False,
+        include_trees: bool = False,
+        change_type_same: bool = False,
+        rename_detector: Optional["RenameDetector"] = None,
+        paths: Optional[list[bytes]] = None,
+    ) -> Iterator[
+        tuple[
+            tuple[Optional[bytes], Optional[bytes]],
+            tuple[Optional[int], Optional[int]],
+            tuple[Optional[bytes], Optional[bytes]],
+        ]
+    ]:
         """Find the differences between the contents of two trees.
         """Find the differences between the contents of two trees.
 
 
         Args:
         Args:
@@ -320,7 +343,9 @@ class BaseObjectStore:
                 (change.old.sha, change.new.sha),
                 (change.old.sha, change.new.sha),
             )
             )
 
 
-    def iter_tree_contents(self, tree_id, include_trees=False):
+    def iter_tree_contents(
+        self, tree_id: bytes, include_trees: bool = False
+    ) -> Iterator[tuple[bytes, int, bytes]]:
         """Iterate the contents of a tree and all subtrees.
         """Iterate the contents of a tree and all subtrees.
 
 
         Iteration is depth-first pre-order, as in e.g. os.walk.
         Iteration is depth-first pre-order, as in e.g. os.walk.
@@ -362,13 +387,13 @@ class BaseObjectStore:
 
 
     def find_missing_objects(
     def find_missing_objects(
         self,
         self,
-        haves,
-        wants,
-        shallow=None,
-        progress=None,
-        get_tagged=None,
-        get_parents=lambda commit: commit.parents,
-    ):
+        haves: Iterable[bytes],
+        wants: Iterable[bytes],
+        shallow: Optional[set[bytes]] = None,
+        progress: Optional[Callable] = None,
+        get_tagged: Optional[Callable] = None,
+        get_parents: Callable = lambda commit: commit.parents,
+    ) -> Iterator[tuple[bytes, Optional[bytes]]]:
         """Find the missing objects required for a set of revisions.
         """Find the missing objects required for a set of revisions.
 
 
         Args:
         Args:
@@ -395,7 +420,7 @@ class BaseObjectStore:
         )
         )
         return iter(finder)
         return iter(finder)
 
 
-    def find_common_revisions(self, graphwalker):
+    def find_common_revisions(self, graphwalker: GraphWalker) -> list[bytes]:
         """Find which revisions this store has in common using graphwalker.
         """Find which revisions this store has in common using graphwalker.
 
 
         Args:
         Args:
@@ -412,7 +437,12 @@ class BaseObjectStore:
         return haves
         return haves
 
 
     def generate_pack_data(
     def generate_pack_data(
-        self, have, want, shallow=None, progress=None, ofs_delta=True
+        self,
+        have: Iterable[bytes],
+        want: Iterable[bytes],
+        shallow: Optional[set[bytes]] = None,
+        progress: Optional[Callable] = None,
+        ofs_delta: bool = True,
     ) -> tuple[int, Iterator[UnpackedObject]]:
     ) -> tuple[int, Iterator[UnpackedObject]]:
         """Generate pack data objects for a set of wants/haves.
         """Generate pack data objects for a set of wants/haves.
 
 
@@ -435,7 +465,7 @@ class BaseObjectStore:
             progress=progress,
             progress=progress,
         )
         )
 
 
-    def peel_sha(self, sha):
+    def peel_sha(self, sha: bytes) -> bytes:
         """Peel all tags from a SHA.
         """Peel all tags from a SHA.
 
 
         Args:
         Args:
@@ -449,14 +479,14 @@ class BaseObjectStore:
             DeprecationWarning,
             DeprecationWarning,
             stacklevel=2,
             stacklevel=2,
         )
         )
-        return peel_sha(self, sha)[1]
+        return peel_sha(self, sha)[1].id
 
 
     def _get_depth(
     def _get_depth(
         self,
         self,
-        head,
-        get_parents=lambda commit: commit.parents,
-        max_depth=None,
-    ):
+        head: bytes,
+        get_parents: Callable = lambda commit: commit.parents,
+        max_depth: Optional[int] = None,
+    ) -> int:
         """Return the current available depth for the given head.
         """Return the current available depth for the given head.
 
 
         For commits with multiple parents, the largest possible depth will be
         For commits with multiple parents, the largest possible depth will be
@@ -496,7 +526,7 @@ class BaseObjectStore:
             if sha.startswith(prefix):
             if sha.startswith(prefix):
                 yield sha
                 yield sha
 
 
-    def get_commit_graph(self):
+    def get_commit_graph(self) -> Optional["CommitGraph"]:
         """Get the commit graph for this object store.
         """Get the commit graph for this object store.
 
 
         Returns:
         Returns:
@@ -504,7 +534,9 @@ class BaseObjectStore:
         """
         """
         return None
         return None
 
 
-    def write_commit_graph(self, refs=None, reachable=True) -> None:
+    def write_commit_graph(
+        self, refs: Optional[list[bytes]] = None, reachable: bool = True
+    ) -> None:
         """Write a commit graph file for this object store.
         """Write a commit graph file for this object store.
 
 
         Args:
         Args:
@@ -518,7 +550,7 @@ class BaseObjectStore:
         """
         """
         raise NotImplementedError(self.write_commit_graph)
         raise NotImplementedError(self.write_commit_graph)
 
 
-    def get_object_mtime(self, sha):
+    def get_object_mtime(self, sha: bytes) -> float:
         """Get the modification time of an object.
         """Get the modification time of an object.
 
 
         Args:
         Args:
@@ -545,14 +577,14 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        pack_compression_level=-1,
-        pack_index_version=None,
-        pack_delta_window_size=None,
-        pack_window_memory=None,
-        pack_delta_cache_size=None,
-        pack_depth=None,
-        pack_threads=None,
-        pack_big_file_threshold=None,
+        pack_compression_level: int = -1,
+        pack_index_version: Optional[int] = None,
+        pack_delta_window_size: Optional[int] = None,
+        pack_window_memory: Optional[int] = None,
+        pack_delta_cache_size: Optional[int] = None,
+        pack_depth: Optional[int] = None,
+        pack_threads: Optional[int] = None,
+        pack_big_file_threshold: Optional[int] = None,
     ) -> None:
     ) -> None:
         """Initialize a PackBasedObjectStore.
         """Initialize a PackBasedObjectStore.
 
 
@@ -581,8 +613,11 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         raise NotImplementedError(self.add_pack)
         raise NotImplementedError(self.add_pack)
 
 
     def add_pack_data(
     def add_pack_data(
-        self, count: int, unpacked_objects: Iterator[UnpackedObject], progress=None
-    ) -> None:
+        self,
+        count: int,
+        unpacked_objects: Iterator[UnpackedObject],
+        progress: Optional[Callable] = None,
+    ) -> Optional["Pack"]:
         """Add pack data to this object store.
         """Add pack data to this object store.
 
 
         Args:
         Args:
@@ -592,7 +627,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         """
         """
         if count == 0:
         if count == 0:
             # Don't bother writing an empty pack file
             # Don't bother writing an empty pack file
-            return
+            return None
         f, commit, abort = self.add_pack()
         f, commit, abort = self.add_pack()
         try:
         try:
             write_pack_data(
             write_pack_data(
@@ -609,15 +644,11 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
             return commit()
             return commit()
 
 
     @property
     @property
-    def alternates(self):
-        """Get the list of alternate object stores.
-
-        Returns:
-          List of alternate BaseObjectStore instances
-        """
+    def alternates(self) -> list:
+        """Return list of alternate object stores."""
         return []
         return []
 
 
-    def contains_packed(self, sha) -> bool:
+    def contains_packed(self, sha: bytes) -> bool:
         """Check if a particular object is present by SHA1 and is packed.
         """Check if a particular object is present by SHA1 and is packed.
 
 
         This does not check alternates.
         This does not check alternates.
@@ -642,7 +673,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
                 return True
                 return True
         return False
         return False
 
 
-    def _add_cached_pack(self, base_name, pack) -> None:
+    def _add_cached_pack(self, base_name: str, pack: Pack) -> None:
         """Add a newly appeared pack to the cache by path."""
         """Add a newly appeared pack to the cache by path."""
         prev_pack = self._pack_cache.get(base_name)
         prev_pack = self._pack_cache.get(base_name)
         if prev_pack is not pack:
         if prev_pack is not pack:
@@ -668,7 +699,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         remote_has = missing_objects.get_remote_has()
         remote_has = missing_objects.get_remote_has()
         object_ids = list(missing_objects)
         object_ids = list(missing_objects)
         return len(object_ids), generate_unpacked_objects(
         return len(object_ids), generate_unpacked_objects(
-            cast(PackedObjectContainer, self),
+            self,
             object_ids,
             object_ids,
             progress=progress,
             progress=progress,
             ofs_delta=ofs_delta,
             ofs_delta=ofs_delta,
@@ -682,8 +713,8 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
             (name, pack) = pack_cache.popitem()
             (name, pack) = pack_cache.popitem()
             pack.close()
             pack.close()
 
 
-    def _iter_cached_packs(self):
-        return self._pack_cache.values()
+    def _iter_cached_packs(self) -> Iterator[Pack]:
+        return iter(self._pack_cache.values())
 
 
     def _update_pack_cache(self) -> list[Pack]:
     def _update_pack_cache(self) -> list[Pack]:
         raise NotImplementedError(self._update_pack_cache)
         raise NotImplementedError(self._update_pack_cache)
@@ -696,7 +727,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         self._clear_cached_packs()
         self._clear_cached_packs()
 
 
     @property
     @property
-    def packs(self):
+    def packs(self) -> list[Pack]:
         """List with pack objects."""
         """List with pack objects."""
         return list(self._iter_cached_packs()) + list(self._update_pack_cache())
         return list(self._iter_cached_packs()) + list(self._update_pack_cache())
 
 
@@ -714,19 +745,19 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
                 count += 1
                 count += 1
         return count
         return count
 
 
-    def _iter_alternate_objects(self):
+    def _iter_alternate_objects(self) -> Iterator[bytes]:
         """Iterate over the SHAs of all the objects in alternate stores."""
         """Iterate over the SHAs of all the objects in alternate stores."""
         for alternate in self.alternates:
         for alternate in self.alternates:
             yield from alternate
             yield from alternate
 
 
-    def _iter_loose_objects(self):
+    def _iter_loose_objects(self) -> Iterator[bytes]:
         """Iterate over the SHAs of all loose objects."""
         """Iterate over the SHAs of all loose objects."""
         raise NotImplementedError(self._iter_loose_objects)
         raise NotImplementedError(self._iter_loose_objects)
 
 
-    def _get_loose_object(self, sha) -> Optional[ShaFile]:
+    def _get_loose_object(self, sha: bytes) -> Optional[ShaFile]:
         raise NotImplementedError(self._get_loose_object)
         raise NotImplementedError(self._get_loose_object)
 
 
-    def delete_loose_object(self, sha) -> None:
+    def delete_loose_object(self, sha: bytes) -> None:
         """Delete a loose object.
         """Delete a loose object.
 
 
         This method only handles loose objects. For packed objects,
         This method only handles loose objects. For packed objects,
@@ -734,23 +765,25 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         """
         """
         raise NotImplementedError(self.delete_loose_object)
         raise NotImplementedError(self.delete_loose_object)
 
 
-    def _remove_pack(self, name) -> None:
+    def _remove_pack(self, pack: "Pack") -> None:
         raise NotImplementedError(self._remove_pack)
         raise NotImplementedError(self._remove_pack)
 
 
-    def pack_loose_objects(self):
+    def pack_loose_objects(self) -> int:
         """Pack loose objects.
         """Pack loose objects.
 
 
         Returns: Number of objects packed
         Returns: Number of objects packed
         """
         """
-        objects = set()
+        objects: list[tuple[ShaFile, None]] = []
         for sha in self._iter_loose_objects():
         for sha in self._iter_loose_objects():
-            objects.add((self._get_loose_object(sha), None))
-        self.add_objects(list(objects))
+            obj = self._get_loose_object(sha)
+            if obj is not None:
+                objects.append((obj, None))
+        self.add_objects(objects)
         for obj, path in objects:
         for obj, path in objects:
             self.delete_loose_object(obj.id)
             self.delete_loose_object(obj.id)
         return len(objects)
         return len(objects)
 
 
-    def repack(self, exclude=None):
+    def repack(self, exclude: Optional[set] = None) -> int:
         """Repack the packs in this repository.
         """Repack the packs in this repository.
 
 
         Note that this implementation is fairly naive and currently keeps all
         Note that this implementation is fairly naive and currently keeps all
@@ -766,11 +799,13 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         excluded_loose_objects = set()
         excluded_loose_objects = set()
         for sha in self._iter_loose_objects():
         for sha in self._iter_loose_objects():
             if sha not in exclude:
             if sha not in exclude:
-                loose_objects.add(self._get_loose_object(sha))
+                obj = self._get_loose_object(sha)
+                if obj is not None:
+                    loose_objects.add(obj)
             else:
             else:
                 excluded_loose_objects.add(sha)
                 excluded_loose_objects.add(sha)
 
 
-        objects = {(obj, None) for obj in loose_objects}
+        objects: set[tuple[ShaFile, None]] = {(obj, None) for obj in loose_objects}
         old_packs = {p.name(): p for p in self.packs}
         old_packs = {p.name(): p for p in self.packs}
         for name, pack in old_packs.items():
         for name, pack in old_packs.items():
             objects.update(
             objects.update(
@@ -782,12 +817,14 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
             # The name of the consolidated pack might match the name of a
             # The name of the consolidated pack might match the name of a
             # pre-existing pack. Take care not to remove the newly created
             # pre-existing pack. Take care not to remove the newly created
             # consolidated pack.
             # consolidated pack.
-            consolidated = self.add_objects(objects)
-            old_packs.pop(consolidated.name(), None)
+            consolidated = self.add_objects(list(objects))
+            if consolidated is not None:
+                old_packs.pop(consolidated.name(), None)
 
 
         # Delete loose objects that were packed
         # Delete loose objects that were packed
         for obj in loose_objects:
         for obj in loose_objects:
-            self.delete_loose_object(obj.id)
+            if obj is not None:
+                self.delete_loose_object(obj.id)
         # Delete excluded loose objects
         # Delete excluded loose objects
         for sha in excluded_loose_objects:
         for sha in excluded_loose_objects:
             self.delete_loose_object(sha)
             self.delete_loose_object(sha)
@@ -943,9 +980,9 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
                 yield o
                 yield o
                 todo.remove(o.id)
                 todo.remove(o.id)
         for oid in todo:
         for oid in todo:
-            o = self._get_loose_object(oid)
-            if o is not None:
-                yield o
+            loose_obj: Optional[ShaFile] = self._get_loose_object(oid)
+            if loose_obj is not None:
+                yield loose_obj
             elif not allow_missing:
             elif not allow_missing:
                 raise KeyError(oid)
                 raise KeyError(oid)
 
 
@@ -993,7 +1030,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         self,
         self,
         objects: Sequence[tuple[ShaFile, Optional[str]]],
         objects: Sequence[tuple[ShaFile, Optional[str]]],
         progress: Optional[Callable[[str], None]] = None,
         progress: Optional[Callable[[str], None]] = None,
-    ) -> None:
+    ) -> Optional["Pack"]:
         """Add a set of objects to this object store.
         """Add a set of objects to this object store.
 
 
         Args:
         Args:
@@ -1012,6 +1049,8 @@ class DiskObjectStore(PackBasedObjectStore):
 
 
     path: Union[str, os.PathLike]
     path: Union[str, os.PathLike]
     pack_dir: Union[str, os.PathLike]
     pack_dir: Union[str, os.PathLike]
+    _alternates: Optional[list["DiskObjectStore"]]
+    _commit_graph: Optional["CommitGraph"]
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -1244,7 +1283,7 @@ class DiskObjectStore(PackBasedObjectStore):
 
 
     def _get_shafile_path(self, sha):
     def _get_shafile_path(self, sha):
         # Check from object dir
         # Check from object dir
-        return hex_to_filename(self.path, sha)
+        return hex_to_filename(os.fspath(self.path), sha)
 
 
     def _iter_loose_objects(self):
     def _iter_loose_objects(self):
         for base in os.listdir(self.path):
         for base in os.listdir(self.path):
@@ -1345,9 +1384,9 @@ class DiskObjectStore(PackBasedObjectStore):
         os.remove(pack.index.path)
         os.remove(pack.index.path)
 
 
     def _get_pack_basepath(self, entries):
     def _get_pack_basepath(self, entries):
-        suffix = iter_sha1(entry[0] for entry in entries)
+        suffix_bytes = iter_sha1(entry[0] for entry in entries)
         # TODO: Handle self.pack_dir being bytes
         # TODO: Handle self.pack_dir being bytes
-        suffix = suffix.decode("ascii")
+        suffix = suffix_bytes.decode("ascii")
         return os.path.join(self.pack_dir, "pack-" + suffix)
         return os.path.join(self.pack_dir, "pack-" + suffix)
 
 
     def _complete_pack(self, f, path, num_objects, indexer, progress=None):
     def _complete_pack(self, f, path, num_objects, indexer, progress=None):
@@ -1371,7 +1410,7 @@ class DiskObjectStore(PackBasedObjectStore):
 
 
         pack_sha, extra_entries = extend_pack(
         pack_sha, extra_entries = extend_pack(
             f,
             f,
-            indexer.ext_refs(),
+            indexer.ext_refs,
             get_raw=self.get_raw,
             get_raw=self.get_raw,
             compression_level=self.pack_compression_level,
             compression_level=self.pack_compression_level,
             progress=progress,
             progress=progress,
@@ -1466,6 +1505,7 @@ class DiskObjectStore(PackBasedObjectStore):
         def commit():
         def commit():
             if f.tell() > 0:
             if f.tell() > 0:
                 f.seek(0)
                 f.seek(0)
+
                 with PackData(path, f) as pd:
                 with PackData(path, f) as pd:
                     indexer = PackIndexer.for_pack_data(
                     indexer = PackIndexer.for_pack_data(
                         pd, resolve_ext_ref=self.get_raw
                         pd, resolve_ext_ref=self.get_raw
@@ -1798,6 +1838,7 @@ class MemoryObjectStore(BaseObjectStore):
             size = f.tell()
             size = f.tell()
             if size > 0:
             if size > 0:
                 f.seek(0)
                 f.seek(0)
+
                 p = PackData.from_file(f, size)
                 p = PackData.from_file(f, size)
                 for obj in PackInflater.for_pack_data(p, self.get_raw):
                 for obj in PackInflater.for_pack_data(p, self.get_raw):
                     self.add_object(obj)
                     self.add_object(obj)
@@ -2134,7 +2175,7 @@ class ObjectStoreGraphWalker:
     heads: set[ObjectID]
     heads: set[ObjectID]
     """Revisions without descendants in the local repo."""
     """Revisions without descendants in the local repo."""
 
 
-    get_parents: Callable[[ObjectID], ObjectID]
+    get_parents: Callable[[ObjectID], list[ObjectID]]
     """Function to retrieve parents in the local repo."""
     """Function to retrieve parents in the local repo."""
 
 
     shallow: set[ObjectID]
     shallow: set[ObjectID]
@@ -2230,7 +2271,7 @@ def commit_tree_changes(object_store, tree, changes):
     """
     """
     # TODO(jelmer): Save up the objects and add them using .add_objects
     # TODO(jelmer): Save up the objects and add them using .add_objects
     # rather than with individual calls to .add_object.
     # rather than with individual calls to .add_object.
-    nested_changes = {}
+    nested_changes: dict[bytes, list[tuple[bytes, Optional[int], Optional[bytes]]]] = {}
     for path, new_mode, new_sha in changes:
     for path, new_mode, new_sha in changes:
         try:
         try:
             (dirname, subpath) = path.split(b"/", 1)
             (dirname, subpath) = path.split(b"/", 1)
@@ -2465,8 +2506,16 @@ class BucketBasedObjectStore(PackBasedObjectStore):
         """
         """
         # Doesn't exist..
         # Doesn't exist..
 
 
-    def _remove_pack(self, name) -> None:
-        raise NotImplementedError(self._remove_pack)
+    def pack_loose_objects(self) -> int:
+        """Pack loose objects. Returns number of objects packed.
+
+        BucketBasedObjectStore doesn't support loose objects, so this is a no-op.
+        """
+        return 0
+
+    def _remove_pack_by_name(self, name: str) -> None:
+        """Remove a pack by name. Subclasses should implement this."""
+        raise NotImplementedError(self._remove_pack_by_name)
 
 
     def _iter_pack_names(self) -> Iterator[str]:
     def _iter_pack_names(self) -> Iterator[str]:
         raise NotImplementedError(self._iter_pack_names)
         raise NotImplementedError(self._iter_pack_names)
@@ -2511,6 +2560,7 @@ class BucketBasedObjectStore(PackBasedObjectStore):
                 return None
                 return None
 
 
             pf.seek(0)
             pf.seek(0)
+
             p = PackData(pf.name, pf)
             p = PackData(pf.name, pf)
             entries = p.sorted_entries()
             entries = p.sorted_entries()
             basename = iter_sha1(entry[0] for entry in entries).decode("ascii")
             basename = iter_sha1(entry[0] for entry in entries).decode("ascii")

+ 2 - 1
dulwich/objects.py

@@ -430,7 +430,8 @@ class ShaFile:
             self._sha = None
             self._sha = None
             self._chunked_text = self._serialize()
             self._chunked_text = self._serialize()
             self._needs_serialization = False
             self._needs_serialization = False
-        return self._chunked_text  # type: ignore
+        assert self._chunked_text is not None
+        return self._chunked_text
 
 
     def as_raw_string(self) -> bytes:
     def as_raw_string(self) -> bytes:
         """Return raw string with serialization of the object.
         """Return raw string with serialization of the object.

+ 30 - 15
dulwich/objectspec.py

@@ -24,6 +24,7 @@
 from typing import TYPE_CHECKING, Optional, Union
 from typing import TYPE_CHECKING, Optional, Union
 
 
 from .objects import Commit, ShaFile, Tag, Tree
 from .objects import Commit, ShaFile, Tag, Tree
+from .repo import BaseRepo
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
     from .object_store import BaseObjectStore
@@ -40,9 +41,9 @@ def to_bytes(text: Union[str, bytes]) -> bytes:
     Returns:
     Returns:
       Bytes representation of text
       Bytes representation of text
     """
     """
-    if getattr(text, "encode", None) is not None:
-        text = text.encode("ascii")  # type: ignore
-    return text  # type: ignore
+    if isinstance(text, str):
+        return text.encode("ascii")
+    return text
 
 
 
 
 def _resolve_object(repo: "Repo", ref: bytes) -> "ShaFile":
 def _resolve_object(repo: "Repo", ref: bytes) -> "ShaFile":
@@ -136,7 +137,9 @@ def parse_object(repo: "Repo", objectish: Union[bytes, str]) -> "ShaFile":
                         raise ValueError(
                         raise ValueError(
                             f"Commit {commit.id.decode('ascii', 'replace')} has no parents"
                             f"Commit {commit.id.decode('ascii', 'replace')} has no parents"
                         )
                         )
-                    commit = repo[commit.parents[0]]
+                    parent_obj = repo[commit.parents[0]]
+                    assert isinstance(parent_obj, Commit)
+                    commit = parent_obj
                 obj = commit
                 obj = commit
             else:  # sep == b"^"
             else:  # sep == b"^"
                 # Get N-th parent (or commit itself if N=0)
                 # Get N-th parent (or commit itself if N=0)
@@ -157,11 +160,13 @@ def parse_object(repo: "Repo", objectish: Union[bytes, str]) -> "ShaFile":
     return _resolve_object(repo, objectish)
     return _resolve_object(repo, objectish)
 
 
 
 
-def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "Tree":
+def parse_tree(
+    repo: "BaseRepo", treeish: Union[bytes, str, Tree, Commit, Tag]
+) -> "Tree":
     """Parse a string referring to a tree.
     """Parse a string referring to a tree.
 
 
     Args:
     Args:
-      repo: A `Repo` object
+      repo: A repository object
       treeish: A string referring to a tree, or a Tree, Commit, or Tag object
       treeish: A string referring to a tree, or a Tree, Commit, or Tag object
     Returns: A Tree object
     Returns: A Tree object
     Raises:
     Raises:
@@ -173,7 +178,9 @@ def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "
 
 
     # If it's a Commit, return its tree
     # If it's a Commit, return its tree
     if isinstance(treeish, Commit):
     if isinstance(treeish, Commit):
-        return repo[treeish.tree]
+        tree = repo[treeish.tree]
+        assert isinstance(tree, Tree)
+        return tree
 
 
     # For Tag objects or strings, use the existing logic
     # For Tag objects or strings, use the existing logic
     if isinstance(treeish, Tag):
     if isinstance(treeish, Tag):
@@ -181,7 +188,7 @@ def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "
     else:
     else:
         treeish = to_bytes(treeish)
         treeish = to_bytes(treeish)
     try:
     try:
-        treeish = parse_ref(repo, treeish)
+        treeish = parse_ref(repo.refs, treeish)
     except KeyError:  # treeish is commit sha
     except KeyError:  # treeish is commit sha
         pass
         pass
     try:
     try:
@@ -190,15 +197,21 @@ def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "
         # Try parsing as commit (handles short hashes)
         # Try parsing as commit (handles short hashes)
         try:
         try:
             commit = parse_commit(repo, treeish)
             commit = parse_commit(repo, treeish)
-            return repo[commit.tree]
+            assert isinstance(commit, Commit)
+            tree = repo[commit.tree]
+            assert isinstance(tree, Tree)
+            return tree
         except KeyError:
         except KeyError:
             raise KeyError(treeish)
             raise KeyError(treeish)
-    if o.type_name == b"commit":
-        return repo[o.tree]
-    elif o.type_name == b"tag":
+    if isinstance(o, Commit):
+        tree = repo[o.tree]
+        assert isinstance(tree, Tree)
+        return tree
+    elif isinstance(o, Tag):
         # Tag handling - dereference and recurse
         # Tag handling - dereference and recurse
         obj_type, obj_sha = o.object
         obj_type, obj_sha = o.object
         return parse_tree(repo, obj_sha)
         return parse_tree(repo, obj_sha)
+    assert isinstance(o, Tree)
     return o
     return o
 
 
 
 
@@ -383,11 +396,13 @@ def scan_for_short_id(
     raise AmbiguousShortId(prefix, ret)
     raise AmbiguousShortId(prefix, ret)
 
 
 
 
-def parse_commit(repo: "Repo", committish: Union[str, bytes, Commit, Tag]) -> "Commit":
+def parse_commit(
+    repo: "BaseRepo", committish: Union[str, bytes, Commit, Tag]
+) -> "Commit":
     """Parse a string referring to a single commit.
     """Parse a string referring to a single commit.
 
 
     Args:
     Args:
-      repo: A` Repo` object
+      repo: A repository object
       committish: A string referring to a single commit, or a Commit or Tag object.
       committish: A string referring to a single commit, or a Commit or Tag object.
     Returns: A Commit object
     Returns: A Commit object
     Raises:
     Raises:
@@ -426,7 +441,7 @@ def parse_commit(repo: "Repo", committish: Union[str, bytes, Commit, Tag]) -> "C
     else:
     else:
         return dereference_tag(obj)
         return dereference_tag(obj)
     try:
     try:
-        obj = repo[parse_ref(repo, committish)]
+        obj = repo[parse_ref(repo.refs, committish)]
     except KeyError:
     except KeyError:
         pass
         pass
     else:
     else:

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 243 - 169
dulwich/pack.py


+ 8 - 9
dulwich/patch.py

@@ -30,6 +30,7 @@ import time
 from collections.abc import Generator
 from collections.abc import Generator
 from difflib import SequenceMatcher
 from difflib import SequenceMatcher
 from typing import (
 from typing import (
+    IO,
     TYPE_CHECKING,
     TYPE_CHECKING,
     BinaryIO,
     BinaryIO,
     Optional,
     Optional,
@@ -48,7 +49,7 @@ FIRST_FEW_BYTES = 8000
 
 
 
 
 def write_commit_patch(
 def write_commit_patch(
-    f: BinaryIO,
+    f: IO[bytes],
     commit: "Commit",
     commit: "Commit",
     contents: Union[str, bytes],
     contents: Union[str, bytes],
     progress: tuple[int, int],
     progress: tuple[int, int],
@@ -231,7 +232,7 @@ def patch_filename(p: Optional[bytes], root: bytes) -> bytes:
 
 
 
 
 def write_object_diff(
 def write_object_diff(
-    f: BinaryIO,
+    f: IO[bytes],
     store: "BaseObjectStore",
     store: "BaseObjectStore",
     old_file: tuple[Optional[bytes], Optional[int], Optional[bytes]],
     old_file: tuple[Optional[bytes], Optional[int], Optional[bytes]],
     new_file: tuple[Optional[bytes], Optional[int], Optional[bytes]],
     new_file: tuple[Optional[bytes], Optional[int], Optional[bytes]],
@@ -264,19 +265,17 @@ def write_object_diff(
         Returns:
         Returns:
             Blob object
             Blob object
         """
         """
-        from typing import cast
-
         if hexsha is None:
         if hexsha is None:
-            return cast(Blob, Blob.from_string(b""))
+            return Blob.from_string(b"")
         elif mode is not None and S_ISGITLINK(mode):
         elif mode is not None and S_ISGITLINK(mode):
-            return cast(Blob, Blob.from_string(b"Subproject commit " + hexsha + b"\n"))
+            return Blob.from_string(b"Subproject commit " + hexsha + b"\n")
         else:
         else:
             obj = store[hexsha]
             obj = store[hexsha]
             if isinstance(obj, Blob):
             if isinstance(obj, Blob):
                 return obj
                 return obj
             else:
             else:
                 # Fallback for non-blob objects
                 # Fallback for non-blob objects
-                return cast(Blob, Blob.from_string(obj.as_raw_string()))
+                return Blob.from_string(obj.as_raw_string())
 
 
     def lines(content: "Blob") -> list[bytes]:
     def lines(content: "Blob") -> list[bytes]:
         """Split blob content into lines.
         """Split blob content into lines.
@@ -356,7 +355,7 @@ def gen_diff_header(
 
 
 # TODO(jelmer): Support writing unicode, rather than bytes.
 # TODO(jelmer): Support writing unicode, rather than bytes.
 def write_blob_diff(
 def write_blob_diff(
-    f: BinaryIO,
+    f: IO[bytes],
     old_file: tuple[Optional[bytes], Optional[int], Optional["Blob"]],
     old_file: tuple[Optional[bytes], Optional[int], Optional["Blob"]],
     new_file: tuple[Optional[bytes], Optional[int], Optional["Blob"]],
     new_file: tuple[Optional[bytes], Optional[int], Optional["Blob"]],
 ) -> None:
 ) -> None:
@@ -403,7 +402,7 @@ def write_blob_diff(
 
 
 
 
 def write_tree_diff(
 def write_tree_diff(
-    f: BinaryIO,
+    f: IO[bytes],
     store: "BaseObjectStore",
     store: "BaseObjectStore",
     old_tree: Optional[bytes],
     old_tree: Optional[bytes],
     new_tree: Optional[bytes],
     new_tree: Optional[bytes],

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 313 - 227
dulwich/porcelain.py


+ 0 - 12
dulwich/protocol.py

@@ -258,18 +258,6 @@ def pkt_seq(*seq: Optional[bytes]) -> bytes:
     return b"".join([pkt_line(s) for s in seq]) + pkt_line(None)
     return b"".join([pkt_line(s) for s in seq]) + pkt_line(None)
 
 
 
 
-def filter_ref_prefix(
-    refs: dict[bytes, bytes], prefixes: Iterable[bytes]
-) -> dict[bytes, bytes]:
-    """Filter refs to only include those with a given prefix.
-
-    Args:
-      refs: A list of refs.
-      prefixes: The prefixes to filter by.
-    """
-    return {k: v for k, v in refs.items() if any(k.startswith(p) for p in prefixes)}
-
-
 class Protocol:
 class Protocol:
     """Class for interacting with a remote git process over the wire.
     """Class for interacting with a remote git process over the wire.
 
 

+ 17 - 3
dulwich/rebase.py

@@ -32,7 +32,7 @@ from dulwich.graph import find_merge_base
 from dulwich.merge import three_way_merge
 from dulwich.merge import three_way_merge
 from dulwich.objects import Commit
 from dulwich.objects import Commit
 from dulwich.objectspec import parse_commit
 from dulwich.objectspec import parse_commit
-from dulwich.repo import Repo
+from dulwich.repo import BaseRepo, Repo
 
 
 
 
 class RebaseError(Exception):
 class RebaseError(Exception):
@@ -529,7 +529,7 @@ class DiskRebaseStateManager:
 class MemoryRebaseStateManager:
 class MemoryRebaseStateManager:
     """Manages rebase state in memory for MemoryRepo."""
     """Manages rebase state in memory for MemoryRepo."""
 
 
-    def __init__(self, repo: Repo) -> None:
+    def __init__(self, repo: BaseRepo) -> None:
         """Initialize MemoryRebaseStateManager.
         """Initialize MemoryRebaseStateManager.
 
 
         Args:
         Args:
@@ -642,6 +642,8 @@ class Rebaser:
         if branch is None:
         if branch is None:
             # Use current HEAD
             # Use current HEAD
             head_ref, head_sha = self.repo.refs.follow(b"HEAD")
             head_ref, head_sha = self.repo.refs.follow(b"HEAD")
+            if head_sha is None:
+                raise ValueError("HEAD does not point to a valid commit")
             branch_commit = self.repo[head_sha]
             branch_commit = self.repo[head_sha]
         else:
         else:
             # Parse the branch reference
             # Parse the branch reference
@@ -664,6 +666,7 @@ class Rebaser:
         commits = []
         commits = []
         current = branch_commit
         current = branch_commit
         while current.id != merge_base:
         while current.id != merge_base:
+            assert isinstance(current, Commit)
             commits.append(current)
             commits.append(current)
             if not current.parents:
             if not current.parents:
                 break
                 break
@@ -691,6 +694,9 @@ class Rebaser:
         parent = self.repo[commit.parents[0]]
         parent = self.repo[commit.parents[0]]
         onto_commit = self.repo[onto]
         onto_commit = self.repo[onto]
 
 
+        assert isinstance(parent, Commit)
+        assert isinstance(onto_commit, Commit)
+
         # Perform three-way merge
         # Perform three-way merge
         merged_tree, conflicts = three_way_merge(
         merged_tree, conflicts = three_way_merge(
             self.object_store, parent, onto_commit, commit
             self.object_store, parent, onto_commit, commit
@@ -798,7 +804,9 @@ class Rebaser:
 
 
         if new_sha:
         if new_sha:
             # Success - add to done list
             # Success - add to done list
-            self._done.append(self.repo[new_sha])
+            new_commit = self.repo[new_sha]
+            assert isinstance(new_commit, Commit)
+            self._done.append(new_commit)
             self._save_rebase_state()
             self._save_rebase_state()
 
 
             # Continue with next commit if any
             # Continue with next commit if any
@@ -822,6 +830,8 @@ class Rebaser:
             raise RebaseError("No rebase in progress")
             raise RebaseError("No rebase in progress")
 
 
         # Restore original HEAD
         # Restore original HEAD
+        if self._original_head is None:
+            raise RebaseError("No original HEAD to restore")
         self.repo.refs[b"HEAD"] = self._original_head
         self.repo.refs[b"HEAD"] = self._original_head
 
 
         # Clean up rebase state
         # Clean up rebase state
@@ -1200,12 +1210,16 @@ def _squash_commits(
     if not entry.commit_sha:
     if not entry.commit_sha:
         raise RebaseError("No commit SHA for squash/fixup operation")
         raise RebaseError("No commit SHA for squash/fixup operation")
     commit_to_squash = repo[entry.commit_sha]
     commit_to_squash = repo[entry.commit_sha]
+    if not isinstance(commit_to_squash, Commit):
+        raise RebaseError(f"Expected commit, got {type(commit_to_squash).__name__}")
 
 
     # Get the previous commit (target of squash)
     # Get the previous commit (target of squash)
     previous_commit = rebaser._done[-1]
     previous_commit = rebaser._done[-1]
 
 
     # Cherry-pick the changes onto the previous commit
     # Cherry-pick the changes onto the previous commit
     parent = repo[commit_to_squash.parents[0]]
     parent = repo[commit_to_squash.parents[0]]
+    if not isinstance(parent, Commit):
+        raise RebaseError(f"Expected parent commit, got {type(parent).__name__}")
 
 
     # Perform three-way merge for the tree
     # Perform three-way merge for the tree
     merged_tree, conflicts = three_way_merge(
     merged_tree, conflicts = three_way_merge(

+ 259 - 242
dulwich/refs.py

@@ -25,9 +25,19 @@
 import os
 import os
 import types
 import types
 import warnings
 import warnings
-from collections.abc import Iterator
+from collections.abc import Iterable, Iterator
 from contextlib import suppress
 from contextlib import suppress
-from typing import TYPE_CHECKING, Any, Optional, Union
+from typing import (
+    IO,
+    TYPE_CHECKING,
+    Any,
+    BinaryIO,
+    Callable,
+    Optional,
+    TypeVar,
+    Union,
+    cast,
+)
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .file import _GitFile
     from .file import _GitFile
@@ -55,13 +65,8 @@ ANNOTATED_TAG_SUFFIX = PEELED_TAG_SUFFIX
 class SymrefLoop(Exception):
 class SymrefLoop(Exception):
     """There is a loop between one or more symrefs."""
     """There is a loop between one or more symrefs."""
 
 
-    def __init__(self, ref, depth) -> None:
-        """Initialize a SymrefLoop exception.
-
-        Args:
-          ref: The ref that caused the loop
-          depth: Depth at which the loop was detected
-        """
+    def __init__(self, ref: bytes, depth: int) -> None:
+        """Initialize SymrefLoop exception."""
         self.ref = ref
         self.ref = ref
         self.depth = depth
         self.depth = depth
 
 
@@ -142,23 +147,35 @@ def parse_remote_ref(ref: bytes) -> tuple[bytes, bytes]:
 class RefsContainer:
 class RefsContainer:
     """A container for refs."""
     """A container for refs."""
 
 
-    def __init__(self, logger=None) -> None:
-        """Initialize a RefsContainer.
-
-        Args:
-          logger: Optional logger for reflog updates
-        """
+    def __init__(
+        self,
+        logger: Optional[
+            Callable[
+                [
+                    bytes,
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[int],
+                    Optional[int],
+                    Optional[bytes],
+                ],
+                None,
+            ]
+        ] = None,
+    ) -> None:
+        """Initialize RefsContainer with optional logger function."""
         self._logger = logger
         self._logger = logger
 
 
     def _log(
     def _log(
         self,
         self,
-        ref,
-        old_sha,
-        new_sha,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        ref: bytes,
+        old_sha: Optional[bytes],
+        new_sha: Optional[bytes],
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> None:
     ) -> None:
         if self._logger is None:
         if self._logger is None:
             return
             return
@@ -168,12 +185,12 @@ class RefsContainer:
 
 
     def set_symbolic_ref(
     def set_symbolic_ref(
         self,
         self,
-        name,
-        other,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        other: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> None:
     ) -> None:
         """Make a ref point at another ref.
         """Make a ref point at another ref.
 
 
@@ -206,7 +223,7 @@ class RefsContainer:
         """
         """
         raise NotImplementedError(self.add_packed_refs)
         raise NotImplementedError(self.add_packed_refs)
 
 
-    def get_peeled(self, name) -> Optional[ObjectID]:
+    def get_peeled(self, name: bytes) -> Optional[ObjectID]:
         """Return the cached peeled value of a ref, if available.
         """Return the cached peeled value of a ref, if available.
 
 
         Args:
         Args:
@@ -257,12 +274,12 @@ class RefsContainer:
         for ref in to_delete:
         for ref in to_delete:
             self.remove_if_equals(b"/".join((base, ref)), None, message=message)
             self.remove_if_equals(b"/".join((base, ref)), None, message=message)
 
 
-    def allkeys(self) -> Iterator[Ref]:
+    def allkeys(self) -> set[Ref]:
         """All refs present in this container."""
         """All refs present in this container."""
         raise NotImplementedError(self.allkeys)
         raise NotImplementedError(self.allkeys)
 
 
-    def __iter__(self):
-        """Iterate over all ref names."""
+    def __iter__(self) -> Iterator[Ref]:
+        """Iterate over all reference keys."""
         return iter(self.allkeys())
         return iter(self.allkeys())
 
 
     def keys(self, base=None):
     def keys(self, base=None):
@@ -278,7 +295,7 @@ class RefsContainer:
         else:
         else:
             return self.allkeys()
             return self.allkeys()
 
 
-    def subkeys(self, base):
+    def subkeys(self, base: bytes) -> set[bytes]:
         """Refs present in this container under a base.
         """Refs present in this container under a base.
 
 
         Args:
         Args:
@@ -293,7 +310,7 @@ class RefsContainer:
                 keys.add(refname[base_len:])
                 keys.add(refname[base_len:])
         return keys
         return keys
 
 
-    def as_dict(self, base=None) -> dict[Ref, ObjectID]:
+    def as_dict(self, base: Optional[bytes] = None) -> dict[Ref, ObjectID]:
         """Return the contents of this container as a dictionary."""
         """Return the contents of this container as a dictionary."""
         ret = {}
         ret = {}
         keys = self.keys(base)
         keys = self.keys(base)
@@ -309,7 +326,7 @@ class RefsContainer:
 
 
         return ret
         return ret
 
 
-    def _check_refname(self, name) -> None:
+    def _check_refname(self, name: bytes) -> None:
         """Ensure a refname is valid and lives in refs or is HEAD.
         """Ensure a refname is valid and lives in refs or is HEAD.
 
 
         HEAD is not a valid refname according to git-check-ref-format, but this
         HEAD is not a valid refname according to git-check-ref-format, but this
@@ -328,7 +345,7 @@ class RefsContainer:
         if not name.startswith(b"refs/") or not check_ref_format(name[5:]):
         if not name.startswith(b"refs/") or not check_ref_format(name[5:]):
             raise RefFormatError(name)
             raise RefFormatError(name)
 
 
-    def read_ref(self, refname):
+    def read_ref(self, refname: bytes) -> Optional[bytes]:
         """Read a reference without following any references.
         """Read a reference without following any references.
 
 
         Args:
         Args:
@@ -341,7 +358,7 @@ class RefsContainer:
             contents = self.get_packed_refs().get(refname, None)
             contents = self.get_packed_refs().get(refname, None)
         return contents
         return contents
 
 
-    def read_loose_ref(self, name) -> bytes:
+    def read_loose_ref(self, name: bytes) -> Optional[bytes]:
         """Read a loose reference and return its contents.
         """Read a loose reference and return its contents.
 
 
         Args:
         Args:
@@ -351,16 +368,16 @@ class RefsContainer:
         """
         """
         raise NotImplementedError(self.read_loose_ref)
         raise NotImplementedError(self.read_loose_ref)
 
 
-    def follow(self, name) -> tuple[list[bytes], bytes]:
+    def follow(self, name: bytes) -> tuple[list[bytes], Optional[bytes]]:
         """Follow a reference name.
         """Follow a reference name.
 
 
         Returns: a tuple of (refnames, sha), wheres refnames are the names of
         Returns: a tuple of (refnames, sha), wheres refnames are the names of
             references in the chain
             references in the chain
         """
         """
-        contents = SYMREF + name
+        contents: Optional[bytes] = SYMREF + name
         depth = 0
         depth = 0
         refnames = []
         refnames = []
-        while contents.startswith(SYMREF):
+        while contents and contents.startswith(SYMREF):
             refname = contents[len(SYMREF) :]
             refname = contents[len(SYMREF) :]
             refnames.append(refname)
             refnames.append(refname)
             contents = self.read_ref(refname)
             contents = self.read_ref(refname)
@@ -371,20 +388,13 @@ class RefsContainer:
                 raise SymrefLoop(name, depth)
                 raise SymrefLoop(name, depth)
         return refnames, contents
         return refnames, contents
 
 
-    def __contains__(self, refname) -> bool:
-        """Check if a ref exists.
-
-        Args:
-          refname: Name of the ref to check
-
-        Returns:
-          True if the ref exists
-        """
+    def __contains__(self, refname: bytes) -> bool:
+        """Check if a reference exists."""
         if self.read_ref(refname):
         if self.read_ref(refname):
             return True
             return True
         return False
         return False
 
 
-    def __getitem__(self, name) -> ObjectID:
+    def __getitem__(self, name: bytes) -> ObjectID:
         """Get the SHA1 for a reference name.
         """Get the SHA1 for a reference name.
 
 
         This method follows all symbolic references.
         This method follows all symbolic references.
@@ -396,13 +406,13 @@ class RefsContainer:
 
 
     def set_if_equals(
     def set_if_equals(
         self,
         self,
-        name,
-        old_ref,
-        new_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        new_ref: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Set a refname to new_ref only if it currently equals old_ref.
         """Set a refname to new_ref only if it currently equals old_ref.
 
 
@@ -424,7 +434,13 @@ class RefsContainer:
         raise NotImplementedError(self.set_if_equals)
         raise NotImplementedError(self.set_if_equals)
 
 
     def add_if_new(
     def add_if_new(
-        self, name, ref, committer=None, timestamp=None, timezone=None, message=None
+        self,
+        name: bytes,
+        ref: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Add a new reference only if it does not already exist.
         """Add a new reference only if it does not already exist.
 
 
@@ -438,7 +454,7 @@ class RefsContainer:
         """
         """
         raise NotImplementedError(self.add_if_new)
         raise NotImplementedError(self.add_if_new)
 
 
-    def __setitem__(self, name, ref) -> None:
+    def __setitem__(self, name: bytes, ref: bytes) -> None:
         """Set a reference name to point to the given SHA1.
         """Set a reference name to point to the given SHA1.
 
 
         This method follows all symbolic references if applicable for the
         This method follows all symbolic references if applicable for the
@@ -458,12 +474,12 @@ class RefsContainer:
 
 
     def remove_if_equals(
     def remove_if_equals(
         self,
         self,
-        name,
-        old_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Remove a refname only if it currently equals old_ref.
         """Remove a refname only if it currently equals old_ref.
 
 
@@ -483,7 +499,7 @@ class RefsContainer:
         """
         """
         raise NotImplementedError(self.remove_if_equals)
         raise NotImplementedError(self.remove_if_equals)
 
 
-    def __delitem__(self, name) -> None:
+    def __delitem__(self, name: bytes) -> None:
         """Remove a refname.
         """Remove a refname.
 
 
         This method does not follow symbolic references, even if applicable for
         This method does not follow symbolic references, even if applicable for
@@ -498,7 +514,7 @@ class RefsContainer:
         """
         """
         self.remove_if_equals(name, None)
         self.remove_if_equals(name, None)
 
 
-    def get_symrefs(self):
+    def get_symrefs(self) -> dict[bytes, bytes]:
         """Get a dict with all symrefs in this container.
         """Get a dict with all symrefs in this container.
 
 
         Returns: Dictionary mapping source ref to target ref
         Returns: Dictionary mapping source ref to target ref
@@ -506,7 +522,9 @@ class RefsContainer:
         ret = {}
         ret = {}
         for src in self.allkeys():
         for src in self.allkeys():
             try:
             try:
-                dst = parse_symref_value(self.read_ref(src))
+                ref_value = self.read_ref(src)
+                assert ref_value is not None
+                dst = parse_symref_value(ref_value)
             except ValueError:
             except ValueError:
                 pass
                 pass
             else:
             else:
@@ -529,41 +547,43 @@ class DictRefsContainer(RefsContainer):
     threadsafe.
     threadsafe.
     """
     """
 
 
-    def __init__(self, refs, logger=None) -> None:
-        """Initialize DictRefsContainer."""
+    def __init__(
+        self,
+        refs: dict[bytes, bytes],
+        logger: Optional[
+            Callable[
+                [
+                    bytes,
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[int],
+                    Optional[int],
+                    Optional[bytes],
+                ],
+                None,
+            ]
+        ] = None,
+    ) -> None:
+        """Initialize DictRefsContainer with refs dictionary and optional logger."""
         super().__init__(logger=logger)
         super().__init__(logger=logger)
         self._refs = refs
         self._refs = refs
         self._peeled: dict[bytes, ObjectID] = {}
         self._peeled: dict[bytes, ObjectID] = {}
         self._watchers: set[Any] = set()
         self._watchers: set[Any] = set()
 
 
-    def allkeys(self):
-        """Get all ref names.
+    def allkeys(self) -> set[bytes]:
+        """Return all reference keys."""
+        return set(self._refs.keys())
 
 
-        Returns:
-          All ref names in the container
-        """
-        return self._refs.keys()
-
-    def read_loose_ref(self, name):
-        """Read a reference from the refs dictionary.
-
-        Args:
-          name: The ref name to read
-
-        Returns:
-          The ref value or None if not found
-        """
+    def read_loose_ref(self, name: bytes) -> Optional[bytes]:
+        """Read a loose reference."""
         return self._refs.get(name, None)
         return self._refs.get(name, None)
 
 
-    def get_packed_refs(self):
-        """Get packed refs (always empty for DictRefsContainer).
-
-        Returns:
-          Empty dictionary
-        """
+    def get_packed_refs(self) -> dict[bytes, bytes]:
+        """Get packed references."""
         return {}
         return {}
 
 
-    def _notify(self, ref, newsha) -> None:
+    def _notify(self, ref: bytes, newsha: Optional[bytes]) -> None:
         for watcher in self._watchers:
         for watcher in self._watchers:
             watcher._notify((ref, newsha))
             watcher._notify((ref, newsha))
 
 
@@ -571,10 +591,10 @@ class DictRefsContainer(RefsContainer):
         self,
         self,
         name: Ref,
         name: Ref,
         other: Ref,
         other: Ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> None:
     ) -> None:
         """Make a ref point at another ref.
         """Make a ref point at another ref.
 
 
@@ -602,13 +622,13 @@ class DictRefsContainer(RefsContainer):
 
 
     def set_if_equals(
     def set_if_equals(
         self,
         self,
-        name,
-        old_ref,
-        new_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        new_ref: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Set a refname to new_ref only if it currently equals old_ref.
         """Set a refname to new_ref only if it currently equals old_ref.
 
 
@@ -650,9 +670,9 @@ class DictRefsContainer(RefsContainer):
         self,
         self,
         name: Ref,
         name: Ref,
         ref: ObjectID,
         ref: ObjectID,
-        committer=None,
-        timestamp=None,
-        timezone=None,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
         message: Optional[bytes] = None,
         message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Add a new reference only if it does not already exist.
         """Add a new reference only if it does not already exist.
@@ -685,12 +705,12 @@ class DictRefsContainer(RefsContainer):
 
 
     def remove_if_equals(
     def remove_if_equals(
         self,
         self,
-        name,
-        old_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Remove a refname only if it currently equals old_ref.
         """Remove a refname only if it currently equals old_ref.
 
 
@@ -728,25 +748,18 @@ class DictRefsContainer(RefsContainer):
             )
             )
         return True
         return True
 
 
-    def get_peeled(self, name):
-        """Get the peeled value of a ref.
-
-        Args:
-          name: Ref name to get peeled value for
-
-        Returns:
-          The peeled SHA or None if not available
-        """
+    def get_peeled(self, name: bytes) -> Optional[bytes]:
+        """Get peeled version of a reference."""
         return self._peeled.get(name)
         return self._peeled.get(name)
 
 
-    def _update(self, refs) -> None:
+    def _update(self, refs: dict[bytes, bytes]) -> None:
         """Update multiple refs; intended only for testing."""
         """Update multiple refs; intended only for testing."""
         # TODO(dborowitz): replace this with a public function that uses
         # TODO(dborowitz): replace this with a public function that uses
         # set_if_equal.
         # set_if_equal.
         for ref, sha in refs.items():
         for ref, sha in refs.items():
             self.set_if_equals(ref, None, sha)
             self.set_if_equals(ref, None, sha)
 
 
-    def _update_peeled(self, peeled) -> None:
+    def _update_peeled(self, peeled: dict[bytes, bytes]) -> None:
         """Update cached peeled refs; intended only for testing."""
         """Update cached peeled refs; intended only for testing."""
         self._peeled.update(peeled)
         self._peeled.update(peeled)
 
 
@@ -754,56 +767,27 @@ class DictRefsContainer(RefsContainer):
 class InfoRefsContainer(RefsContainer):
 class InfoRefsContainer(RefsContainer):
     """Refs container that reads refs from a info/refs file."""
     """Refs container that reads refs from a info/refs file."""
 
 
-    def __init__(self, f) -> None:
-        """Initialize an InfoRefsContainer.
-
-        Args:
-          f: File-like object containing info/refs data
-        """
-        self._refs = {}
-        self._peeled = {}
+    def __init__(self, f: BinaryIO) -> None:
+        """Initialize InfoRefsContainer from info/refs file."""
+        self._refs: dict[bytes, bytes] = {}
+        self._peeled: dict[bytes, bytes] = {}
         refs = read_info_refs(f)
         refs = read_info_refs(f)
         (self._refs, self._peeled) = split_peeled_refs(refs)
         (self._refs, self._peeled) = split_peeled_refs(refs)
 
 
-    def allkeys(self):
-        """Get all ref names.
+    def allkeys(self) -> set[bytes]:
+        """Return all reference keys."""
+        return set(self._refs.keys())
 
 
-        Returns:
-          All ref names in the info/refs file
-        """
-        return self._refs.keys()
-
-    def read_loose_ref(self, name):
-        """Read a reference from the parsed info/refs.
-
-        Args:
-          name: The ref name to read
-
-        Returns:
-          The ref value or None if not found
-        """
+    def read_loose_ref(self, name: bytes) -> Optional[bytes]:
+        """Read a loose reference."""
         return self._refs.get(name, None)
         return self._refs.get(name, None)
 
 
-    def get_packed_refs(self):
-        """Get packed refs (always empty for InfoRefsContainer).
-
-        Returns:
-          Empty dictionary
-        """
+    def get_packed_refs(self) -> dict[bytes, bytes]:
+        """Get packed references."""
         return {}
         return {}
 
 
-    def get_peeled(self, name):
-        """Get the peeled value of a ref.
-
-        Args:
-          name: Ref name to get peeled value for
-
-        Returns:
-          The peeled SHA if available, otherwise the ref value itself
-
-        Raises:
-          KeyError: If the ref doesn't exist
-        """
+    def get_peeled(self, name: bytes) -> Optional[bytes]:
+        """Get peeled version of a reference."""
         try:
         try:
             return self._peeled[name]
             return self._peeled[name]
         except KeyError:
         except KeyError:
@@ -817,7 +801,20 @@ class DiskRefsContainer(RefsContainer):
         self,
         self,
         path: Union[str, bytes, os.PathLike],
         path: Union[str, bytes, os.PathLike],
         worktree_path: Optional[Union[str, bytes, os.PathLike]] = None,
         worktree_path: Optional[Union[str, bytes, os.PathLike]] = None,
-        logger=None,
+        logger: Optional[
+            Callable[
+                [
+                    bytes,
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[bytes],
+                    Optional[int],
+                    Optional[int],
+                    Optional[bytes],
+                ],
+                None,
+            ]
+        ] = None,
     ) -> None:
     ) -> None:
         """Initialize DiskRefsContainer."""
         """Initialize DiskRefsContainer."""
         super().__init__(logger=logger)
         super().__init__(logger=logger)
@@ -827,22 +824,15 @@ class DiskRefsContainer(RefsContainer):
             self.worktree_path = self.path
             self.worktree_path = self.path
         else:
         else:
             self.worktree_path = os.fsencode(os.fspath(worktree_path))
             self.worktree_path = os.fsencode(os.fspath(worktree_path))
-        self._packed_refs = None
-        self._peeled_refs = None
+        self._packed_refs: Optional[dict[bytes, bytes]] = None
+        self._peeled_refs: Optional[dict[bytes, bytes]] = None
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         """Return string representation of DiskRefsContainer."""
         """Return string representation of DiskRefsContainer."""
         return f"{self.__class__.__name__}({self.path!r})"
         return f"{self.__class__.__name__}({self.path!r})"
 
 
-    def subkeys(self, base):
-        """Get all ref names under a base ref.
-
-        Args:
-          base: Base ref path to search under
-
-        Returns:
-          Set of ref names under the base (without base prefix)
-        """
+    def subkeys(self, base: bytes) -> set[bytes]:
+        """Return subkeys under a given base reference path."""
         subkeys = set()
         subkeys = set()
         path = self.refpath(base)
         path = self.refpath(base)
         for root, unused_dirs, files in os.walk(path):
         for root, unused_dirs, files in os.walk(path):
@@ -861,12 +851,8 @@ class DiskRefsContainer(RefsContainer):
                 subkeys.add(key[len(base) :].strip(b"/"))
                 subkeys.add(key[len(base) :].strip(b"/"))
         return subkeys
         return subkeys
 
 
-    def allkeys(self):
-        """Get all ref names from disk.
-
-        Returns:
-          Set of all ref names (both loose and packed)
-        """
+    def allkeys(self) -> set[bytes]:
+        """Return all reference keys."""
         allkeys = set()
         allkeys = set()
         if os.path.exists(self.refpath(HEADREF)):
         if os.path.exists(self.refpath(HEADREF)):
             allkeys.add(HEADREF)
             allkeys.add(HEADREF)
@@ -883,7 +869,7 @@ class DiskRefsContainer(RefsContainer):
         allkeys.update(self.get_packed_refs())
         allkeys.update(self.get_packed_refs())
         return allkeys
         return allkeys
 
 
-    def refpath(self, name):
+    def refpath(self, name: bytes) -> bytes:
         """Return the disk path of a ref."""
         """Return the disk path of a ref."""
         if os.path.sep != "/":
         if os.path.sep != "/":
             name = name.replace(b"/", os.fsencode(os.path.sep))
             name = name.replace(b"/", os.fsencode(os.path.sep))
@@ -894,7 +880,7 @@ class DiskRefsContainer(RefsContainer):
         else:
         else:
             return os.path.join(self.path, name)
             return os.path.join(self.path, name)
 
 
-    def get_packed_refs(self):
+    def get_packed_refs(self) -> dict[bytes, bytes]:
         """Get contents of the packed-refs file.
         """Get contents of the packed-refs file.
 
 
         Returns: Dictionary mapping ref names to SHA1s
         Returns: Dictionary mapping ref names to SHA1s
@@ -962,7 +948,7 @@ class DiskRefsContainer(RefsContainer):
 
 
             self._packed_refs = packed_refs
             self._packed_refs = packed_refs
 
 
-    def get_peeled(self, name):
+    def get_peeled(self, name: bytes) -> Optional[bytes]:
         """Return the cached peeled value of a ref, if available.
         """Return the cached peeled value of a ref, if available.
 
 
         Args:
         Args:
@@ -972,7 +958,11 @@ class DiskRefsContainer(RefsContainer):
             to a tag, but no cached information is available, None is returned.
             to a tag, but no cached information is available, None is returned.
         """
         """
         self.get_packed_refs()
         self.get_packed_refs()
-        if self._peeled_refs is None or name not in self._packed_refs:
+        if (
+            self._peeled_refs is None
+            or self._packed_refs is None
+            or name not in self._packed_refs
+        ):
             # No cache: no peeled refs were read, or this ref is loose
             # No cache: no peeled refs were read, or this ref is loose
             return None
             return None
         if name in self._peeled_refs:
         if name in self._peeled_refs:
@@ -981,7 +971,7 @@ class DiskRefsContainer(RefsContainer):
             # Known not peelable
             # Known not peelable
             return self[name]
             return self[name]
 
 
-    def read_loose_ref(self, name):
+    def read_loose_ref(self, name: bytes) -> Optional[bytes]:
         """Read a reference file and return its contents.
         """Read a reference file and return its contents.
 
 
         If the reference file a symbolic reference, only read the first line of
         If the reference file a symbolic reference, only read the first line of
@@ -1011,7 +1001,7 @@ class DiskRefsContainer(RefsContainer):
             # errors depending on the specific operating system
             # errors depending on the specific operating system
             return None
             return None
 
 
-    def _remove_packed_ref(self, name) -> None:
+    def _remove_packed_ref(self, name: bytes) -> None:
         if self._packed_refs is None:
         if self._packed_refs is None:
             return
             return
         filename = os.path.join(self.path, b"packed-refs")
         filename = os.path.join(self.path, b"packed-refs")
@@ -1021,13 +1011,14 @@ class DiskRefsContainer(RefsContainer):
             self._packed_refs = None
             self._packed_refs = None
             self.get_packed_refs()
             self.get_packed_refs()
 
 
-            if name not in self._packed_refs:
+            if self._packed_refs is None or name not in self._packed_refs:
                 f.abort()
                 f.abort()
                 return
                 return
 
 
             del self._packed_refs[name]
             del self._packed_refs[name]
-            with suppress(KeyError):
-                del self._peeled_refs[name]
+            if self._peeled_refs is not None:
+                with suppress(KeyError):
+                    del self._peeled_refs[name]
             write_packed_refs(f, self._packed_refs, self._peeled_refs)
             write_packed_refs(f, self._packed_refs, self._peeled_refs)
             f.close()
             f.close()
         except BaseException:
         except BaseException:
@@ -1036,12 +1027,12 @@ class DiskRefsContainer(RefsContainer):
 
 
     def set_symbolic_ref(
     def set_symbolic_ref(
         self,
         self,
-        name,
-        other,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        other: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> None:
     ) -> None:
         """Make a ref point at another ref.
         """Make a ref point at another ref.
 
 
@@ -1077,13 +1068,13 @@ class DiskRefsContainer(RefsContainer):
 
 
     def set_if_equals(
     def set_if_equals(
         self,
         self,
-        name,
-        old_ref,
-        new_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        new_ref: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Set a refname to new_ref only if it currently equals old_ref.
         """Set a refname to new_ref only if it currently equals old_ref.
 
 
@@ -1163,9 +1154,9 @@ class DiskRefsContainer(RefsContainer):
         self,
         self,
         name: bytes,
         name: bytes,
         ref: bytes,
         ref: bytes,
-        committer=None,
-        timestamp=None,
-        timezone=None,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
         message: Optional[bytes] = None,
         message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Add a new reference only if it does not already exist.
         """Add a new reference only if it does not already exist.
@@ -1215,12 +1206,12 @@ class DiskRefsContainer(RefsContainer):
 
 
     def remove_if_equals(
     def remove_if_equals(
         self,
         self,
-        name,
-        old_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Remove a refname only if it currently equals old_ref.
         """Remove a refname only if it currently equals old_ref.
 
 
@@ -1321,7 +1312,7 @@ class DiskRefsContainer(RefsContainer):
             self.add_packed_refs(refs_to_pack)
             self.add_packed_refs(refs_to_pack)
 
 
 
 
-def _split_ref_line(line):
+def _split_ref_line(line: bytes) -> tuple[bytes, bytes]:
     """Split a single ref line into a tuple of SHA1 and name."""
     """Split a single ref line into a tuple of SHA1 and name."""
     fields = line.rstrip(b"\n\r").split(b" ")
     fields = line.rstrip(b"\n\r").split(b" ")
     if len(fields) != 2:
     if len(fields) != 2:
@@ -1334,7 +1325,7 @@ def _split_ref_line(line):
     return (sha, name)
     return (sha, name)
 
 
 
 
-def read_packed_refs(f):
+def read_packed_refs(f: IO[bytes]) -> Iterator[tuple[bytes, bytes]]:
     """Read a packed refs file.
     """Read a packed refs file.
 
 
     Args:
     Args:
@@ -1350,7 +1341,9 @@ def read_packed_refs(f):
         yield _split_ref_line(line)
         yield _split_ref_line(line)
 
 
 
 
-def read_packed_refs_with_peeled(f):
+def read_packed_refs_with_peeled(
+    f: IO[bytes],
+) -> Iterator[tuple[bytes, bytes, Optional[bytes]]]:
     """Read a packed refs file including peeled refs.
     """Read a packed refs file including peeled refs.
 
 
     Assumes the "# pack-refs with: peeled" line was already read. Yields tuples
     Assumes the "# pack-refs with: peeled" line was already read. Yields tuples
@@ -1382,7 +1375,11 @@ def read_packed_refs_with_peeled(f):
         yield (sha, name, None)
         yield (sha, name, None)
 
 
 
 
-def write_packed_refs(f, packed_refs, peeled_refs=None) -> None:
+def write_packed_refs(
+    f: IO[bytes],
+    packed_refs: dict[bytes, bytes],
+    peeled_refs: Optional[dict[bytes, bytes]] = None,
+) -> None:
     """Write a packed refs file.
     """Write a packed refs file.
 
 
     Args:
     Args:
@@ -1400,7 +1397,7 @@ def write_packed_refs(f, packed_refs, peeled_refs=None) -> None:
             f.write(b"^" + peeled_refs[refname] + b"\n")
             f.write(b"^" + peeled_refs[refname] + b"\n")
 
 
 
 
-def read_info_refs(f):
+def read_info_refs(f: BinaryIO) -> dict[bytes, bytes]:
     """Read info/refs file.
     """Read info/refs file.
 
 
     Args:
     Args:
@@ -1416,7 +1413,9 @@ def read_info_refs(f):
     return ret
     return ret
 
 
 
 
-def write_info_refs(refs, store: ObjectContainer):
+def write_info_refs(
+    refs: dict[bytes, bytes], store: ObjectContainer
+) -> Iterator[bytes]:
     """Generate info refs."""
     """Generate info refs."""
     # TODO: Avoid recursive import :(
     # TODO: Avoid recursive import :(
     from .object_store import peel_sha
     from .object_store import peel_sha
@@ -1436,38 +1435,38 @@ def write_info_refs(refs, store: ObjectContainer):
             yield peeled.id + b"\t" + name + PEELED_TAG_SUFFIX + b"\n"
             yield peeled.id + b"\t" + name + PEELED_TAG_SUFFIX + b"\n"
 
 
 
 
-def is_local_branch(x):
-    """Check if a ref name refers to a local branch.
+def is_local_branch(x: bytes) -> bool:
+    """Check if a ref name is a local branch."""
+    return x.startswith(LOCAL_BRANCH_PREFIX)
 
 
-    Args:
-      x: Ref name to check
 
 
-    Returns:
-      True if ref is a local branch (refs/heads/...)
-    """
-    return x.startswith(LOCAL_BRANCH_PREFIX)
+T = TypeVar("T", dict[bytes, bytes], dict[bytes, Optional[bytes]])
 
 
 
 
-def strip_peeled_refs(refs):
+def strip_peeled_refs(refs: T) -> T:
     """Remove all peeled refs."""
     """Remove all peeled refs."""
     return {
     return {
         ref: sha for (ref, sha) in refs.items() if not ref.endswith(PEELED_TAG_SUFFIX)
         ref: sha for (ref, sha) in refs.items() if not ref.endswith(PEELED_TAG_SUFFIX)
     }
     }
 
 
 
 
-def split_peeled_refs(refs):
+def split_peeled_refs(refs: T) -> tuple[T, dict[bytes, bytes]]:
     """Split peeled refs from regular refs."""
     """Split peeled refs from regular refs."""
-    peeled = {}
-    regular = {}
+    peeled: dict[bytes, bytes] = {}
+    regular = {k: v for k, v in refs.items() if not k.endswith(PEELED_TAG_SUFFIX)}
+
     for ref, sha in refs.items():
     for ref, sha in refs.items():
         if ref.endswith(PEELED_TAG_SUFFIX):
         if ref.endswith(PEELED_TAG_SUFFIX):
-            peeled[ref[: -len(PEELED_TAG_SUFFIX)]] = sha
-        else:
-            regular[ref] = sha
+            # Only add to peeled dict if sha is not None
+            if sha is not None:
+                peeled[ref[: -len(PEELED_TAG_SUFFIX)]] = sha
+
     return regular, peeled
     return regular, peeled
 
 
 
 
-def _set_origin_head(refs, origin, origin_head) -> None:
+def _set_origin_head(
+    refs: RefsContainer, origin: bytes, origin_head: Optional[bytes]
+) -> None:
     # set refs/remotes/origin/HEAD
     # set refs/remotes/origin/HEAD
     origin_base = b"refs/remotes/" + origin + b"/"
     origin_base = b"refs/remotes/" + origin + b"/"
     if origin_head and origin_head.startswith(LOCAL_BRANCH_PREFIX):
     if origin_head and origin_head.startswith(LOCAL_BRANCH_PREFIX):
@@ -1511,7 +1510,9 @@ def _set_default_branch(
     return head_ref
     return head_ref
 
 
 
 
-def _set_head(refs, head_ref, ref_message):
+def _set_head(
+    refs: RefsContainer, head_ref: bytes, ref_message: Optional[bytes]
+) -> Optional[bytes]:
     if head_ref.startswith(LOCAL_TAG_PREFIX):
     if head_ref.startswith(LOCAL_TAG_PREFIX):
         # detach HEAD at specified tag
         # detach HEAD at specified tag
         head = refs[head_ref]
         head = refs[head_ref]
@@ -1534,7 +1535,7 @@ def _set_head(refs, head_ref, ref_message):
 def _import_remote_refs(
 def _import_remote_refs(
     refs_container: RefsContainer,
     refs_container: RefsContainer,
     remote_name: str,
     remote_name: str,
-    refs: dict[str, str],
+    refs: dict[bytes, Optional[bytes]],
     message: Optional[bytes] = None,
     message: Optional[bytes] = None,
     prune: bool = False,
     prune: bool = False,
     prune_tags: bool = False,
     prune_tags: bool = False,
@@ -1543,7 +1544,7 @@ def _import_remote_refs(
     branches = {
     branches = {
         n[len(LOCAL_BRANCH_PREFIX) :]: v
         n[len(LOCAL_BRANCH_PREFIX) :]: v
         for (n, v) in stripped_refs.items()
         for (n, v) in stripped_refs.items()
-        if n.startswith(LOCAL_BRANCH_PREFIX)
+        if n.startswith(LOCAL_BRANCH_PREFIX) and v is not None
     }
     }
     refs_container.import_refs(
     refs_container.import_refs(
         b"refs/remotes/" + remote_name.encode(),
         b"refs/remotes/" + remote_name.encode(),
@@ -1554,14 +1555,18 @@ def _import_remote_refs(
     tags = {
     tags = {
         n[len(LOCAL_TAG_PREFIX) :]: v
         n[len(LOCAL_TAG_PREFIX) :]: v
         for (n, v) in stripped_refs.items()
         for (n, v) in stripped_refs.items()
-        if n.startswith(LOCAL_TAG_PREFIX) and not n.endswith(PEELED_TAG_SUFFIX)
+        if n.startswith(LOCAL_TAG_PREFIX)
+        and not n.endswith(PEELED_TAG_SUFFIX)
+        and v is not None
     }
     }
     refs_container.import_refs(
     refs_container.import_refs(
         LOCAL_TAG_PREFIX, tags, message=message, prune=prune_tags
         LOCAL_TAG_PREFIX, tags, message=message, prune=prune_tags
     )
     )
 
 
 
 
-def serialize_refs(store, refs):
+def serialize_refs(
+    store: ObjectContainer, refs: dict[bytes, bytes]
+) -> dict[bytes, bytes]:
     """Serialize refs with peeled refs.
     """Serialize refs with peeled refs.
 
 
     Args:
     Args:
@@ -1658,6 +1663,7 @@ class locked_ref:
         if not self._file:
         if not self._file:
             raise RuntimeError("locked_ref not in context")
             raise RuntimeError("locked_ref not in context")
 
 
+        assert self._realname is not None
         current_ref = self._refs_container.read_loose_ref(self._realname)
         current_ref = self._refs_container.read_loose_ref(self._realname)
         if current_ref is None:
         if current_ref is None:
             current_ref = self._refs_container.get_packed_refs().get(
             current_ref = self._refs_container.get_packed_refs().get(
@@ -1724,3 +1730,14 @@ class locked_ref:
             self._refs_container._remove_packed_ref(self._realname)
             self._refs_container._remove_packed_ref(self._realname)
 
 
         self._deleted = True
         self._deleted = True
+
+
+def filter_ref_prefix(refs: T, prefixes: Iterable[bytes]) -> T:
+    """Filter refs to only include those with a given prefix.
+
+    Args:
+      refs: A dictionary of refs.
+      prefixes: The prefixes to filter by.
+    """
+    filtered = {k: v for k, v in refs.items() if any(k.startswith(p) for p in prefixes)}
+    return cast(T, filtered)

+ 1 - 1
dulwich/reftable.py

@@ -700,7 +700,7 @@ class ReftableReader:
         # Read magic bytes
         # Read magic bytes
         magic = self.f.read(4)
         magic = self.f.read(4)
         if magic != REFTABLE_MAGIC:
         if magic != REFTABLE_MAGIC:
-            raise ValueError(f"Invalid reftable magic: {magic}")
+            raise ValueError(f"Invalid reftable magic: {magic!r}")
 
 
         # Read version + block size (4 bytes total, big-endian network order)
         # Read version + block size (4 bytes total, big-endian network order)
         # Format: uint8(version) + uint24(block_size)
         # Format: uint8(version) + uint24(block_size)

+ 87 - 74
dulwich/repo.py

@@ -34,7 +34,7 @@ import stat
 import sys
 import sys
 import time
 import time
 import warnings
 import warnings
-from collections.abc import Iterable
+from collections.abc import Iterable, Iterator
 from io import BytesIO
 from io import BytesIO
 from typing import (
 from typing import (
     TYPE_CHECKING,
     TYPE_CHECKING,
@@ -42,6 +42,7 @@ from typing import (
     BinaryIO,
     BinaryIO,
     Callable,
     Callable,
     Optional,
     Optional,
+    TypeVar,
     Union,
     Union,
 )
 )
 
 
@@ -52,7 +53,11 @@ if TYPE_CHECKING:
     from .attrs import GitAttributes
     from .attrs import GitAttributes
     from .config import ConditionMatcher, ConfigFile, StackedConfig
     from .config import ConditionMatcher, ConfigFile, StackedConfig
     from .index import Index
     from .index import Index
+    from .line_ending import BlobNormalizer
     from .notes import Notes
     from .notes import Notes
+    from .object_store import BaseObjectStore, GraphWalker, UnpackedObject
+    from .rebase import RebaseStateManager
+    from .walk import Walker
     from .worktree import WorkTree
     from .worktree import WorkTree
 
 
 from . import replace_me
 from . import replace_me
@@ -115,6 +120,8 @@ from .refs import (
 
 
 CONTROLDIR = ".git"
 CONTROLDIR = ".git"
 OBJECTDIR = "objects"
 OBJECTDIR = "objects"
+
+T = TypeVar("T", bound="ShaFile")
 REFSDIR = "refs"
 REFSDIR = "refs"
 REFSDIR_TAGS = "tags"
 REFSDIR_TAGS = "tags"
 REFSDIR_HEADS = "heads"
 REFSDIR_HEADS = "heads"
@@ -138,12 +145,8 @@ DEFAULT_BRANCH = b"master"
 class InvalidUserIdentity(Exception):
 class InvalidUserIdentity(Exception):
     """User identity is not of the format 'user <email>'."""
     """User identity is not of the format 'user <email>'."""
 
 
-    def __init__(self, identity) -> None:
-        """Initialize InvalidUserIdentity exception.
-
-        Args:
-            identity: The invalid identity string
-        """
+    def __init__(self, identity: str) -> None:
+        """Initialize InvalidUserIdentity exception."""
         self.identity = identity
         self.identity = identity
 
 
 
 
@@ -241,7 +244,7 @@ def get_user_identity(config: "StackedConfig", kind: Optional[str] = None) -> by
     return user + b" <" + email + b">"
     return user + b" <" + email + b">"
 
 
 
 
-def check_user_identity(identity) -> None:
+def check_user_identity(identity: bytes) -> None:
     """Verify that a user identity is formatted correctly.
     """Verify that a user identity is formatted correctly.
 
 
     Args:
     Args:
@@ -252,11 +255,11 @@ def check_user_identity(identity) -> None:
     try:
     try:
         fst, snd = identity.split(b" <", 1)
         fst, snd = identity.split(b" <", 1)
     except ValueError as exc:
     except ValueError as exc:
-        raise InvalidUserIdentity(identity) from exc
+        raise InvalidUserIdentity(identity.decode("utf-8", "replace")) from exc
     if b">" not in snd:
     if b">" not in snd:
-        raise InvalidUserIdentity(identity)
+        raise InvalidUserIdentity(identity.decode("utf-8", "replace"))
     if b"\0" in identity or b"\n" in identity:
     if b"\0" in identity or b"\n" in identity:
-        raise InvalidUserIdentity(identity)
+        raise InvalidUserIdentity(identity.decode("utf-8", "replace"))
 
 
 
 
 def parse_graftpoints(
 def parse_graftpoints(
@@ -313,7 +316,7 @@ def serialize_graftpoints(graftpoints: dict[bytes, list[bytes]]) -> bytes:
     return b"\n".join(graft_lines)
     return b"\n".join(graft_lines)
 
 
 
 
-def _set_filesystem_hidden(path) -> None:
+def _set_filesystem_hidden(path: str) -> None:
     """Mark path as to be hidden if supported by platform and filesystem.
     """Mark path as to be hidden if supported by platform and filesystem.
 
 
     On win32 uses SetFileAttributesW api:
     On win32 uses SetFileAttributesW api:
@@ -337,15 +340,20 @@ def _set_filesystem_hidden(path) -> None:
 
 
 
 
 class ParentsProvider:
 class ParentsProvider:
-    """Provides parents for commits, handling grafts and shallow commits."""
+    """Provider for commit parent information."""
 
 
-    def __init__(self, store, grafts={}, shallows=[]) -> None:
+    def __init__(
+        self,
+        store: "BaseObjectStore",
+        grafts: dict = {},
+        shallows: Iterable[bytes] = [],
+    ) -> None:
         """Initialize ParentsProvider.
         """Initialize ParentsProvider.
 
 
         Args:
         Args:
-            store: Object store to get commits from
-            grafts: Dictionary mapping commit ids to parent ids
-            shallows: List of shallow commit ids
+            store: Object store to use
+            grafts: Graft information
+            shallows: Shallow commit SHAs
         """
         """
         self.store = store
         self.store = store
         self.grafts = grafts
         self.grafts = grafts
@@ -354,16 +362,10 @@ class ParentsProvider:
         # Get commit graph once at initialization for performance
         # Get commit graph once at initialization for performance
         self.commit_graph = store.get_commit_graph()
         self.commit_graph = store.get_commit_graph()
 
 
-    def get_parents(self, commit_id, commit=None):
-        """Get the parents of a commit.
-
-        Args:
-          commit_id: The commit SHA to get parents for
-          commit: Optional commit object to avoid fetching
-
-        Returns:
-          List of parent commit SHAs
-        """
+    def get_parents(
+        self, commit_id: bytes, commit: Optional[Commit] = None
+    ) -> list[bytes]:
+        """Get parents for a commit using the parents provider."""
         try:
         try:
             return self.grafts[commit_id]
             return self.grafts[commit_id]
         except KeyError:
         except KeyError:
@@ -379,7 +381,9 @@ class ParentsProvider:
 
 
         # Fallback to reading the commit object
         # Fallback to reading the commit object
         if commit is None:
         if commit is None:
-            commit = self.store[commit_id]
+            obj = self.store[commit_id]
+            assert isinstance(obj, Commit)
+            commit = obj
         return commit.parents
         return commit.parents
 
 
 
 
@@ -494,8 +498,12 @@ class BaseRepo:
         raise NotImplementedError(self.open_index)
         raise NotImplementedError(self.open_index)
 
 
     def fetch(
     def fetch(
-        self, target, determine_wants=None, progress=None, depth: Optional[int] = None
-    ):
+        self,
+        target: "BaseRepo",
+        determine_wants: Optional[Callable] = None,
+        progress: Optional[Callable] = None,
+        depth: Optional[int] = None,
+    ) -> dict:
         """Fetch objects into another repository.
         """Fetch objects into another repository.
 
 
         Args:
         Args:
@@ -519,13 +527,13 @@ class BaseRepo:
 
 
     def fetch_pack_data(
     def fetch_pack_data(
         self,
         self,
-        determine_wants,
-        graph_walker,
-        progress,
+        determine_wants: Callable,
+        graph_walker: "GraphWalker",
+        progress: Optional[Callable],
         *,
         *,
-        get_tagged=None,
+        get_tagged: Optional[Callable] = None,
         depth: Optional[int] = None,
         depth: Optional[int] = None,
-    ):
+    ) -> tuple:
         """Fetch the pack data required for a set of revisions.
         """Fetch the pack data required for a set of revisions.
 
 
         Args:
         Args:
@@ -554,11 +562,11 @@ class BaseRepo:
 
 
     def find_missing_objects(
     def find_missing_objects(
         self,
         self,
-        determine_wants,
-        graph_walker,
-        progress,
+        determine_wants: Callable,
+        graph_walker: "GraphWalker",
+        progress: Optional[Callable],
         *,
         *,
-        get_tagged=None,
+        get_tagged: Optional[Callable] = None,
         depth: Optional[int] = None,
         depth: Optional[int] = None,
     ) -> Optional[MissingObjectFinder]:
     ) -> Optional[MissingObjectFinder]:
         """Fetch the missing objects required for a set of revisions.
         """Fetch the missing objects required for a set of revisions.
@@ -585,16 +593,17 @@ class BaseRepo:
         current_shallow = set(getattr(graph_walker, "shallow", set()))
         current_shallow = set(getattr(graph_walker, "shallow", set()))
 
 
         if depth not in (None, 0):
         if depth not in (None, 0):
+            assert depth is not None
             shallow, not_shallow = find_shallow(self.object_store, wants, depth)
             shallow, not_shallow = find_shallow(self.object_store, wants, depth)
             # Only update if graph_walker has shallow attribute
             # Only update if graph_walker has shallow attribute
             if hasattr(graph_walker, "shallow"):
             if hasattr(graph_walker, "shallow"):
                 graph_walker.shallow.update(shallow - not_shallow)
                 graph_walker.shallow.update(shallow - not_shallow)
                 new_shallow = graph_walker.shallow - current_shallow
                 new_shallow = graph_walker.shallow - current_shallow
-                unshallow = graph_walker.unshallow = not_shallow & current_shallow
+                unshallow = graph_walker.unshallow = not_shallow & current_shallow  # type: ignore[attr-defined]
                 if hasattr(graph_walker, "update_shallow"):
                 if hasattr(graph_walker, "update_shallow"):
                     graph_walker.update_shallow(new_shallow, unshallow)
                     graph_walker.update_shallow(new_shallow, unshallow)
         else:
         else:
-            unshallow = getattr(graph_walker, "unshallow", frozenset())
+            unshallow = getattr(graph_walker, "unshallow", set())
 
 
         if wants == []:
         if wants == []:
             # TODO(dborowitz): find a way to short-circuit that doesn't change
             # TODO(dborowitz): find a way to short-circuit that doesn't change
@@ -618,7 +627,7 @@ class BaseRepo:
                 def __len__(self) -> int:
                 def __len__(self) -> int:
                     return 0
                     return 0
 
 
-                def __iter__(self):
+                def __iter__(self) -> Iterator[tuple[bytes, Optional[bytes]]]:
                     yield from []
                     yield from []
 
 
             return DummyMissingObjectFinder()  # type: ignore
             return DummyMissingObjectFinder()  # type: ignore
@@ -637,7 +646,7 @@ class BaseRepo:
 
 
         parents_provider = ParentsProvider(self.object_store, shallows=current_shallow)
         parents_provider = ParentsProvider(self.object_store, shallows=current_shallow)
 
 
-        def get_parents(commit):
+        def get_parents(commit: Commit) -> list[bytes]:
             """Get parents for a commit using the parents provider.
             """Get parents for a commit using the parents provider.
 
 
             Args:
             Args:
@@ -660,11 +669,11 @@ class BaseRepo:
 
 
     def generate_pack_data(
     def generate_pack_data(
         self,
         self,
-        have: list[ObjectID],
-        want: list[ObjectID],
+        have: Iterable[ObjectID],
+        want: Iterable[ObjectID],
         progress: Optional[Callable[[str], None]] = None,
         progress: Optional[Callable[[str], None]] = None,
         ofs_delta: Optional[bool] = None,
         ofs_delta: Optional[bool] = None,
-    ):
+    ) -> tuple[int, Iterator["UnpackedObject"]]:
         """Generate pack data objects for a set of wants/haves.
         """Generate pack data objects for a set of wants/haves.
 
 
         Args:
         Args:
@@ -719,18 +728,18 @@ class BaseRepo:
         # TODO: move this method to WorkTree
         # TODO: move this method to WorkTree
         return self.refs[b"HEAD"]
         return self.refs[b"HEAD"]
 
 
-    def _get_object(self, sha, cls):
+    def _get_object(self, sha: bytes, cls: type[T]) -> T:
         assert len(sha) in (20, 40)
         assert len(sha) in (20, 40)
         ret = self.get_object(sha)
         ret = self.get_object(sha)
         if not isinstance(ret, cls):
         if not isinstance(ret, cls):
             if cls is Commit:
             if cls is Commit:
-                raise NotCommitError(ret)
+                raise NotCommitError(ret.id)
             elif cls is Blob:
             elif cls is Blob:
-                raise NotBlobError(ret)
+                raise NotBlobError(ret.id)
             elif cls is Tree:
             elif cls is Tree:
-                raise NotTreeError(ret)
+                raise NotTreeError(ret.id)
             elif cls is Tag:
             elif cls is Tag:
-                raise NotTagError(ret)
+                raise NotTagError(ret.id)
             else:
             else:
                 raise Exception(f"Type invalid: {ret.type_name!r} != {cls.type_name!r}")
                 raise Exception(f"Type invalid: {ret.type_name!r} != {cls.type_name!r}")
         return ret
         return ret
@@ -790,7 +799,7 @@ class BaseRepo:
         """
         """
         raise NotImplementedError(self.get_description)
         raise NotImplementedError(self.get_description)
 
 
-    def set_description(self, description) -> None:
+    def set_description(self, description: bytes) -> None:
         """Set the description for this repository.
         """Set the description for this repository.
 
 
         Args:
         Args:
@@ -798,14 +807,14 @@ class BaseRepo:
         """
         """
         raise NotImplementedError(self.set_description)
         raise NotImplementedError(self.set_description)
 
 
-    def get_rebase_state_manager(self):
+    def get_rebase_state_manager(self) -> "RebaseStateManager":
         """Get the appropriate rebase state manager for this repository.
         """Get the appropriate rebase state manager for this repository.
 
 
         Returns: RebaseStateManager instance
         Returns: RebaseStateManager instance
         """
         """
         raise NotImplementedError(self.get_rebase_state_manager)
         raise NotImplementedError(self.get_rebase_state_manager)
 
 
-    def get_blob_normalizer(self):
+    def get_blob_normalizer(self) -> "BlobNormalizer":
         """Return a BlobNormalizer object for checkin/checkout operations.
         """Return a BlobNormalizer object for checkin/checkout operations.
 
 
         Returns: BlobNormalizer instance
         Returns: BlobNormalizer instance
@@ -853,7 +862,9 @@ class BaseRepo:
         with f:
         with f:
             return {line.strip() for line in f}
             return {line.strip() for line in f}
 
 
-    def update_shallow(self, new_shallow, new_unshallow) -> None:
+    def update_shallow(
+        self, new_shallow: Optional[set[bytes]], new_unshallow: Optional[set[bytes]]
+    ) -> None:
         """Update the list of shallow objects.
         """Update the list of shallow objects.
 
 
         Args:
         Args:
@@ -895,7 +906,7 @@ class BaseRepo:
 
 
         return Notes(self.object_store, self.refs)
         return Notes(self.object_store, self.refs)
 
 
-    def get_walker(self, include: Optional[list[bytes]] = None, **kwargs):
+    def get_walker(self, include: Optional[list[bytes]] = None, **kwargs) -> "Walker":
         """Obtain a walker for this repository.
         """Obtain a walker for this repository.
 
 
         Args:
         Args:
@@ -932,7 +943,7 @@ class BaseRepo:
 
 
         return Walker(self.object_store, include, **kwargs)
         return Walker(self.object_store, include, **kwargs)
 
 
-    def __getitem__(self, name: Union[ObjectID, Ref]):
+    def __getitem__(self, name: Union[ObjectID, Ref]) -> "ShaFile":
         """Retrieve a Git object by SHA1 or ref.
         """Retrieve a Git object by SHA1 or ref.
 
 
         Args:
         Args:
@@ -1024,7 +1035,7 @@ class BaseRepo:
         for sha in to_remove:
         for sha in to_remove:
             del self._graftpoints[sha]
             del self._graftpoints[sha]
 
 
-    def _read_heads(self, name):
+    def _read_heads(self, name: str) -> list[bytes]:
         f = self.get_named_file(name)
         f = self.get_named_file(name)
         if f is None:
         if f is None:
             return []
             return []
@@ -1050,17 +1061,17 @@ class BaseRepo:
         message: Optional[bytes] = None,
         message: Optional[bytes] = None,
         committer: Optional[bytes] = None,
         committer: Optional[bytes] = None,
         author: Optional[bytes] = None,
         author: Optional[bytes] = None,
-        commit_timestamp=None,
-        commit_timezone=None,
-        author_timestamp=None,
-        author_timezone=None,
+        commit_timestamp: Optional[float] = None,
+        commit_timezone: Optional[int] = None,
+        author_timestamp: Optional[float] = None,
+        author_timezone: Optional[int] = None,
         tree: Optional[ObjectID] = None,
         tree: Optional[ObjectID] = None,
         encoding: Optional[bytes] = None,
         encoding: Optional[bytes] = None,
         ref: Optional[Ref] = b"HEAD",
         ref: Optional[Ref] = b"HEAD",
         merge_heads: Optional[list[ObjectID]] = None,
         merge_heads: Optional[list[ObjectID]] = None,
         no_verify: bool = False,
         no_verify: bool = False,
         sign: bool = False,
         sign: bool = False,
-    ):
+    ) -> bytes:
         """Create a new commit.
         """Create a new commit.
 
 
         If not specified, committer and author default to
         If not specified, committer and author default to
@@ -1109,7 +1120,7 @@ class BaseRepo:
         )
         )
 
 
 
 
-def read_gitfile(f):
+def read_gitfile(f: BinaryIO) -> str:
     """Read a ``.git`` file.
     """Read a ``.git`` file.
 
 
     The first line of the file should start with "gitdir: "
     The first line of the file should start with "gitdir: "
@@ -1119,9 +1130,9 @@ def read_gitfile(f):
     Returns: A path
     Returns: A path
     """
     """
     cs = f.read()
     cs = f.read()
-    if not cs.startswith("gitdir: "):
+    if not cs.startswith(b"gitdir: "):
         raise ValueError("Expected file to start with 'gitdir: '")
         raise ValueError("Expected file to start with 'gitdir: '")
-    return cs[len("gitdir: ") :].rstrip("\n")
+    return cs[len(b"gitdir: ") :].rstrip(b"\n").decode("utf-8")
 
 
 
 
 class UnsupportedVersion(Exception):
 class UnsupportedVersion(Exception):
@@ -1205,7 +1216,7 @@ class Repo(BaseRepo):
         self.bare = bare
         self.bare = bare
         if bare is False:
         if bare is False:
             if os.path.isfile(hidden_path):
             if os.path.isfile(hidden_path):
-                with open(hidden_path) as f:
+                with open(hidden_path, "rb") as f:
                     path = read_gitfile(f)
                     path = read_gitfile(f)
                 self._controldir = os.path.join(root, path)
                 self._controldir = os.path.join(root, path)
             else:
             else:
@@ -1364,11 +1375,11 @@ class Repo(BaseRepo):
             "No git repository was found at {path}".format(**dict(path=start))
             "No git repository was found at {path}".format(**dict(path=start))
         )
         )
 
 
-    def controldir(self):
+    def controldir(self) -> str:
         """Return the path of the control directory."""
         """Return the path of the control directory."""
         return self._controldir
         return self._controldir
 
 
-    def commondir(self):
+    def commondir(self) -> str:
         """Return the path of the common directory.
         """Return the path of the common directory.
 
 
         For a main working tree, it is identical to controldir().
         For a main working tree, it is identical to controldir().
@@ -1378,7 +1389,7 @@ class Repo(BaseRepo):
         """
         """
         return self._commondir
         return self._commondir
 
 
-    def _determine_file_mode(self):
+    def _determine_file_mode(self) -> bool:
         """Probe the file-system to determine whether permissions can be trusted.
         """Probe the file-system to determine whether permissions can be trusted.
 
 
         Returns: True if permissions can be trusted, False otherwise.
         Returns: True if permissions can be trusted, False otherwise.
@@ -1401,7 +1412,7 @@ class Repo(BaseRepo):
 
 
         return mode_differs and st2_has_exec
         return mode_differs and st2_has_exec
 
 
-    def _determine_symlinks(self):
+    def _determine_symlinks(self) -> bool:
         """Probe the filesystem to determine whether symlinks can be created.
         """Probe the filesystem to determine whether symlinks can be created.
 
 
         Returns: True if symlinks can be created, False otherwise.
         Returns: True if symlinks can be created, False otherwise.
@@ -1409,7 +1420,7 @@ class Repo(BaseRepo):
         # TODO(jelmer): Actually probe disk / look at filesystem
         # TODO(jelmer): Actually probe disk / look at filesystem
         return sys.platform != "win32"
         return sys.platform != "win32"
 
 
-    def _put_named_file(self, path, contents) -> None:
+    def _put_named_file(self, path: str, contents: bytes) -> None:
         """Write a file to the control dir with the given name and contents.
         """Write a file to the control dir with the given name and contents.
 
 
         Args:
         Args:
@@ -1420,7 +1431,7 @@ class Repo(BaseRepo):
         with GitFile(os.path.join(self.controldir(), path), "wb") as f:
         with GitFile(os.path.join(self.controldir(), path), "wb") as f:
             f.write(contents)
             f.write(contents)
 
 
-    def _del_named_file(self, path) -> None:
+    def _del_named_file(self, path: str) -> None:
         try:
         try:
             os.unlink(os.path.join(self.controldir(), path))
             os.unlink(os.path.join(self.controldir(), path))
         except FileNotFoundError:
         except FileNotFoundError:
@@ -2040,6 +2051,7 @@ class Repo(BaseRepo):
                 if isinstance(head, Tag):
                 if isinstance(head, Tag):
                     _cls, obj = head.object
                     _cls, obj = head.object
                     head = self.get_object(obj)
                     head = self.get_object(obj)
+                assert isinstance(head, Commit)
                 tree = head.tree
                 tree = head.tree
             except KeyError:
             except KeyError:
                 # No HEAD, no attributes from tree
                 # No HEAD, no attributes from tree
@@ -2048,6 +2060,7 @@ class Repo(BaseRepo):
         if tree is not None:
         if tree is not None:
             try:
             try:
                 tree_obj = self[tree]
                 tree_obj = self[tree]
+                assert isinstance(tree_obj, Tree)
                 if b".gitattributes" in tree_obj:
                 if b".gitattributes" in tree_obj:
                     _, attrs_sha = tree_obj[b".gitattributes"]
                     _, attrs_sha = tree_obj[b".gitattributes"]
                     attrs_blob = self[attrs_sha]
                     attrs_blob = self[attrs_sha]
@@ -2136,7 +2149,7 @@ class MemoryRepo(BaseRepo):
 
 
         self._reflog: list[Any] = []
         self._reflog: list[Any] = []
         refs_container = DictRefsContainer({}, logger=self._append_reflog)
         refs_container = DictRefsContainer({}, logger=self._append_reflog)
-        BaseRepo.__init__(self, MemoryObjectStore(), refs_container)  # type: ignore
+        BaseRepo.__init__(self, MemoryObjectStore(), refs_container)  # type: ignore[arg-type]
         self._named_files: dict[str, bytes] = {}
         self._named_files: dict[str, bytes] = {}
         self.bare = True
         self.bare = True
         self._config = ConfigFile()
         self._config = ConfigFile()

+ 16 - 3
dulwich/server.py

@@ -52,7 +52,7 @@ import time
 import zlib
 import zlib
 from collections.abc import Iterable, Iterator
 from collections.abc import Iterable, Iterator
 from functools import partial
 from functools import partial
-from typing import TYPE_CHECKING, Optional, cast
+from typing import TYPE_CHECKING, Optional
 from typing import Protocol as TypingProtocol
 from typing import Protocol as TypingProtocol
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
@@ -551,10 +551,12 @@ def _split_proto_line(line, allowed):
             COMMAND_SHALLOW,
             COMMAND_SHALLOW,
             COMMAND_UNSHALLOW,
             COMMAND_UNSHALLOW,
         ):
         ):
+            assert fields[1] is not None
             if not valid_hexsha(fields[1]):
             if not valid_hexsha(fields[1]):
                 raise GitProtocolError("Invalid sha")
                 raise GitProtocolError("Invalid sha")
             return tuple(fields)
             return tuple(fields)
         elif command == COMMAND_DEEPEN:
         elif command == COMMAND_DEEPEN:
+            assert fields[1] is not None
             return command, int(fields[1])
             return command, int(fields[1])
     raise GitProtocolError(f"Received invalid line from client: {line!r}")
     raise GitProtocolError(f"Received invalid line from client: {line!r}")
 
 
@@ -633,6 +635,10 @@ class AckGraphWalkerImpl:
         """
         """
         raise NotImplementedError
         raise NotImplementedError
 
 
+    def handle_done(self, done_required, done_received):
+        """Handle 'done' packet from client."""
+        raise NotImplementedError
+
 
 
 class _ProtocolGraphWalker:
 class _ProtocolGraphWalker:
     """A graph walker that knows the git protocol.
     """A graph walker that knows the git protocol.
@@ -784,6 +790,7 @@ class _ProtocolGraphWalker:
         """
         """
         if len(have_ref) != 40:
         if len(have_ref) != 40:
             raise ValueError(f"invalid sha {have_ref!r}")
             raise ValueError(f"invalid sha {have_ref!r}")
+        assert self._impl is not None
         return self._impl.ack(have_ref)
         return self._impl.ack(have_ref)
 
 
     def reset(self) -> None:
     def reset(self) -> None:
@@ -799,7 +806,8 @@ class _ProtocolGraphWalker:
         if not self._cached:
         if not self._cached:
             if not self._impl and self.stateless_rpc:
             if not self._impl and self.stateless_rpc:
                 return None
                 return None
-            return next(self._impl)
+            assert self._impl is not None
+            return next(self._impl)  # type: ignore[call-overload]
         self._cache_index += 1
         self._cache_index += 1
         if self._cache_index > len(self._cache):
         if self._cache_index > len(self._cache):
             return None
             return None
@@ -884,6 +892,7 @@ class _ProtocolGraphWalker:
         Returns: True if done handling succeeded
         Returns: True if done handling succeeded
         """
         """
         # Delegate this to the implementation.
         # Delegate this to the implementation.
+        assert self._impl is not None
         return self._impl.handle_done(done_required, done_received)
         return self._impl.handle_done(done_required, done_received)
 
 
     def set_wants(self, wants) -> None:
     def set_wants(self, wants) -> None:
@@ -1400,7 +1409,11 @@ class UploadArchiveHandler(Handler):
                 format = arguments[i].decode("ascii")
                 format = arguments[i].decode("ascii")
             else:
             else:
                 commit_sha = self.repo.refs[argument]
                 commit_sha = self.repo.refs[argument]
-                tree = cast(Tree, store[cast(Commit, store[commit_sha]).tree])
+                commit_obj = store[commit_sha]
+                assert isinstance(commit_obj, Commit)
+                tree_obj = store[commit_obj.tree]
+                assert isinstance(tree_obj, Tree)
+                tree = tree_obj
             i += 1
             i += 1
         self.proto.write_pkt_line(b"ACK")
         self.proto.write_pkt_line(b"ACK")
         self.proto.write_pkt_line(None)
         self.proto.write_pkt_line(None)

+ 12 - 6
dulwich/stash.py

@@ -23,7 +23,7 @@
 
 
 import os
 import os
 import sys
 import sys
-from typing import TYPE_CHECKING, Optional, TypedDict
+from typing import TYPE_CHECKING, Optional, TypedDict, Union
 
 
 from .diff_tree import tree_changes
 from .diff_tree import tree_changes
 from .file import GitFile
 from .file import GitFile
@@ -162,10 +162,16 @@ class Stash:
             symlink_fn = symlink
             symlink_fn = symlink
         else:
         else:
 
 
-            def symlink_fn(source, target) -> None:  # type: ignore
-                mode = "w" + ("b" if isinstance(source, bytes) else "")
-                with open(target, mode) as f:
-                    f.write(source)
+            def symlink_fn(
+                src: Union[str, bytes, os.PathLike],
+                dst: Union[str, bytes, os.PathLike],
+                target_is_directory: bool = False,
+                *,
+                dir_fd: Optional[int] = None,
+            ) -> None:
+                mode = "w" + ("b" if isinstance(src, bytes) else "")
+                with open(dst, mode) as f:
+                    f.write(src)
 
 
         # Get blob normalizer for line ending conversion
         # Get blob normalizer for line ending conversion
         blob_normalizer = self._repo.get_blob_normalizer()
         blob_normalizer = self._repo.get_blob_normalizer()
@@ -228,7 +234,7 @@ class Stash:
                     entry.mode,
                     entry.mode,
                     full_path,
                     full_path,
                     honor_filemode=honor_filemode,
                     honor_filemode=honor_filemode,
-                    symlink_fn=symlink_fn,
+                    symlink_fn=symlink_fn,  # type: ignore[arg-type]
                 )
                 )
 
 
             # Update index if the file wasn't already staged
             # Update index if the file wasn't already staged

+ 1 - 1
dulwich/tests/test_object_store.py

@@ -444,7 +444,7 @@ class FindShallowTests(TestCase):
     def make_linear_commits(self, n, message=b""):
     def make_linear_commits(self, n, message=b""):
         """Create a linear chain of commits."""
         """Create a linear chain of commits."""
         commits = []
         commits = []
-        parents = []
+        parents: list[bytes] = []
         for _ in range(n):
         for _ in range(n):
             commits.append(self.make_commit(parents=parents, message=message))
             commits.append(self.make_commit(parents=parents, message=message))
             parents = [commits[-1].id]
             parents = [commits[-1].id]

+ 51 - 28
dulwich/tests/utils.py

@@ -21,6 +21,8 @@
 
 
 """Utility functions common to Dulwich tests."""
 """Utility functions common to Dulwich tests."""
 
 
+# ruff: noqa: ANN401
+
 import datetime
 import datetime
 import os
 import os
 import shutil
 import shutil
@@ -28,10 +30,12 @@ import tempfile
 import time
 import time
 import types
 import types
 import warnings
 import warnings
+from typing import Any, BinaryIO, Callable, Optional, TypeVar, Union
 from unittest import SkipTest
 from unittest import SkipTest
 
 
 from dulwich.index import commit_tree
 from dulwich.index import commit_tree
-from dulwich.objects import Commit, FixedSha, Tag, object_class
+from dulwich.object_store import BaseObjectStore
+from dulwich.objects import Commit, FixedSha, ShaFile, Tag, object_class
 from dulwich.pack import (
 from dulwich.pack import (
     DELTA_TYPES,
     DELTA_TYPES,
     OFS_DELTA,
     OFS_DELTA,
@@ -47,8 +51,10 @@ from dulwich.repo import Repo
 # Plain files are very frequently used in tests, so let the mode be very short.
 # Plain files are very frequently used in tests, so let the mode be very short.
 F = 0o100644  # Shorthand mode for Files.
 F = 0o100644  # Shorthand mode for Files.
 
 
+T = TypeVar("T", bound=ShaFile)
+
 
 
-def open_repo(name, temp_dir=None):
+def open_repo(name: str, temp_dir: Optional[str] = None) -> Repo:
     """Open a copy of a repo in a temporary directory.
     """Open a copy of a repo in a temporary directory.
 
 
     Use this function for accessing repos in dulwich/tests/data/repos to avoid
     Use this function for accessing repos in dulwich/tests/data/repos to avoid
@@ -72,14 +78,14 @@ def open_repo(name, temp_dir=None):
     return Repo(temp_repo_dir)
     return Repo(temp_repo_dir)
 
 
 
 
-def tear_down_repo(repo) -> None:
+def tear_down_repo(repo: Repo) -> None:
     """Tear down a test repository."""
     """Tear down a test repository."""
     repo.close()
     repo.close()
     temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
     temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
     shutil.rmtree(temp_dir)
     shutil.rmtree(temp_dir)
 
 
 
 
-def make_object(cls, **attrs):
+def make_object(cls: type[T], **attrs: Any) -> T:
     """Make an object for testing and assign some members.
     """Make an object for testing and assign some members.
 
 
     This method creates a new subclass to allow arbitrary attribute
     This method creates a new subclass to allow arbitrary attribute
@@ -92,7 +98,7 @@ def make_object(cls, **attrs):
     Returns: A newly initialized object of type cls.
     Returns: A newly initialized object of type cls.
     """
     """
 
 
-    class TestObject(cls):
+    class TestObject(cls):  # type: ignore[misc,valid-type]
         """Class that inherits from the given class, but without __slots__.
         """Class that inherits from the given class, but without __slots__.
 
 
         Note that classes with __slots__ can't have arbitrary attributes
         Note that classes with __slots__ can't have arbitrary attributes
@@ -102,7 +108,7 @@ def make_object(cls, **attrs):
 
 
     TestObject.__name__ = "TestObject_" + cls.__name__
     TestObject.__name__ = "TestObject_" + cls.__name__
 
 
-    obj = TestObject()
+    obj = TestObject()  # type: ignore[abstract]
     for name, value in attrs.items():
     for name, value in attrs.items():
         if name == "id":
         if name == "id":
             # id property is read-only, so we overwrite sha instead.
             # id property is read-only, so we overwrite sha instead.
@@ -113,7 +119,7 @@ def make_object(cls, **attrs):
     return obj
     return obj
 
 
 
 
-def make_commit(**attrs):
+def make_commit(**attrs: Any) -> Commit:
     """Make a Commit object with a default set of members.
     """Make a Commit object with a default set of members.
 
 
     Args:
     Args:
@@ -136,7 +142,7 @@ def make_commit(**attrs):
     return make_object(Commit, **all_attrs)
     return make_object(Commit, **all_attrs)
 
 
 
 
-def make_tag(target, **attrs):
+def make_tag(target: ShaFile, **attrs: Any) -> Tag:
     """Make a Tag object with a default set of values.
     """Make a Tag object with a default set of values.
 
 
     Args:
     Args:
@@ -159,16 +165,20 @@ def make_tag(target, **attrs):
     return make_object(Tag, **all_attrs)
     return make_object(Tag, **all_attrs)
 
 
 
 
-def functest_builder(method, func):
+def functest_builder(
+    method: Callable[[Any, Any], None], func: Any
+) -> Callable[[Any], None]:
     """Generate a test method that tests the given function."""
     """Generate a test method that tests the given function."""
 
 
-    def do_test(self) -> None:
+    def do_test(self: Any) -> None:
         method(self, func)
         method(self, func)
 
 
     return do_test
     return do_test
 
 
 
 
-def ext_functest_builder(method, func):
+def ext_functest_builder(
+    method: Callable[[Any, Any], None], func: Any
+) -> Callable[[Any], None]:
     """Generate a test method that tests the given extension function.
     """Generate a test method that tests the given extension function.
 
 
     This is intended to generate test methods that test both a pure-Python
     This is intended to generate test methods that test both a pure-Python
@@ -190,7 +200,7 @@ def ext_functest_builder(method, func):
       func: The function implementation to pass to method.
       func: The function implementation to pass to method.
     """
     """
 
 
-    def do_test(self) -> None:
+    def do_test(self: Any) -> None:
         if not isinstance(func, types.BuiltinFunctionType):
         if not isinstance(func, types.BuiltinFunctionType):
             raise SkipTest(f"{func} extension not found")
             raise SkipTest(f"{func} extension not found")
         method(self, func)
         method(self, func)
@@ -198,7 +208,11 @@ def ext_functest_builder(method, func):
     return do_test
     return do_test
 
 
 
 
-def build_pack(f, objects_spec, store=None):
+def build_pack(
+    f: BinaryIO,
+    objects_spec: list[tuple[int, Any]],
+    store: Optional[BaseObjectStore] = None,
+) -> list[tuple[int, int, bytes, bytes, int]]:
     """Write test pack data from a concise spec.
     """Write test pack data from a concise spec.
 
 
     Args:
     Args:
@@ -221,14 +235,14 @@ def build_pack(f, objects_spec, store=None):
     num_objects = len(objects_spec)
     num_objects = len(objects_spec)
     write_pack_header(sf.write, num_objects)
     write_pack_header(sf.write, num_objects)
 
 
-    full_objects = {}
-    offsets = {}
-    crc32s = {}
+    full_objects: dict[int, tuple[int, bytes, bytes]] = {}
+    offsets: dict[int, int] = {}
+    crc32s: dict[int, int] = {}
 
 
     while len(full_objects) < num_objects:
     while len(full_objects) < num_objects:
         for i, (type_num, data) in enumerate(objects_spec):
         for i, (type_num, data) in enumerate(objects_spec):
             if type_num not in DELTA_TYPES:
             if type_num not in DELTA_TYPES:
-                full_objects[i] = (type_num, data, obj_sha(type_num, [data]))
+                full_objects[i] = (type_num, data, obj_sha(type_num, [data]))  # type: ignore[no-untyped-call]
                 continue
                 continue
             base, data = data
             base, data = data
             if isinstance(base, int):
             if isinstance(base, int):
@@ -236,11 +250,12 @@ def build_pack(f, objects_spec, store=None):
                     continue
                     continue
                 base_type_num, _, _ = full_objects[base]
                 base_type_num, _, _ = full_objects[base]
             else:
             else:
+                assert store is not None
                 base_type_num, _ = store.get_raw(base)
                 base_type_num, _ = store.get_raw(base)
             full_objects[i] = (
             full_objects[i] = (
                 base_type_num,
                 base_type_num,
                 data,
                 data,
-                obj_sha(base_type_num, [data]),
+                obj_sha(base_type_num, [data]),  # type: ignore[no-untyped-call]
             )
             )
 
 
     for i, (type_num, obj) in enumerate(objects_spec):
     for i, (type_num, obj) in enumerate(objects_spec):
@@ -249,17 +264,18 @@ def build_pack(f, objects_spec, store=None):
             base_index, data = obj
             base_index, data = obj
             base = offset - offsets[base_index]
             base = offset - offsets[base_index]
             _, base_data, _ = full_objects[base_index]
             _, base_data, _ = full_objects[base_index]
-            obj = (base, list(create_delta(base_data, data)))
+            obj = (base, list(create_delta(base_data, data)))  # type: ignore[no-untyped-call]
         elif type_num == REF_DELTA:
         elif type_num == REF_DELTA:
             base_ref, data = obj
             base_ref, data = obj
             if isinstance(base_ref, int):
             if isinstance(base_ref, int):
                 _, base_data, base = full_objects[base_ref]
                 _, base_data, base = full_objects[base_ref]
             else:
             else:
+                assert store is not None
                 base_type_num, base_data = store.get_raw(base_ref)
                 base_type_num, base_data = store.get_raw(base_ref)
-                base = obj_sha(base_type_num, base_data)
-            obj = (base, list(create_delta(base_data, data)))
+                base = obj_sha(base_type_num, base_data)  # type: ignore[no-untyped-call]
+            obj = (base, list(create_delta(base_data, data)))  # type: ignore[no-untyped-call]
 
 
-        crc32 = write_pack_object(sf.write, type_num, obj)
+        crc32 = write_pack_object(sf.write, type_num, obj)  # type: ignore[no-untyped-call]
         offsets[i] = offset
         offsets[i] = offset
         crc32s[i] = crc32
         crc32s[i] = crc32
 
 
@@ -269,12 +285,19 @@ def build_pack(f, objects_spec, store=None):
         assert len(sha) == 20
         assert len(sha) == 20
         expected.append((offsets[i], type_num, data, sha, crc32s[i]))
         expected.append((offsets[i], type_num, data, sha, crc32s[i]))
 
 
-    sf.write_sha()
+    sf.write_sha()  # type: ignore[no-untyped-call]
     f.seek(0)
     f.seek(0)
     return expected
     return expected
 
 
 
 
-def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
+def build_commit_graph(
+    object_store: BaseObjectStore,
+    commit_spec: list[list[int]],
+    trees: Optional[
+        dict[int, list[Union[tuple[bytes, ShaFile], tuple[bytes, ShaFile, int]]]]
+    ] = None,
+    attrs: Optional[dict[int, dict[str, Any]]] = None,
+) -> list[Commit]:
     """Build a commit graph from a concise specification.
     """Build a commit graph from a concise specification.
 
 
     Sample usage:
     Sample usage:
@@ -311,7 +334,7 @@ def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
     if attrs is None:
     if attrs is None:
         attrs = {}
         attrs = {}
     commit_time = 0
     commit_time = 0
-    nums = {}
+    nums: dict[int, bytes] = {}
     commits = []
     commits = []
 
 
     for commit in commit_spec:
     for commit in commit_spec:
@@ -343,7 +366,7 @@ def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
 
 
         # By default, increment the time by a lot. Out-of-order commits should
         # By default, increment the time by a lot. Out-of-order commits should
         # be closer together than this because their main cause is clock skew.
         # be closer together than this because their main cause is clock skew.
-        commit_time = commit_attrs["commit_time"] + 100
+        commit_time = commit_attrs["commit_time"] + 100  # type: ignore[operator]
         nums[commit_num] = commit_obj.id
         nums[commit_num] = commit_obj.id
         object_store.add_object(commit_obj)
         object_store.add_object(commit_obj)
         commits.append(commit_obj)
         commits.append(commit_obj)
@@ -351,12 +374,12 @@ def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
     return commits
     return commits
 
 
 
 
-def setup_warning_catcher():
+def setup_warning_catcher() -> tuple[list[Warning], Callable[[], None]]:
     """Wrap warnings.showwarning with code that records warnings."""
     """Wrap warnings.showwarning with code that records warnings."""
     caught_warnings = []
     caught_warnings = []
     original_showwarning = warnings.showwarning
     original_showwarning = warnings.showwarning
 
 
-    def custom_showwarning(*args, **kwargs) -> None:
+    def custom_showwarning(*args: Any, **kwargs: Any) -> None:
         caught_warnings.append(args[0])
         caught_warnings.append(args[0])
 
 
     warnings.showwarning = custom_showwarning
     warnings.showwarning = custom_showwarning

+ 12 - 5
dulwich/walk.py

@@ -87,7 +87,9 @@ class WalkEntry:
                 parent = None
                 parent = None
             elif len(self._get_parents(commit)) == 1:
             elif len(self._get_parents(commit)) == 1:
                 changes_func = tree_changes
                 changes_func = tree_changes
-                parent = cast(Commit, self._store[self._get_parents(commit)[0]]).tree
+                parent_commit = self._store[self._get_parents(commit)[0]]
+                assert isinstance(parent_commit, Commit)
+                parent = parent_commit.tree
                 if path_prefix:
                 if path_prefix:
                     mode, subtree_sha = parent.lookup_path(
                     mode, subtree_sha = parent.lookup_path(
                         self._store.__getitem__,
                         self._store.__getitem__,
@@ -96,9 +98,12 @@ class WalkEntry:
                     parent = self._store[subtree_sha]
                     parent = self._store[subtree_sha]
             else:
             else:
                 # For merge commits, we need to handle multiple parents differently
                 # For merge commits, we need to handle multiple parents differently
-                parent = [
-                    cast(Commit, self._store[p]).tree for p in self._get_parents(commit)
-                ]
+                parent_trees = []
+                for p in self._get_parents(commit):
+                    parent_commit = self._store[p]
+                    assert isinstance(parent_commit, Commit)
+                    parent_trees.append(parent_commit.tree)
+                parent = parent_trees
                 # Use a lambda to adapt the signature
                 # Use a lambda to adapt the signature
                 changes_func = cast(
                 changes_func = cast(
                     Any,
                     Any,
@@ -200,7 +205,9 @@ class _CommitTimeQueue:
                     # some caching (which DiskObjectStore currently does not).
                     # some caching (which DiskObjectStore currently does not).
                     # We could either add caching in this class or pass around
                     # We could either add caching in this class or pass around
                     # parsed queue entry objects instead of commits.
                     # parsed queue entry objects instead of commits.
-                    todo.append(cast(Commit, self._store[parent]))
+                    parent_commit = self._store[parent]
+                    assert isinstance(parent_commit, Commit)
+                    todo.append(parent_commit)
                 excluded.add(parent)
                 excluded.add(parent)
 
 
     def next(self) -> Optional[WalkEntry]:
     def next(self) -> Optional[WalkEntry]:

+ 141 - 64
dulwich/web.py

@@ -26,9 +26,10 @@ import os
 import re
 import re
 import sys
 import sys
 import time
 import time
-from collections.abc import Iterator
+from collections.abc import Iterable, Iterator
 from io import BytesIO
 from io import BytesIO
-from typing import BinaryIO, Callable, ClassVar, Optional, cast
+from types import TracebackType
+from typing import Any, BinaryIO, Callable, ClassVar, Optional, Union, cast
 from urllib.parse import parse_qs
 from urllib.parse import parse_qs
 from wsgiref.simple_server import (
 from wsgiref.simple_server import (
     ServerHandler,
     ServerHandler,
@@ -37,6 +38,45 @@ from wsgiref.simple_server import (
     make_server,
     make_server,
 )
 )
 
 
+# wsgiref.types was added in Python 3.11
+if sys.version_info >= (3, 11):
+    from wsgiref.types import StartResponse, WSGIApplication, WSGIEnvironment
+else:
+    # Fallback type definitions for Python < 3.11
+    from typing import TYPE_CHECKING
+
+    if TYPE_CHECKING:
+        # For type checking, use the _typeshed types if available
+        try:
+            from _typeshed.wsgi import StartResponse, WSGIApplication, WSGIEnvironment
+        except ImportError:
+            # Define our own protocol types for type checking
+            from typing import Protocol
+
+            class StartResponse(Protocol):  # type: ignore[no-redef]
+                """WSGI start_response callable protocol."""
+
+                def __call__(
+                    self,
+                    status: str,
+                    response_headers: list[tuple[str, str]],
+                    exc_info: Optional[
+                        tuple[type, BaseException, TracebackType]
+                    ] = None,
+                ) -> Callable[[bytes], None]:
+                    """Start the response with status and headers."""
+                    ...
+
+            WSGIEnvironment = dict[str, Any]  # type: ignore[misc]
+            WSGIApplication = Callable[  # type: ignore[misc]
+                [WSGIEnvironment, StartResponse], Iterable[bytes]
+            ]
+    else:
+        # At runtime, just use type aliases since these are only for type hints
+        StartResponse = Any
+        WSGIEnvironment = dict[str, Any]
+        WSGIApplication = Callable
+
 from dulwich import log_utils
 from dulwich import log_utils
 
 
 from .protocol import ReceivableProtocol
 from .protocol import ReceivableProtocol
@@ -45,6 +85,7 @@ from .server import (
     DEFAULT_HANDLERS,
     DEFAULT_HANDLERS,
     Backend,
     Backend,
     DictBackend,
     DictBackend,
+    Handler,
     generate_info_refs,
     generate_info_refs,
     generate_objects_info_packs,
     generate_objects_info_packs,
 )
 )
@@ -292,13 +333,21 @@ def get_info_refs(
         yield req.not_found(str(e))
         yield req.not_found(str(e))
         return
         return
     if service and not req.dumb:
     if service and not req.dumb:
+        if req.handlers is None:
+            yield req.forbidden("No handlers configured")
+            return
         handler_cls = req.handlers.get(service.encode("ascii"), None)
         handler_cls = req.handlers.get(service.encode("ascii"), None)
         if handler_cls is None:
         if handler_cls is None:
             yield req.forbidden("Unsupported service")
             yield req.forbidden("Unsupported service")
             return
             return
         req.nocache()
         req.nocache()
         write = req.respond(HTTP_OK, f"application/x-{service}-advertisement")
         write = req.respond(HTTP_OK, f"application/x-{service}-advertisement")
-        proto = ReceivableProtocol(BytesIO().read, write)
+
+        def write_fn(data: bytes) -> Optional[int]:
+            result = write(data)
+            return len(data) if result is not None else None
+
+        proto = ReceivableProtocol(BytesIO().read, write_fn)
         handler = handler_cls(
         handler = handler_cls(
             backend,
             backend,
             [url_prefix(mat)],
             [url_prefix(mat)],
@@ -425,6 +474,9 @@ def handle_service_request(
     """
     """
     service = mat.group().lstrip("/")
     service = mat.group().lstrip("/")
     logger.info("Handling service request for %s", service)
     logger.info("Handling service request for %s", service)
+    if req.handlers is None:
+        yield req.forbidden("No handlers configured")
+        return
     handler_cls = req.handlers.get(service.encode("ascii"), None)
     handler_cls = req.handlers.get(service.encode("ascii"), None)
     if handler_cls is None:
     if handler_cls is None:
         yield req.forbidden("Unsupported service")
         yield req.forbidden("Unsupported service")
@@ -436,11 +488,16 @@ def handle_service_request(
         return
         return
     req.nocache()
     req.nocache()
     write = req.respond(HTTP_OK, f"application/x-{service}-result")
     write = req.respond(HTTP_OK, f"application/x-{service}-result")
+
+    def write_fn(data: bytes) -> Optional[int]:
+        result = write(data)
+        return len(data) if result is not None else None
+
     if req.environ.get("HTTP_TRANSFER_ENCODING") == "chunked":
     if req.environ.get("HTTP_TRANSFER_ENCODING") == "chunked":
         read = ChunkReader(req.environ["wsgi.input"]).read
         read = ChunkReader(req.environ["wsgi.input"]).read
     else:
     else:
         read = req.environ["wsgi.input"].read
         read = req.environ["wsgi.input"].read
-    proto = ReceivableProtocol(read, write)
+    proto = ReceivableProtocol(read, write_fn)
     # TODO(jelmer): Find a way to pass in repo, rather than having handler_cls
     # TODO(jelmer): Find a way to pass in repo, rather than having handler_cls
     # reopen.
     # reopen.
     handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True)
     handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True)
@@ -455,7 +512,11 @@ class HTTPGitRequest:
     """
     """
 
 
     def __init__(
     def __init__(
-        self, environ, start_response, dumb: bool = False, handlers=None
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+        dumb: bool = False,
+        handlers: Optional[dict[bytes, Callable]] = None,
     ) -> None:
     ) -> None:
         """Initialize HTTPGitRequest.
         """Initialize HTTPGitRequest.
 
 
@@ -472,7 +533,7 @@ class HTTPGitRequest:
         self._cache_headers: list[tuple[str, str]] = []
         self._cache_headers: list[tuple[str, str]] = []
         self._headers: list[tuple[str, str]] = []
         self._headers: list[tuple[str, str]] = []
 
 
-    def add_header(self, name, value) -> None:
+    def add_header(self, name: str, value: str) -> None:
         """Add a header to the response."""
         """Add a header to the response."""
         self._headers.append((name, value))
         self._headers.append((name, value))
 
 
@@ -481,7 +542,7 @@ class HTTPGitRequest:
         status: str = HTTP_OK,
         status: str = HTTP_OK,
         content_type: Optional[str] = None,
         content_type: Optional[str] = None,
         headers: Optional[list[tuple[str, str]]] = None,
         headers: Optional[list[tuple[str, str]]] = None,
-    ):
+    ) -> Callable[[bytes], object]:
         """Begin a response with the given status and other headers."""
         """Begin a response with the given status and other headers."""
         if headers:
         if headers:
             self._headers.extend(headers)
             self._headers.extend(headers)
@@ -556,7 +617,11 @@ class HTTPGitApplication:
     }
     }
 
 
     def __init__(
     def __init__(
-        self, backend, dumb: bool = False, handlers=None, fallback_app=None
+        self,
+        backend: Backend,
+        dumb: bool = False,
+        handlers: Optional[dict[bytes, Callable]] = None,
+        fallback_app: Optional[WSGIApplication] = None,
     ) -> None:
     ) -> None:
         """Initialize HTTPGitApplication.
         """Initialize HTTPGitApplication.
 
 
@@ -568,12 +633,18 @@ class HTTPGitApplication:
         """
         """
         self.backend = backend
         self.backend = backend
         self.dumb = dumb
         self.dumb = dumb
-        self.handlers = dict(DEFAULT_HANDLERS)
+        self.handlers: dict[bytes, Union[type[Handler], Callable[..., Any]]] = dict(
+            DEFAULT_HANDLERS
+        )
         self.fallback_app = fallback_app
         self.fallback_app = fallback_app
         if handlers is not None:
         if handlers is not None:
             self.handlers.update(handlers)
             self.handlers.update(handlers)
 
 
-    def __call__(self, environ, start_response):
+    def __call__(
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+    ) -> Iterable[bytes]:
         """Handle WSGI request."""
         """Handle WSGI request."""
         path = environ["PATH_INFO"]
         path = environ["PATH_INFO"]
         method = environ["REQUEST_METHOD"]
         method = environ["REQUEST_METHOD"]
@@ -582,6 +653,7 @@ class HTTPGitApplication:
         )
         )
         # environ['QUERY_STRING'] has qs args
         # environ['QUERY_STRING'] has qs args
         handler = None
         handler = None
+        mat = None
         for smethod, spath in self.services.keys():
         for smethod, spath in self.services.keys():
             if smethod != method:
             if smethod != method:
                 continue
                 continue
@@ -590,7 +662,7 @@ class HTTPGitApplication:
                 handler = self.services[smethod, spath]
                 handler = self.services[smethod, spath]
                 break
                 break
 
 
-        if handler is None:
+        if handler is None or mat is None:
             if self.fallback_app is not None:
             if self.fallback_app is not None:
                 return self.fallback_app(environ, start_response)
                 return self.fallback_app(environ, start_response)
             else:
             else:
@@ -602,12 +674,16 @@ class HTTPGitApplication:
 class GunzipFilter:
 class GunzipFilter:
     """WSGI middleware that unzips gzip-encoded requests before passing on to the underlying application."""
     """WSGI middleware that unzips gzip-encoded requests before passing on to the underlying application."""
 
 
-    def __init__(self, application) -> None:
-        """Initialize GunzipFilter."""
+    def __init__(self, application: WSGIApplication) -> None:
+        """Initialize GunzipFilter with WSGI application."""
         self.app = application
         self.app = application
 
 
-    def __call__(self, environ, start_response):
-        """Handle WSGI request."""
+    def __call__(
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+    ) -> Iterable[bytes]:
+        """Handle WSGI request with gzip decompression."""
         import gzip
         import gzip
 
 
         if environ.get("HTTP_CONTENT_ENCODING", "") == "gzip":
         if environ.get("HTTP_CONTENT_ENCODING", "") == "gzip":
@@ -615,8 +691,7 @@ class GunzipFilter:
                 filename=None, fileobj=environ["wsgi.input"], mode="rb"
                 filename=None, fileobj=environ["wsgi.input"], mode="rb"
             )
             )
             del environ["HTTP_CONTENT_ENCODING"]
             del environ["HTTP_CONTENT_ENCODING"]
-            if "CONTENT_LENGTH" in environ:
-                del environ["CONTENT_LENGTH"]
+            environ.pop("CONTENT_LENGTH", None)
 
 
         return self.app(environ, start_response)
         return self.app(environ, start_response)
 
 
@@ -624,12 +699,16 @@ class GunzipFilter:
 class LimitedInputFilter:
 class LimitedInputFilter:
     """WSGI middleware that limits the input length of a request to that specified in Content-Length."""
     """WSGI middleware that limits the input length of a request to that specified in Content-Length."""
 
 
-    def __init__(self, application) -> None:
-        """Initialize LimitedInputFilter."""
+    def __init__(self, application: WSGIApplication) -> None:
+        """Initialize LimitedInputFilter with WSGI application."""
         self.app = application
         self.app = application
 
 
-    def __call__(self, environ, start_response):
-        """Handle WSGI request."""
+    def __call__(
+        self,
+        environ: WSGIEnvironment,
+        start_response: StartResponse,
+    ) -> Iterable[bytes]:
+        """Handle WSGI request with input length limiting."""
         # This is not necessary if this app is run from a conforming WSGI
         # This is not necessary if this app is run from a conforming WSGI
         # server. Unfortunately, there's no way to tell that at this point.
         # server. Unfortunately, there's no way to tell that at this point.
         # TODO: git may used HTTP/1.1 chunked encoding instead of specifying
         # TODO: git may used HTTP/1.1 chunked encoding instead of specifying
@@ -642,9 +721,19 @@ class LimitedInputFilter:
         return self.app(environ, start_response)
         return self.app(environ, start_response)
 
 
 
 
-def make_wsgi_chain(*args, **kwargs):
-    """Factory function to create an instance of HTTPGitApplication, correctly wrapped with needed middleware."""
-    app = HTTPGitApplication(*args, **kwargs)
+def make_wsgi_chain(
+    backend: Backend,
+    dumb: bool = False,
+    handlers: Optional[dict[bytes, Callable[..., Any]]] = None,
+    fallback_app: Optional[WSGIApplication] = None,
+) -> WSGIApplication:
+    """Factory function to create an instance of HTTPGitApplication.
+
+    Correctly wrapped with needed middleware.
+    """
+    app = HTTPGitApplication(
+        backend, dumb=dumb, handlers=handlers, fallback_app=fallback_app
+    )
     wrapped_app = LimitedInputFilter(GunzipFilter(app))
     wrapped_app = LimitedInputFilter(GunzipFilter(app))
     return wrapped_app
     return wrapped_app
 
 
@@ -652,64 +741,52 @@ def make_wsgi_chain(*args, **kwargs):
 class ServerHandlerLogger(ServerHandler):
 class ServerHandlerLogger(ServerHandler):
     """ServerHandler that uses dulwich's logger for logging exceptions."""
     """ServerHandler that uses dulwich's logger for logging exceptions."""
 
 
-    def log_exception(self, exc_info) -> None:
-        """Log an exception using dulwich's logger.
-
-        Args:
-          exc_info: Exception information tuple
-        """
+    def log_exception(
+        self,
+        exc_info: Union[
+            tuple[type[BaseException], BaseException, TracebackType],
+            tuple[None, None, None],
+            None,
+        ],
+    ) -> None:
+        """Log exception using dulwich logger."""
         logger.exception(
         logger.exception(
             "Exception happened during processing of request",
             "Exception happened during processing of request",
             exc_info=exc_info,
             exc_info=exc_info,
         )
         )
 
 
-    def log_message(self, format, *args) -> None:
-        """Log a message using dulwich's logger.
-
-        Args:
-          format: Format string for the message
-          *args: Arguments for the format string
-        """
+    def log_message(self, format: str, *args: object) -> None:
+        """Log message using dulwich logger."""
         logger.info(format, *args)
         logger.info(format, *args)
 
 
-    def log_error(self, *args) -> None:
-        """Log an error using dulwich's logger.
-
-        Args:
-          *args: Error message components
-        """
+    def log_error(self, *args: object) -> None:
+        """Log error using dulwich logger."""
         logger.error(*args)
         logger.error(*args)
 
 
 
 
 class WSGIRequestHandlerLogger(WSGIRequestHandler):
 class WSGIRequestHandlerLogger(WSGIRequestHandler):
     """WSGIRequestHandler that uses dulwich's logger for logging exceptions."""
     """WSGIRequestHandler that uses dulwich's logger for logging exceptions."""
 
 
-    def log_exception(self, exc_info) -> None:
-        """Log an exception using dulwich's logger.
-
-        Args:
-          exc_info: Exception information tuple
-        """
+    def log_exception(
+        self,
+        exc_info: Union[
+            tuple[type[BaseException], BaseException, TracebackType],
+            tuple[None, None, None],
+            None,
+        ],
+    ) -> None:
+        """Log exception using dulwich logger."""
         logger.exception(
         logger.exception(
             "Exception happened during processing of request",
             "Exception happened during processing of request",
             exc_info=exc_info,
             exc_info=exc_info,
         )
         )
 
 
-    def log_message(self, format, *args) -> None:
-        """Log a message using dulwich's logger.
-
-        Args:
-          format: Format string for the message
-          *args: Arguments for the format string
-        """
+    def log_message(self, format: str, *args: object) -> None:
+        """Log message using dulwich logger."""
         logger.info(format, *args)
         logger.info(format, *args)
 
 
-    def log_error(self, *args) -> None:
-        """Log an error using dulwich's logger.
-
-        Args:
-          *args: Error message components
-        """
+    def log_error(self, *args: object) -> None:
+        """Log error using dulwich logger."""
         logger.error(*args)
         logger.error(*args)
 
 
     def handle(self) -> None:
     def handle(self) -> None:
@@ -731,14 +808,14 @@ class WSGIRequestHandlerLogger(WSGIRequestHandler):
 class WSGIServerLogger(WSGIServer):
 class WSGIServerLogger(WSGIServer):
     """WSGIServer that uses dulwich's logger for error handling."""
     """WSGIServer that uses dulwich's logger for error handling."""
 
 
-    def handle_error(self, request, client_address) -> None:
+    def handle_error(self, request: object, client_address: tuple[str, int]) -> None:
         """Handle an error."""
         """Handle an error."""
         logger.exception(
         logger.exception(
             f"Exception happened during processing of request from {client_address!s}"
             f"Exception happened during processing of request from {client_address!s}"
         )
         )
 
 
 
 
-def main(argv=sys.argv) -> None:
+def main(argv: list[str] = sys.argv) -> None:
     """Entry point for starting an HTTP git server."""
     """Entry point for starting an HTTP git server."""
     import optparse
     import optparse
 
 

+ 41 - 16
dulwich/worktree.py

@@ -34,9 +34,10 @@ import warnings
 from collections.abc import Iterable, Iterator
 from collections.abc import Iterable, Iterator
 from contextlib import contextmanager
 from contextlib import contextmanager
 from pathlib import Path
 from pathlib import Path
+from typing import Any, Callable, Union
 
 
 from .errors import CommitError, HookError
 from .errors import CommitError, HookError
-from .objects import Commit, ObjectID, Tag, Tree
+from .objects import Blob, Commit, ObjectID, Tag, Tree
 from .refs import SYMREF, Ref
 from .refs import SYMREF, Ref
 from .repo import (
 from .repo import (
     GITDIR,
     GITDIR,
@@ -335,7 +336,7 @@ class WorkTree:
 
 
         index = self._repo.open_index()
         index = self._repo.open_index()
         try:
         try:
-            tree_id = self._repo[b"HEAD"].tree
+            commit = self._repo[b"HEAD"]
         except KeyError:
         except KeyError:
             # no head mean no commit in the repo
             # no head mean no commit in the repo
             for fs_path in fs_paths:
             for fs_path in fs_paths:
@@ -343,6 +344,9 @@ class WorkTree:
                 del index[tree_path]
                 del index[tree_path]
             index.write()
             index.write()
             return
             return
+        else:
+            assert isinstance(commit, Commit), "HEAD must be a commit"
+            tree_id = commit.tree
 
 
         for fs_path in fs_paths:
         for fs_path in fs_paths:
             tree_path = _fs_to_tree_path(fs_path)
             tree_path = _fs_to_tree_path(fs_path)
@@ -367,15 +371,19 @@ class WorkTree:
             except FileNotFoundError:
             except FileNotFoundError:
                 pass
                 pass
 
 
+            blob_obj = self._repo[tree_entry[1]]
+            assert isinstance(blob_obj, Blob)
+            blob_size = len(blob_obj.data)
+
             index_entry = IndexEntry(
             index_entry = IndexEntry(
-                ctime=(self._repo[b"HEAD"].commit_time, 0),
-                mtime=(self._repo[b"HEAD"].commit_time, 0),
+                ctime=(commit.commit_time, 0),
+                mtime=(commit.commit_time, 0),
                 dev=st.st_dev if st else 0,
                 dev=st.st_dev if st else 0,
                 ino=st.st_ino if st else 0,
                 ino=st.st_ino if st else 0,
                 mode=tree_entry[0],
                 mode=tree_entry[0],
                 uid=st.st_uid if st else 0,
                 uid=st.st_uid if st else 0,
                 gid=st.st_gid if st else 0,
                 gid=st.st_gid if st else 0,
-                size=len(self._repo[tree_entry[1]].data),
+                size=blob_size,
                 sha=tree_entry[1],
                 sha=tree_entry[1],
                 flags=0,
                 flags=0,
                 extended_flags=0,
                 extended_flags=0,
@@ -386,7 +394,7 @@ class WorkTree:
 
 
     def commit(
     def commit(
         self,
         self,
-        message: bytes | None = None,
+        message: Union[str, bytes, Callable[[Any, Commit], bytes], None] = None,
         committer: bytes | None = None,
         committer: bytes | None = None,
         author: bytes | None = None,
         author: bytes | None = None,
         commit_timestamp: float | None = None,
         commit_timestamp: float | None = None,
@@ -541,13 +549,18 @@ class WorkTree:
                 if should_sign:
                 if should_sign:
                     c.sign(keyid)
                     c.sign(keyid)
                 self._repo.object_store.add_object(c)
                 self._repo.object_store.add_object(c)
+                message_bytes = (
+                    message.encode() if isinstance(message, str) else message
+                )
                 ok = self._repo.refs.set_if_equals(
                 ok = self._repo.refs.set_if_equals(
                     ref,
                     ref,
                     old_head,
                     old_head,
                     c.id,
                     c.id,
-                    message=b"commit: " + message,
+                    message=b"commit: " + message_bytes,
                     committer=committer,
                     committer=committer,
-                    timestamp=commit_timestamp,
+                    timestamp=int(commit_timestamp)
+                    if commit_timestamp is not None
+                    else None,
                     timezone=commit_timezone,
                     timezone=commit_timezone,
                 )
                 )
             except KeyError:
             except KeyError:
@@ -555,12 +568,17 @@ class WorkTree:
                 if should_sign:
                 if should_sign:
                     c.sign(keyid)
                     c.sign(keyid)
                 self._repo.object_store.add_object(c)
                 self._repo.object_store.add_object(c)
+                message_bytes = (
+                    message.encode() if isinstance(message, str) else message
+                )
                 ok = self._repo.refs.add_if_new(
                 ok = self._repo.refs.add_if_new(
                     ref,
                     ref,
                     c.id,
                     c.id,
-                    message=b"commit: " + message,
+                    message=b"commit: " + message_bytes,
                     committer=committer,
                     committer=committer,
-                    timestamp=commit_timestamp,
+                    timestamp=int(commit_timestamp)
+                    if commit_timestamp is not None
+                    else None,
                     timezone=commit_timezone,
                     timezone=commit_timezone,
                 )
                 )
             if not ok:
             if not ok:
@@ -603,6 +621,9 @@ class WorkTree:
             if isinstance(head, Tag):
             if isinstance(head, Tag):
                 _cls, obj = head.object
                 _cls, obj = head.object
                 head = self._repo.get_object(obj)
                 head = self._repo.get_object(obj)
+            from .objects import Commit
+
+            assert isinstance(head, Commit)
             tree = head.tree
             tree = head.tree
         config = self._repo.get_config()
         config = self._repo.get_config()
         honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
         honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
@@ -616,11 +637,15 @@ class WorkTree:
             symlink_fn = symlink
             symlink_fn = symlink
         else:
         else:
 
 
-            def symlink_fn(source, target) -> None:  # type: ignore
-                with open(
-                    target, "w" + ("b" if isinstance(source, bytes) else "")
-                ) as f:
-                    f.write(source)
+            def symlink_fn(
+                src: Union[str, bytes, os.PathLike],
+                dst: Union[str, bytes, os.PathLike],
+                target_is_directory: bool = False,
+                *,
+                dir_fd: int | None = None,
+            ) -> None:
+                with open(dst, "w" + ("b" if isinstance(src, bytes) else "")) as f:
+                    f.write(src)
 
 
         blob_normalizer = self._repo.get_blob_normalizer()
         blob_normalizer = self._repo.get_blob_normalizer()
         return build_index_from_tree(
         return build_index_from_tree(
@@ -630,7 +655,7 @@ class WorkTree:
             tree,
             tree,
             honor_filemode=honor_filemode,
             honor_filemode=honor_filemode,
             validate_path_element=validate_path_element,
             validate_path_element=validate_path_element,
-            symlink_fn=symlink_fn,
+            symlink_fn=symlink_fn,  # type: ignore[arg-type]
             blob_normalizer=blob_normalizer,
             blob_normalizer=blob_normalizer,
         )
         )
 
 

+ 2 - 1
pyproject.toml

@@ -25,7 +25,7 @@ classifiers = [
 requires-python = ">=3.9"
 requires-python = ">=3.9"
 dependencies = [
 dependencies = [
     "urllib3>=1.25",
     "urllib3>=1.25",
-    'typing_extensions >=4.0 ; python_version < "3.11"',
+    'typing_extensions >=4.0 ; python_version < "3.12"',
 ]
 ]
 dynamic = ["version"]
 dynamic = ["version"]
 license-files = ["COPYING"]
 license-files = ["COPYING"]
@@ -99,6 +99,7 @@ ignore = [
     "ANN205",
     "ANN205",
     "ANN206",
     "ANN206",
     "E501",  # line too long
     "E501",  # line too long
+    "UP007",  # Use X | Y for type annotations (Python 3.10+ syntax, but we support 3.9+)
 ]
 ]
 
 
 [tool.ruff.lint.pydocstyle]
 [tool.ruff.lint.pydocstyle]

+ 18 - 18
tests/__init__.py

@@ -38,7 +38,7 @@ import tempfile
 
 
 # If Python itself provides an exception, use that
 # If Python itself provides an exception, use that
 import unittest
 import unittest
-from typing import ClassVar
+from typing import ClassVar, Optional
 from unittest import SkipTest, expectedFailure, skipIf
 from unittest import SkipTest, expectedFailure, skipIf
 from unittest import TestCase as _TestCase
 from unittest import TestCase as _TestCase
 
 
@@ -49,7 +49,7 @@ class TestCase(_TestCase):
         self.overrideEnv("HOME", "/nonexistent")
         self.overrideEnv("HOME", "/nonexistent")
         self.overrideEnv("GIT_CONFIG_NOSYSTEM", "1")
         self.overrideEnv("GIT_CONFIG_NOSYSTEM", "1")
 
 
-    def overrideEnv(self, name, value) -> None:
+    def overrideEnv(self, name: str, value: Optional[str]) -> None:
         def restore() -> None:
         def restore() -> None:
             if oldval is not None:
             if oldval is not None:
                 os.environ[name] = oldval
                 os.environ[name] = oldval
@@ -74,7 +74,7 @@ class BlackboxTestCase(TestCase):
         "/usr/local/bin",
         "/usr/local/bin",
     ]
     ]
 
 
-    def bin_path(self, name):
+    def bin_path(self, name: str) -> str:
         """Determine the full path of a binary.
         """Determine the full path of a binary.
 
 
         Args:
         Args:
@@ -88,7 +88,7 @@ class BlackboxTestCase(TestCase):
         else:
         else:
             raise SkipTest(f"Unable to find binary {name}")
             raise SkipTest(f"Unable to find binary {name}")
 
 
-    def run_command(self, name, args):
+    def run_command(self, name: str, args: list[str]) -> subprocess.Popen[bytes]:
         """Run a Dulwich command.
         """Run a Dulwich command.
 
 
         Args:
         Args:
@@ -113,7 +113,7 @@ class BlackboxTestCase(TestCase):
         )
         )
 
 
 
 
-def self_test_suite():
+def self_test_suite() -> unittest.TestSuite:
     names = [
     names = [
         "annotate",
         "annotate",
         "archive",
         "archive",
@@ -181,7 +181,7 @@ def self_test_suite():
     return loader.loadTestsFromNames(module_names)
     return loader.loadTestsFromNames(module_names)
 
 
 
 
-def tutorial_test_suite():
+def tutorial_test_suite() -> unittest.TestSuite:
     tutorial = [
     tutorial = [
         "introduction",
         "introduction",
         "file-format",
         "file-format",
@@ -194,7 +194,7 @@ def tutorial_test_suite():
 
 
     to_restore = []
     to_restore = []
 
 
-    def overrideEnv(name, value) -> None:
+    def overrideEnv(name: str, value: Optional[str]) -> None:
         oldval = os.environ.get(name)
         oldval = os.environ.get(name)
         if value is not None:
         if value is not None:
             os.environ[name] = value
             os.environ[name] = value
@@ -202,17 +202,17 @@ def tutorial_test_suite():
             del os.environ[name]
             del os.environ[name]
         to_restore.append((name, oldval))
         to_restore.append((name, oldval))
 
 
-    def setup(test) -> None:
-        test.__old_cwd = os.getcwd()
-        test.tempdir = tempfile.mkdtemp()
-        test.globs.update({"tempdir": test.tempdir})
-        os.chdir(test.tempdir)
+    def setup(test: doctest.DocTest) -> None:
+        test.__old_cwd = os.getcwd()  # type: ignore[attr-defined]
+        test.tempdir = tempfile.mkdtemp()  # type: ignore[attr-defined]
+        test.globs.update({"tempdir": test.tempdir})  # type: ignore[attr-defined]
+        os.chdir(test.tempdir)  # type: ignore[attr-defined]
         overrideEnv("HOME", "/nonexistent")
         overrideEnv("HOME", "/nonexistent")
         overrideEnv("GIT_CONFIG_NOSYSTEM", "1")
         overrideEnv("GIT_CONFIG_NOSYSTEM", "1")
 
 
-    def teardown(test) -> None:
-        os.chdir(test.__old_cwd)
-        shutil.rmtree(test.tempdir)
+    def teardown(test: doctest.DocTest) -> None:
+        os.chdir(test.__old_cwd)  # type: ignore[attr-defined]
+        shutil.rmtree(test.tempdir)  # type: ignore[attr-defined]
         for name, oldval in to_restore:
         for name, oldval in to_restore:
             if oldval is not None:
             if oldval is not None:
                 os.environ[name] = oldval
                 os.environ[name] = oldval
@@ -229,7 +229,7 @@ def tutorial_test_suite():
     )
     )
 
 
 
 
-def nocompat_test_suite():
+def nocompat_test_suite() -> unittest.TestSuite:
     result = unittest.TestSuite()
     result = unittest.TestSuite()
     result.addTests(self_test_suite())
     result.addTests(self_test_suite())
     result.addTests(tutorial_test_suite())
     result.addTests(tutorial_test_suite())
@@ -239,7 +239,7 @@ def nocompat_test_suite():
     return result
     return result
 
 
 
 
-def compat_test_suite():
+def compat_test_suite() -> unittest.TestSuite:
     result = unittest.TestSuite()
     result = unittest.TestSuite()
     from .compat import test_suite as compat_test_suite
     from .compat import test_suite as compat_test_suite
 
 
@@ -247,7 +247,7 @@ def compat_test_suite():
     return result
     return result
 
 
 
 
-def test_suite():
+def test_suite() -> unittest.TestSuite:
     result = unittest.TestSuite()
     result = unittest.TestSuite()
     result.addTests(self_test_suite())
     result.addTests(self_test_suite())
     if sys.platform != "win32":
     if sys.platform != "win32":

+ 1 - 1
tests/compat/__init__.py

@@ -24,7 +24,7 @@
 import unittest
 import unittest
 
 
 
 
-def test_suite():
+def test_suite() -> unittest.TestSuite:
     names = [
     names = [
         "bundle",
         "bundle",
         "check_ignore",
         "check_ignore",

+ 2 - 2
tests/contrib/__init__.py

@@ -19,10 +19,10 @@
 # License, Version 2.0.
 # License, Version 2.0.
 #
 #
 
 
+import unittest
 
 
-def test_suite():
-    import unittest
 
 
+def test_suite() -> unittest.TestSuite:
     names = [
     names = [
         "diffstat",
         "diffstat",
         "paramiko_vendor",
         "paramiko_vendor",

+ 51 - 34
tests/test_annotate.py

@@ -21,6 +21,7 @@
 import os
 import os
 import tempfile
 import tempfile
 import unittest
 import unittest
+from typing import Any, Optional
 from unittest import TestCase
 from unittest import TestCase
 
 
 from dulwich.annotate import annotate_lines, update_lines
 from dulwich.annotate import annotate_lines, update_lines
@@ -32,28 +33,28 @@ from dulwich.repo import Repo
 class UpdateLinesTestCase(TestCase):
 class UpdateLinesTestCase(TestCase):
     """Tests for update_lines function."""
     """Tests for update_lines function."""
 
 
-    def test_update_lines_equal(self):
+    def test_update_lines_equal(self) -> None:
         """Test update_lines when all lines are equal."""
         """Test update_lines when all lines are equal."""
-        old_lines = [
+        old_lines: list[tuple[tuple[Any, Any], bytes]] = [
             (("commit1", "entry1"), b"line1"),
             (("commit1", "entry1"), b"line1"),
             (("commit2", "entry2"), b"line2"),
             (("commit2", "entry2"), b"line2"),
         ]
         ]
         new_blob = b"line1\nline2"
         new_blob = b"line1\nline2"
         new_history_data = ("commit3", "entry3")
         new_history_data = ("commit3", "entry3")
 
 
-        result = update_lines(old_lines, new_history_data, new_blob)
+        result = update_lines(old_lines, new_history_data, new_blob)  # type: ignore[arg-type]
         self.assertEqual(old_lines, result)
         self.assertEqual(old_lines, result)
 
 
-    def test_update_lines_insert(self):
+    def test_update_lines_insert(self) -> None:
         """Test update_lines when new lines are inserted."""
         """Test update_lines when new lines are inserted."""
-        old_lines = [
+        old_lines: list[tuple[tuple[Any, Any], bytes]] = [
             (("commit1", "entry1"), b"line1"),
             (("commit1", "entry1"), b"line1"),
             (("commit2", "entry2"), b"line3"),
             (("commit2", "entry2"), b"line3"),
         ]
         ]
         new_blob = b"line1\nline2\nline3"
         new_blob = b"line1\nline2\nline3"
         new_history_data = ("commit3", "entry3")
         new_history_data = ("commit3", "entry3")
 
 
-        result = update_lines(old_lines, new_history_data, new_blob)
+        result = update_lines(old_lines, new_history_data, new_blob)  # type: ignore[arg-type]
         expected = [
         expected = [
             (("commit1", "entry1"), b"line1"),
             (("commit1", "entry1"), b"line1"),
             (("commit3", "entry3"), b"line2"),
             (("commit3", "entry3"), b"line2"),
@@ -61,9 +62,9 @@ class UpdateLinesTestCase(TestCase):
         ]
         ]
         self.assertEqual(expected, result)
         self.assertEqual(expected, result)
 
 
-    def test_update_lines_delete(self):
+    def test_update_lines_delete(self) -> None:
         """Test update_lines when lines are deleted."""
         """Test update_lines when lines are deleted."""
-        old_lines = [
+        old_lines: list[tuple[tuple[Any, Any], bytes]] = [
             (("commit1", "entry1"), b"line1"),
             (("commit1", "entry1"), b"line1"),
             (("commit2", "entry2"), b"line2"),
             (("commit2", "entry2"), b"line2"),
             (("commit3", "entry3"), b"line3"),
             (("commit3", "entry3"), b"line3"),
@@ -71,66 +72,70 @@ class UpdateLinesTestCase(TestCase):
         new_blob = b"line1\nline3"
         new_blob = b"line1\nline3"
         new_history_data = ("commit4", "entry4")
         new_history_data = ("commit4", "entry4")
 
 
-        result = update_lines(old_lines, new_history_data, new_blob)
+        result = update_lines(old_lines, new_history_data, new_blob)  # type: ignore[arg-type]
         expected = [
         expected = [
             (("commit1", "entry1"), b"line1"),
             (("commit1", "entry1"), b"line1"),
             (("commit3", "entry3"), b"line3"),
             (("commit3", "entry3"), b"line3"),
         ]
         ]
         self.assertEqual(expected, result)
         self.assertEqual(expected, result)
 
 
-    def test_update_lines_replace(self):
+    def test_update_lines_replace(self) -> None:
         """Test update_lines when lines are replaced."""
         """Test update_lines when lines are replaced."""
-        old_lines = [
+        old_lines: list[tuple[tuple[Any, Any], bytes]] = [
             (("commit1", "entry1"), b"line1"),
             (("commit1", "entry1"), b"line1"),
             (("commit2", "entry2"), b"line2"),
             (("commit2", "entry2"), b"line2"),
         ]
         ]
         new_blob = b"line1\nline2_modified"
         new_blob = b"line1\nline2_modified"
         new_history_data = ("commit3", "entry3")
         new_history_data = ("commit3", "entry3")
 
 
-        result = update_lines(old_lines, new_history_data, new_blob)
+        result = update_lines(old_lines, new_history_data, new_blob)  # type: ignore[arg-type]
         expected = [
         expected = [
             (("commit1", "entry1"), b"line1"),
             (("commit1", "entry1"), b"line1"),
             (("commit3", "entry3"), b"line2_modified"),
             (("commit3", "entry3"), b"line2_modified"),
         ]
         ]
         self.assertEqual(expected, result)
         self.assertEqual(expected, result)
 
 
-    def test_update_lines_empty_old(self):
+    def test_update_lines_empty_old(self) -> None:
         """Test update_lines with empty old lines."""
         """Test update_lines with empty old lines."""
-        old_lines = []
+        old_lines: list[tuple[tuple[Any, Any], bytes]] = []
         new_blob = b"line1\nline2"
         new_blob = b"line1\nline2"
         new_history_data = ("commit1", "entry1")
         new_history_data = ("commit1", "entry1")
 
 
-        result = update_lines(old_lines, new_history_data, new_blob)
+        result = update_lines(old_lines, new_history_data, new_blob)  # type: ignore[arg-type]
         expected = [
         expected = [
             (("commit1", "entry1"), b"line1"),
             (("commit1", "entry1"), b"line1"),
             (("commit1", "entry1"), b"line2"),
             (("commit1", "entry1"), b"line2"),
         ]
         ]
         self.assertEqual(expected, result)
         self.assertEqual(expected, result)
 
 
-    def test_update_lines_empty_new(self):
+    def test_update_lines_empty_new(self) -> None:
         """Test update_lines with empty new blob."""
         """Test update_lines with empty new blob."""
-        old_lines = [(("commit1", "entry1"), b"line1")]
+        old_lines: list[tuple[tuple[Any, Any], bytes]] = [
+            (("commit1", "entry1"), b"line1")
+        ]
         new_blob = b""
         new_blob = b""
         new_history_data = ("commit2", "entry2")
         new_history_data = ("commit2", "entry2")
 
 
-        result = update_lines(old_lines, new_history_data, new_blob)
+        result = update_lines(old_lines, new_history_data, new_blob)  # type: ignore[arg-type]
         self.assertEqual([], result)
         self.assertEqual([], result)
 
 
 
 
 class AnnotateLinesTestCase(TestCase):
 class AnnotateLinesTestCase(TestCase):
     """Tests for annotate_lines function."""
     """Tests for annotate_lines function."""
 
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.temp_dir = tempfile.mkdtemp()
         self.temp_dir = tempfile.mkdtemp()
         self.repo = Repo.init(self.temp_dir)
         self.repo = Repo.init(self.temp_dir)
 
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         self.repo.close()
         self.repo.close()
         import shutil
         import shutil
 
 
         shutil.rmtree(self.temp_dir)
         shutil.rmtree(self.temp_dir)
 
 
-    def _make_commit(self, blob_content, message, parent=None):
+    def _make_commit(
+        self, blob_content: bytes, message: str, parent: Optional[bytes] = None
+    ) -> bytes:
         """Helper to create a commit with a single file."""
         """Helper to create a commit with a single file."""
         # Create blob
         # Create blob
         blob = Blob()
         blob = Blob()
@@ -159,7 +164,7 @@ class AnnotateLinesTestCase(TestCase):
         self.repo.object_store.add_object(commit)
         self.repo.object_store.add_object(commit)
         return commit.id
         return commit.id
 
 
-    def test_annotate_lines_single_commit(self):
+    def test_annotate_lines_single_commit(self) -> None:
         """Test annotating a file with a single commit."""
         """Test annotating a file with a single commit."""
         commit_id = self._make_commit(b"line1\nline2\nline3\n", "Initial commit")
         commit_id = self._make_commit(b"line1\nline2\nline3\n", "Initial commit")
 
 
@@ -170,7 +175,7 @@ class AnnotateLinesTestCase(TestCase):
             self.assertEqual(commit_id, commit.id)
             self.assertEqual(commit_id, commit.id)
             self.assertIn(line, [b"line1\n", b"line2\n", b"line3\n"])
             self.assertIn(line, [b"line1\n", b"line2\n", b"line3\n"])
 
 
-    def test_annotate_lines_multiple_commits(self):
+    def test_annotate_lines_multiple_commits(self) -> None:
         """Test annotating a file with multiple commits."""
         """Test annotating a file with multiple commits."""
         # First commit
         # First commit
         commit1_id = self._make_commit(b"line1\nline2\n", "Initial commit")
         commit1_id = self._make_commit(b"line1\nline2\n", "Initial commit")
@@ -200,7 +205,7 @@ class AnnotateLinesTestCase(TestCase):
         self.assertEqual(commit1_id, result[2][0][0].id)
         self.assertEqual(commit1_id, result[2][0][0].id)
         self.assertEqual(b"line2\n", result[2][1])
         self.assertEqual(b"line2\n", result[2][1])
 
 
-    def test_annotate_lines_nonexistent_path(self):
+    def test_annotate_lines_nonexistent_path(self) -> None:
         """Test annotating a nonexistent file."""
         """Test annotating a nonexistent file."""
         commit_id = self._make_commit(b"content\n", "Initial commit")
         commit_id = self._make_commit(b"content\n", "Initial commit")
 
 
@@ -211,17 +216,23 @@ class AnnotateLinesTestCase(TestCase):
 class PorcelainAnnotateTestCase(TestCase):
 class PorcelainAnnotateTestCase(TestCase):
     """Tests for the porcelain annotate function."""
     """Tests for the porcelain annotate function."""
 
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.temp_dir = tempfile.mkdtemp()
         self.temp_dir = tempfile.mkdtemp()
         self.repo = Repo.init(self.temp_dir)
         self.repo = Repo.init(self.temp_dir)
 
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         self.repo.close()
         self.repo.close()
         import shutil
         import shutil
 
 
         shutil.rmtree(self.temp_dir)
         shutil.rmtree(self.temp_dir)
 
 
-    def _make_commit_with_file(self, filename, content, message, parent=None):
+    def _make_commit_with_file(
+        self,
+        filename: str,
+        content: bytes,
+        message: str,
+        parent: Optional[bytes] = None,
+    ) -> bytes:
         """Helper to create a commit with a file."""
         """Helper to create a commit with a file."""
         # Create blob
         # Create blob
         blob = Blob()
         blob = Blob()
@@ -254,7 +265,7 @@ class PorcelainAnnotateTestCase(TestCase):
 
 
         return commit.id
         return commit.id
 
 
-    def test_porcelain_annotate(self):
+    def test_porcelain_annotate(self) -> None:
         """Test the porcelain annotate function."""
         """Test the porcelain annotate function."""
         # Create commits
         # Create commits
         commit1_id = self._make_commit_with_file(
         commit1_id = self._make_commit_with_file(
@@ -274,7 +285,7 @@ class PorcelainAnnotateTestCase(TestCase):
             self.assertIsNotNone(entry)
             self.assertIsNotNone(entry)
             self.assertIn(line, [b"line1\n", b"line2\n", b"line3\n"])
             self.assertIn(line, [b"line1\n", b"line2\n", b"line3\n"])
 
 
-    def test_porcelain_annotate_with_committish(self):
+    def test_porcelain_annotate_with_committish(self) -> None:
         """Test porcelain annotate with specific commit."""
         """Test porcelain annotate with specific commit."""
         # Create commits
         # Create commits
         commit1_id = self._make_commit_with_file(
         commit1_id = self._make_commit_with_file(
@@ -296,7 +307,7 @@ class PorcelainAnnotateTestCase(TestCase):
         self.assertEqual(1, len(result))
         self.assertEqual(1, len(result))
         self.assertEqual(b"modified\n", result[0][1])
         self.assertEqual(b"modified\n", result[0][1])
 
 
-    def test_blame_alias(self):
+    def test_blame_alias(self) -> None:
         """Test that blame is an alias for annotate."""
         """Test that blame is an alias for annotate."""
         self.assertIs(blame, annotate)
         self.assertIs(blame, annotate)
 
 
@@ -304,17 +315,23 @@ class PorcelainAnnotateTestCase(TestCase):
 class IntegrationTestCase(TestCase):
 class IntegrationTestCase(TestCase):
     """Integration tests with more complex scenarios."""
     """Integration tests with more complex scenarios."""
 
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.temp_dir = tempfile.mkdtemp()
         self.temp_dir = tempfile.mkdtemp()
         self.repo = Repo.init(self.temp_dir)
         self.repo = Repo.init(self.temp_dir)
 
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         self.repo.close()
         self.repo.close()
         import shutil
         import shutil
 
 
         shutil.rmtree(self.temp_dir)
         shutil.rmtree(self.temp_dir)
 
 
-    def _create_file_commit(self, filename, content, message, parent=None):
+    def _create_file_commit(
+        self,
+        filename: str,
+        content: bytes,
+        message: str,
+        parent: Optional[bytes] = None,
+    ) -> bytes:
         """Helper to create a commit with file content."""
         """Helper to create a commit with file content."""
         # Write file to working directory
         # Write file to working directory
         filepath = os.path.join(self.temp_dir, filename)
         filepath = os.path.join(self.temp_dir, filename)
@@ -337,7 +354,7 @@ class IntegrationTestCase(TestCase):
 
 
         return commit_id
         return commit_id
 
 
-    def test_complex_file_history(self):
+    def test_complex_file_history(self) -> None:
         """Test annotating a file with complex history."""
         """Test annotating a file with complex history."""
         # Initial commit with 3 lines
         # Initial commit with 3 lines
         self._create_file_commit(
         self._create_file_commit(

+ 7 - 10
tests/test_archive.py

@@ -24,7 +24,8 @@
 import struct
 import struct
 import tarfile
 import tarfile
 from io import BytesIO
 from io import BytesIO
-from unittest import skipUnless
+from typing import Optional
+from unittest.mock import patch
 
 
 from dulwich.archive import tar_stream
 from dulwich.archive import tar_stream
 from dulwich.object_store import MemoryObjectStore
 from dulwich.object_store import MemoryObjectStore
@@ -33,11 +34,6 @@ from dulwich.tests.utils import build_commit_graph
 
 
 from . import TestCase
 from . import TestCase
 
 
-try:
-    from unittest.mock import patch
-except ImportError:
-    patch = None
-
 
 
 class ArchiveTests(TestCase):
 class ArchiveTests(TestCase):
     def test_empty(self) -> None:
     def test_empty(self) -> None:
@@ -50,14 +46,16 @@ class ArchiveTests(TestCase):
         self.addCleanup(tf.close)
         self.addCleanup(tf.close)
         self.assertEqual([], tf.getnames())
         self.assertEqual([], tf.getnames())
 
 
-    def _get_example_tar_stream(self, *tar_stream_args, **tar_stream_kwargs):
+    def _get_example_tar_stream(
+        self, mtime: int, prefix: bytes = b"", format: str = ""
+    ) -> BytesIO:
         store = MemoryObjectStore()
         store = MemoryObjectStore()
         b1 = Blob.from_string(b"somedata")
         b1 = Blob.from_string(b"somedata")
         store.add_object(b1)
         store.add_object(b1)
         t1 = Tree()
         t1 = Tree()
         t1.add(b"somename", 0o100644, b1.id)
         t1.add(b"somename", 0o100644, b1.id)
         store.add_object(t1)
         store.add_object(t1)
-        stream = b"".join(tar_stream(store, t1, *tar_stream_args, **tar_stream_kwargs))
+        stream = b"".join(tar_stream(store, t1, mtime, prefix, format))
         return BytesIO(stream)
         return BytesIO(stream)
 
 
     def test_simple(self) -> None:
     def test_simple(self) -> None:
@@ -89,9 +87,8 @@ class ArchiveTests(TestCase):
         expected_mtime = struct.pack("<L", 1234)
         expected_mtime = struct.pack("<L", 1234)
         self.assertEqual(stream.getvalue()[4:8], expected_mtime)
         self.assertEqual(stream.getvalue()[4:8], expected_mtime)
 
 
-    @skipUnless(patch, "Required mock.patch")
     def test_same_file(self) -> None:
     def test_same_file(self) -> None:
-        contents = [None, None]
+        contents: list[Optional[bytes]] = [None, None]
         for format in ["", "gz", "bz2"]:
         for format in ["", "gz", "bz2"]:
             for i in [0, 1]:
             for i in [0, 1]:
                 with patch("time.time", return_value=i):
                 with patch("time.time", return_value=i):

+ 16 - 16
tests/test_bisect.py

@@ -36,19 +36,19 @@ from . import TestCase
 class BisectStateTests(TestCase):
 class BisectStateTests(TestCase):
     """Tests for BisectState class."""
     """Tests for BisectState class."""
 
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.test_dir = tempfile.mkdtemp()
         self.test_dir = tempfile.mkdtemp()
         self.repo = porcelain.init(self.test_dir)
         self.repo = porcelain.init(self.test_dir)
 
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         shutil.rmtree(self.test_dir)
         shutil.rmtree(self.test_dir)
 
 
-    def test_is_active_false(self):
+    def test_is_active_false(self) -> None:
         """Test is_active when no bisect session is active."""
         """Test is_active when no bisect session is active."""
         state = BisectState(self.repo)
         state = BisectState(self.repo)
         self.assertFalse(state.is_active)
         self.assertFalse(state.is_active)
 
 
-    def test_start_bisect(self):
+    def test_start_bisect(self) -> None:
         """Test starting a bisect session."""
         """Test starting a bisect session."""
         # Create at least one commit so HEAD exists
         # Create at least one commit so HEAD exists
         c1 = make_commit(id=b"1" * 40, message=b"initial commit")
         c1 = make_commit(id=b"1" * 40, message=b"initial commit")
@@ -73,7 +73,7 @@ class BisectStateTests(TestCase):
             os.path.exists(os.path.join(self.repo.controldir(), "BISECT_LOG"))
             os.path.exists(os.path.join(self.repo.controldir(), "BISECT_LOG"))
         )
         )
 
 
-    def test_start_bisect_no_head(self):
+    def test_start_bisect_no_head(self) -> None:
         """Test starting a bisect session when repository has no HEAD."""
         """Test starting a bisect session when repository has no HEAD."""
         state = BisectState(self.repo)
         state = BisectState(self.repo)
 
 
@@ -81,7 +81,7 @@ class BisectStateTests(TestCase):
             state.start()
             state.start()
         self.assertIn("Cannot start bisect: repository has no HEAD", str(cm.exception))
         self.assertIn("Cannot start bisect: repository has no HEAD", str(cm.exception))
 
 
-    def test_start_bisect_already_active(self):
+    def test_start_bisect_already_active(self) -> None:
         """Test starting a bisect session when one is already active."""
         """Test starting a bisect session when one is already active."""
         # Create at least one commit so HEAD exists
         # Create at least one commit so HEAD exists
         c1 = make_commit(id=b"1" * 40, message=b"initial commit")
         c1 = make_commit(id=b"1" * 40, message=b"initial commit")
@@ -94,28 +94,28 @@ class BisectStateTests(TestCase):
         with self.assertRaises(ValueError):
         with self.assertRaises(ValueError):
             state.start()
             state.start()
 
 
-    def test_mark_bad_no_session(self):
+    def test_mark_bad_no_session(self) -> None:
         """Test marking bad commit when no session is active."""
         """Test marking bad commit when no session is active."""
         state = BisectState(self.repo)
         state = BisectState(self.repo)
 
 
         with self.assertRaises(ValueError):
         with self.assertRaises(ValueError):
             state.mark_bad()
             state.mark_bad()
 
 
-    def test_mark_good_no_session(self):
+    def test_mark_good_no_session(self) -> None:
         """Test marking good commit when no session is active."""
         """Test marking good commit when no session is active."""
         state = BisectState(self.repo)
         state = BisectState(self.repo)
 
 
         with self.assertRaises(ValueError):
         with self.assertRaises(ValueError):
             state.mark_good()
             state.mark_good()
 
 
-    def test_reset_no_session(self):
+    def test_reset_no_session(self) -> None:
         """Test resetting when no session is active."""
         """Test resetting when no session is active."""
         state = BisectState(self.repo)
         state = BisectState(self.repo)
 
 
         with self.assertRaises(ValueError):
         with self.assertRaises(ValueError):
             state.reset()
             state.reset()
 
 
-    def test_bisect_workflow(self):
+    def test_bisect_workflow(self) -> None:
         """Test a complete bisect workflow."""
         """Test a complete bisect workflow."""
         # Create some commits
         # Create some commits
         c1 = make_commit(id=b"1" * 40, message=b"good commit 1")
         c1 = make_commit(id=b"1" * 40, message=b"good commit 1")
@@ -163,7 +163,7 @@ class BisectStateTests(TestCase):
 class BisectPorcelainTests(TestCase):
 class BisectPorcelainTests(TestCase):
     """Tests for porcelain bisect functions."""
     """Tests for porcelain bisect functions."""
 
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.test_dir = tempfile.mkdtemp()
         self.test_dir = tempfile.mkdtemp()
         self.repo = porcelain.init(self.test_dir)
         self.repo = porcelain.init(self.test_dir)
 
 
@@ -191,10 +191,10 @@ class BisectPorcelainTests(TestCase):
         self.repo.refs[b"HEAD"] = self.c4.id
         self.repo.refs[b"HEAD"] = self.c4.id
         self.repo.refs[b"refs/heads/master"] = self.c4.id
         self.repo.refs[b"refs/heads/master"] = self.c4.id
 
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         shutil.rmtree(self.test_dir)
         shutil.rmtree(self.test_dir)
 
 
-    def test_bisect_start(self):
+    def test_bisect_start(self) -> None:
         """Test bisect_start porcelain function."""
         """Test bisect_start porcelain function."""
         porcelain.bisect_start(self.test_dir)
         porcelain.bisect_start(self.test_dir)
 
 
@@ -203,7 +203,7 @@ class BisectPorcelainTests(TestCase):
             os.path.exists(os.path.join(self.repo.controldir(), "BISECT_START"))
             os.path.exists(os.path.join(self.repo.controldir(), "BISECT_START"))
         )
         )
 
 
-    def test_bisect_bad_good(self):
+    def test_bisect_bad_good(self) -> None:
         """Test marking commits as bad and good."""
         """Test marking commits as bad and good."""
         porcelain.bisect_start(self.test_dir)
         porcelain.bisect_start(self.test_dir)
         porcelain.bisect_bad(self.test_dir, self.c4.id.decode("ascii"))
         porcelain.bisect_bad(self.test_dir, self.c4.id.decode("ascii"))
@@ -226,7 +226,7 @@ class BisectPorcelainTests(TestCase):
             )
             )
         )
         )
 
 
-    def test_bisect_log(self):
+    def test_bisect_log(self) -> None:
         """Test getting bisect log."""
         """Test getting bisect log."""
         porcelain.bisect_start(self.test_dir)
         porcelain.bisect_start(self.test_dir)
         porcelain.bisect_bad(self.test_dir, self.c4.id.decode("ascii"))
         porcelain.bisect_bad(self.test_dir, self.c4.id.decode("ascii"))
@@ -238,7 +238,7 @@ class BisectPorcelainTests(TestCase):
         self.assertIn("git bisect bad", log)
         self.assertIn("git bisect bad", log)
         self.assertIn("git bisect good", log)
         self.assertIn("git bisect good", log)
 
 
-    def test_bisect_reset(self):
+    def test_bisect_reset(self) -> None:
         """Test resetting bisect state."""
         """Test resetting bisect state."""
         porcelain.bisect_start(self.test_dir)
         porcelain.bisect_start(self.test_dir)
         porcelain.bisect_bad(self.test_dir)
         porcelain.bisect_bad(self.test_dir)

+ 2 - 2
tests/test_cloud_gcs.py

@@ -39,8 +39,8 @@ class GcsObjectStoreTests(unittest.TestCase):
         self.assertIn("git", repr(self.store))
         self.assertIn("git", repr(self.store))
 
 
     def test_remove_pack(self):
     def test_remove_pack(self):
-        """Test _remove_pack method."""
-        self.store._remove_pack("pack-1234")
+        """Test _remove_pack_by_name method."""
+        self.store._remove_pack_by_name("pack-1234")
         self.mock_bucket.delete_blobs.assert_called_once()
         self.mock_bucket.delete_blobs.assert_called_once()
         args = self.mock_bucket.delete_blobs.call_args[0][0]
         args = self.mock_bucket.delete_blobs.call_args[0][0]
         self.assertEqual(
         self.assertEqual(

+ 51 - 49
tests/test_commit_graph.py

@@ -35,7 +35,7 @@ from dulwich.commit_graph import (
 class CommitGraphEntryTests(unittest.TestCase):
 class CommitGraphEntryTests(unittest.TestCase):
     """Tests for CommitGraphEntry."""
     """Tests for CommitGraphEntry."""
 
 
-    def test_init(self):
+    def test_init(self) -> None:
         commit_id = b"a" * 40
         commit_id = b"a" * 40
         tree_id = b"b" * 40
         tree_id = b"b" * 40
         parents = [b"c" * 40, b"d" * 40]
         parents = [b"c" * 40, b"d" * 40]
@@ -50,7 +50,7 @@ class CommitGraphEntryTests(unittest.TestCase):
         self.assertEqual(entry.generation, generation)
         self.assertEqual(entry.generation, generation)
         self.assertEqual(entry.commit_time, commit_time)
         self.assertEqual(entry.commit_time, commit_time)
 
 
-    def test_repr(self):
+    def test_repr(self) -> None:
         entry = CommitGraphEntry(b"a" * 40, b"b" * 40, [], 1, 1000)
         entry = CommitGraphEntry(b"a" * 40, b"b" * 40, [], 1, 1000)
         repr_str = repr(entry)
         repr_str = repr(entry)
         self.assertIn("CommitGraphEntry", repr_str)
         self.assertIn("CommitGraphEntry", repr_str)
@@ -60,12 +60,12 @@ class CommitGraphEntryTests(unittest.TestCase):
 class CommitGraphChunkTests(unittest.TestCase):
 class CommitGraphChunkTests(unittest.TestCase):
     """Tests for CommitGraphChunk."""
     """Tests for CommitGraphChunk."""
 
 
-    def test_init(self):
+    def test_init(self) -> None:
         chunk = CommitGraphChunk(b"TEST", b"test data")
         chunk = CommitGraphChunk(b"TEST", b"test data")
         self.assertEqual(chunk.chunk_id, b"TEST")
         self.assertEqual(chunk.chunk_id, b"TEST")
         self.assertEqual(chunk.data, b"test data")
         self.assertEqual(chunk.data, b"test data")
 
 
-    def test_repr(self):
+    def test_repr(self) -> None:
         chunk = CommitGraphChunk(b"TEST", b"x" * 100)
         chunk = CommitGraphChunk(b"TEST", b"x" * 100)
         repr_str = repr(chunk)
         repr_str = repr(chunk)
         self.assertIn("CommitGraphChunk", repr_str)
         self.assertIn("CommitGraphChunk", repr_str)
@@ -75,13 +75,13 @@ class CommitGraphChunkTests(unittest.TestCase):
 class CommitGraphTests(unittest.TestCase):
 class CommitGraphTests(unittest.TestCase):
     """Tests for CommitGraph."""
     """Tests for CommitGraph."""
 
 
-    def test_init(self):
+    def test_init(self) -> None:
         graph = CommitGraph()
         graph = CommitGraph()
         self.assertEqual(graph.hash_version, HASH_VERSION_SHA1)
         self.assertEqual(graph.hash_version, HASH_VERSION_SHA1)
         self.assertEqual(len(graph.entries), 0)
         self.assertEqual(len(graph.entries), 0)
         self.assertEqual(len(graph.chunks), 0)
         self.assertEqual(len(graph.chunks), 0)
 
 
-    def test_len(self):
+    def test_len(self) -> None:
         graph = CommitGraph()
         graph = CommitGraph()
         self.assertEqual(len(graph), 0)
         self.assertEqual(len(graph), 0)
 
 
@@ -90,7 +90,7 @@ class CommitGraphTests(unittest.TestCase):
         graph.entries.append(entry)
         graph.entries.append(entry)
         self.assertEqual(len(graph), 1)
         self.assertEqual(len(graph), 1)
 
 
-    def test_iter(self):
+    def test_iter(self) -> None:
         graph = CommitGraph()
         graph = CommitGraph()
         entry1 = CommitGraphEntry(b"a" * 40, b"b" * 40, [], 1, 1000)
         entry1 = CommitGraphEntry(b"a" * 40, b"b" * 40, [], 1, 1000)
         entry2 = CommitGraphEntry(b"c" * 40, b"d" * 40, [], 2, 2000)
         entry2 = CommitGraphEntry(b"c" * 40, b"d" * 40, [], 2, 2000)
@@ -101,22 +101,22 @@ class CommitGraphTests(unittest.TestCase):
         self.assertEqual(entries[0], entry1)
         self.assertEqual(entries[0], entry1)
         self.assertEqual(entries[1], entry2)
         self.assertEqual(entries[1], entry2)
 
 
-    def test_get_entry_by_oid_missing(self):
+    def test_get_entry_by_oid_missing(self) -> None:
         graph = CommitGraph()
         graph = CommitGraph()
         result = graph.get_entry_by_oid(b"f" * 40)
         result = graph.get_entry_by_oid(b"f" * 40)
         self.assertIsNone(result)
         self.assertIsNone(result)
 
 
-    def test_get_generation_number_missing(self):
+    def test_get_generation_number_missing(self) -> None:
         graph = CommitGraph()
         graph = CommitGraph()
         result = graph.get_generation_number(b"f" * 40)
         result = graph.get_generation_number(b"f" * 40)
         self.assertIsNone(result)
         self.assertIsNone(result)
 
 
-    def test_get_parents_missing(self):
+    def test_get_parents_missing(self) -> None:
         graph = CommitGraph()
         graph = CommitGraph()
         result = graph.get_parents(b"f" * 40)
         result = graph.get_parents(b"f" * 40)
         self.assertIsNone(result)
         self.assertIsNone(result)
 
 
-    def test_from_invalid_signature(self):
+    def test_from_invalid_signature(self) -> None:
         data = b"XXXX" + b"\\x00" * 100
         data = b"XXXX" + b"\\x00" * 100
         f = io.BytesIO(data)
         f = io.BytesIO(data)
 
 
@@ -124,7 +124,7 @@ class CommitGraphTests(unittest.TestCase):
             CommitGraph.from_file(f)
             CommitGraph.from_file(f)
         self.assertIn("Invalid commit graph signature", str(cm.exception))
         self.assertIn("Invalid commit graph signature", str(cm.exception))
 
 
-    def test_from_invalid_version(self):
+    def test_from_invalid_version(self) -> None:
         data = COMMIT_GRAPH_SIGNATURE + struct.pack(">B", 99) + b"\\x00" * 100
         data = COMMIT_GRAPH_SIGNATURE + struct.pack(">B", 99) + b"\\x00" * 100
         f = io.BytesIO(data)
         f = io.BytesIO(data)
 
 
@@ -132,7 +132,7 @@ class CommitGraphTests(unittest.TestCase):
             CommitGraph.from_file(f)
             CommitGraph.from_file(f)
         self.assertIn("Unsupported commit graph version", str(cm.exception))
         self.assertIn("Unsupported commit graph version", str(cm.exception))
 
 
-    def test_from_invalid_hash_version(self):
+    def test_from_invalid_hash_version(self) -> None:
         data = (
         data = (
             COMMIT_GRAPH_SIGNATURE
             COMMIT_GRAPH_SIGNATURE
             + struct.pack(">B", COMMIT_GRAPH_VERSION)
             + struct.pack(">B", COMMIT_GRAPH_VERSION)
@@ -145,7 +145,7 @@ class CommitGraphTests(unittest.TestCase):
             CommitGraph.from_file(f)
             CommitGraph.from_file(f)
         self.assertIn("Unsupported hash version", str(cm.exception))
         self.assertIn("Unsupported hash version", str(cm.exception))
 
 
-    def create_minimal_commit_graph_data(self):
+    def create_minimal_commit_graph_data(self) -> bytes:
         """Create minimal valid commit graph data for testing."""
         """Create minimal valid commit graph data for testing."""
         # Create the data in order and calculate offsets properly
         # Create the data in order and calculate offsets properly
 
 
@@ -209,7 +209,7 @@ class CommitGraphTests(unittest.TestCase):
 
 
         return header + toc + fanout + oid_lookup + commit_data
         return header + toc + fanout + oid_lookup + commit_data
 
 
-    def test_from_minimal_valid_file(self):
+    def test_from_minimal_valid_file(self) -> None:
         """Test parsing a minimal but valid commit graph file."""
         """Test parsing a minimal but valid commit graph file."""
         data = self.create_minimal_commit_graph_data()
         data = self.create_minimal_commit_graph_data()
         f = io.BytesIO(data)
         f = io.BytesIO(data)
@@ -233,7 +233,7 @@ class CommitGraphTests(unittest.TestCase):
         self.assertEqual(graph.get_parents(commit_oid), [])
         self.assertEqual(graph.get_parents(commit_oid), [])
         self.assertIsNotNone(graph.get_entry_by_oid(commit_oid))
         self.assertIsNotNone(graph.get_entry_by_oid(commit_oid))
 
 
-    def test_missing_required_chunks(self):
+    def test_missing_required_chunks(self) -> None:
         """Test error handling for missing required chunks."""
         """Test error handling for missing required chunks."""
         # Create data with header but no chunks
         # Create data with header but no chunks
         header = (
         header = (
@@ -254,7 +254,7 @@ class CommitGraphTests(unittest.TestCase):
             CommitGraph.from_file(f)
             CommitGraph.from_file(f)
         self.assertIn("Missing required OID lookup chunk", str(cm.exception))
         self.assertIn("Missing required OID lookup chunk", str(cm.exception))
 
 
-    def test_write_empty_graph_raises(self):
+    def test_write_empty_graph_raises(self) -> None:
         """Test that writing empty graph raises ValueError."""
         """Test that writing empty graph raises ValueError."""
         graph = CommitGraph()
         graph = CommitGraph()
         f = io.BytesIO()
         f = io.BytesIO()
@@ -262,7 +262,7 @@ class CommitGraphTests(unittest.TestCase):
         with self.assertRaises(ValueError):
         with self.assertRaises(ValueError):
             graph.write_to_file(f)
             graph.write_to_file(f)
 
 
-    def test_write_and_read_round_trip(self):
+    def test_write_and_read_round_trip(self) -> None:
         """Test writing and reading a commit graph."""
         """Test writing and reading a commit graph."""
         # Create a simple commit graph
         # Create a simple commit graph
         graph = CommitGraph()
         graph = CommitGraph()
@@ -297,21 +297,21 @@ class CommitGraphTests(unittest.TestCase):
 class CommitGraphFileOperationsTests(unittest.TestCase):
 class CommitGraphFileOperationsTests(unittest.TestCase):
     """Tests for commit graph file operations."""
     """Tests for commit graph file operations."""
 
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.tempdir = tempfile.mkdtemp()
         self.tempdir = tempfile.mkdtemp()
 
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         import shutil
         import shutil
 
 
         shutil.rmtree(self.tempdir, ignore_errors=True)
         shutil.rmtree(self.tempdir, ignore_errors=True)
 
 
-    def test_read_commit_graph_missing_file(self):
+    def test_read_commit_graph_missing_file(self) -> None:
         """Test reading from non-existent file."""
         """Test reading from non-existent file."""
         missing_path = os.path.join(self.tempdir, "missing.graph")
         missing_path = os.path.join(self.tempdir, "missing.graph")
         result = read_commit_graph(missing_path)
         result = read_commit_graph(missing_path)
         self.assertIsNone(result)
         self.assertIsNone(result)
 
 
-    def test_read_commit_graph_invalid_file(self):
+    def test_read_commit_graph_invalid_file(self) -> None:
         """Test reading from invalid file."""
         """Test reading from invalid file."""
         invalid_path = os.path.join(self.tempdir, "invalid.graph")
         invalid_path = os.path.join(self.tempdir, "invalid.graph")
         with open(invalid_path, "wb") as f:
         with open(invalid_path, "wb") as f:
@@ -320,12 +320,12 @@ class CommitGraphFileOperationsTests(unittest.TestCase):
         with self.assertRaises(ValueError):
         with self.assertRaises(ValueError):
             read_commit_graph(invalid_path)
             read_commit_graph(invalid_path)
 
 
-    def test_find_commit_graph_file_missing(self):
+    def test_find_commit_graph_file_missing(self) -> None:
         """Test finding commit graph file when it doesn't exist."""
         """Test finding commit graph file when it doesn't exist."""
         result = find_commit_graph_file(self.tempdir)
         result = find_commit_graph_file(self.tempdir)
         self.assertIsNone(result)
         self.assertIsNone(result)
 
 
-    def test_find_commit_graph_file_standard_location(self):
+    def test_find_commit_graph_file_standard_location(self) -> None:
         """Test finding commit graph file in standard location."""
         """Test finding commit graph file in standard location."""
         # Create .git/objects/info/commit-graph
         # Create .git/objects/info/commit-graph
         objects_dir = os.path.join(self.tempdir, "objects")
         objects_dir = os.path.join(self.tempdir, "objects")
@@ -339,7 +339,7 @@ class CommitGraphFileOperationsTests(unittest.TestCase):
         result = find_commit_graph_file(self.tempdir)
         result = find_commit_graph_file(self.tempdir)
         self.assertEqual(result, graph_path.encode())
         self.assertEqual(result, graph_path.encode())
 
 
-    def test_find_commit_graph_file_chain_location(self):
+    def test_find_commit_graph_file_chain_location(self) -> None:
         """Test finding commit graph file in chain location."""
         """Test finding commit graph file in chain location."""
         # Create .git/objects/info/commit-graphs/graph-{hash}.graph
         # Create .git/objects/info/commit-graphs/graph-{hash}.graph
         objects_dir = os.path.join(self.tempdir, "objects")
         objects_dir = os.path.join(self.tempdir, "objects")
@@ -354,7 +354,7 @@ class CommitGraphFileOperationsTests(unittest.TestCase):
         result = find_commit_graph_file(self.tempdir)
         result = find_commit_graph_file(self.tempdir)
         self.assertEqual(result, graph_path.encode())
         self.assertEqual(result, graph_path.encode())
 
 
-    def test_find_commit_graph_file_prefers_standard(self):
+    def test_find_commit_graph_file_prefers_standard(self) -> None:
         """Test that standard location is preferred over chain location."""
         """Test that standard location is preferred over chain location."""
         # Create both locations
         # Create both locations
         objects_dir = os.path.join(self.tempdir, "objects")
         objects_dir = os.path.join(self.tempdir, "objects")
@@ -380,15 +380,15 @@ class CommitGraphFileOperationsTests(unittest.TestCase):
 class CommitGraphGenerationTests(unittest.TestCase):
 class CommitGraphGenerationTests(unittest.TestCase):
     """Tests for commit graph generation functionality."""
     """Tests for commit graph generation functionality."""
 
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.tempdir = tempfile.mkdtemp()
         self.tempdir = tempfile.mkdtemp()
 
 
-    def tearDown(self):
+    def tearDown(self) -> None:
         import shutil
         import shutil
 
 
         shutil.rmtree(self.tempdir, ignore_errors=True)
         shutil.rmtree(self.tempdir, ignore_errors=True)
 
 
-    def test_generate_commit_graph_empty(self):
+    def test_generate_commit_graph_empty(self) -> None:
         """Test generating commit graph with no commits."""
         """Test generating commit graph with no commits."""
         from dulwich.object_store import MemoryObjectStore
         from dulwich.object_store import MemoryObjectStore
 
 
@@ -397,7 +397,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
 
 
         self.assertEqual(len(graph), 0)
         self.assertEqual(len(graph), 0)
 
 
-    def test_generate_commit_graph_single_commit(self):
+    def test_generate_commit_graph_single_commit(self) -> None:
         """Test generating commit graph with single commit."""
         """Test generating commit graph with single commit."""
         from dulwich.object_store import MemoryObjectStore
         from dulwich.object_store import MemoryObjectStore
         from dulwich.objects import Commit, Tree
         from dulwich.objects import Commit, Tree
@@ -428,7 +428,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
         self.assertEqual(entry.generation, 1)
         self.assertEqual(entry.generation, 1)
         self.assertEqual(entry.commit_time, 1234567890)
         self.assertEqual(entry.commit_time, 1234567890)
 
 
-    def test_get_reachable_commits(self):
+    def test_get_reachable_commits(self) -> None:
         """Test getting reachable commits."""
         """Test getting reachable commits."""
         from dulwich.object_store import MemoryObjectStore
         from dulwich.object_store import MemoryObjectStore
         from dulwich.objects import Commit, Tree
         from dulwich.objects import Commit, Tree
@@ -465,7 +465,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
         self.assertIn(commit1.id, reachable)
         self.assertIn(commit1.id, reachable)
         self.assertIn(commit2.id, reachable)
         self.assertIn(commit2.id, reachable)
 
 
-    def test_write_commit_graph_to_file(self):
+    def test_write_commit_graph_to_file(self) -> None:
         """Test writing commit graph to file."""
         """Test writing commit graph to file."""
         from dulwich.object_store import DiskObjectStore
         from dulwich.object_store import DiskObjectStore
         from dulwich.objects import Commit, Tree
         from dulwich.objects import Commit, Tree
@@ -498,13 +498,14 @@ class CommitGraphGenerationTests(unittest.TestCase):
         # Read back and verify
         # Read back and verify
         graph = read_commit_graph(graph_path)
         graph = read_commit_graph(graph_path)
         self.assertIsNotNone(graph)
         self.assertIsNotNone(graph)
+        assert graph is not None  # For mypy
         self.assertEqual(len(graph), 1)
         self.assertEqual(len(graph), 1)
 
 
         entry = graph.entries[0]
         entry = graph.entries[0]
         self.assertEqual(entry.commit_id, commit.id)
         self.assertEqual(entry.commit_id, commit.id)
         self.assertEqual(entry.tree_id, commit.tree)
         self.assertEqual(entry.tree_id, commit.tree)
 
 
-    def test_object_store_commit_graph_methods(self):
+    def test_object_store_commit_graph_methods(self) -> None:
         """Test ObjectStore commit graph methods."""
         """Test ObjectStore commit graph methods."""
         from dulwich.object_store import DiskObjectStore
         from dulwich.object_store import DiskObjectStore
         from dulwich.objects import Commit, Tree
         from dulwich.objects import Commit, Tree
@@ -515,7 +516,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
         object_store = DiskObjectStore(object_store_path)
         object_store = DiskObjectStore(object_store_path)
 
 
         # Initially no commit graph
         # Initially no commit graph
-        self.assertIsNone(object_store.get_commit_graph())
+        self.assertIsNone(object_store.get_commit_graph())  # type: ignore[no-untyped-call]
 
 
         # Create a tree and commit
         # Create a tree and commit
         tree = Tree()
         tree = Tree()
@@ -534,13 +535,13 @@ class CommitGraphGenerationTests(unittest.TestCase):
         object_store.write_commit_graph([commit.id], reachable=False)
         object_store.write_commit_graph([commit.id], reachable=False)
 
 
         # Now should have commit graph
         # Now should have commit graph
-        self.assertIsNotNone(object_store.get_commit_graph())
+        self.assertIsNotNone(object_store.get_commit_graph())  # type: ignore[no-untyped-call]
 
 
         # Test update (should still have commit graph)
         # Test update (should still have commit graph)
         object_store.write_commit_graph()
         object_store.write_commit_graph()
-        self.assertIsNot(None, object_store.get_commit_graph())
+        self.assertIsNot(None, object_store.get_commit_graph())  # type: ignore[no-untyped-call]
 
 
-    def test_parents_provider_commit_graph_integration(self):
+    def test_parents_provider_commit_graph_integration(self) -> None:
         """Test that ParentsProvider uses commit graph when available."""
         """Test that ParentsProvider uses commit graph when available."""
         from dulwich.object_store import DiskObjectStore
         from dulwich.object_store import DiskObjectStore
         from dulwich.objects import Commit, Tree
         from dulwich.objects import Commit, Tree
@@ -584,10 +585,10 @@ class CommitGraphGenerationTests(unittest.TestCase):
         self.assertIsNotNone(provider.commit_graph)
         self.assertIsNotNone(provider.commit_graph)
 
 
         # Test parent lookups
         # Test parent lookups
-        parents1 = provider.get_parents(commit1.id)
+        parents1 = provider.get_parents(commit1.id)  # type: ignore[no-untyped-call]
         self.assertEqual(parents1, [])
         self.assertEqual(parents1, [])
 
 
-        parents2 = provider.get_parents(commit2.id)
+        parents2 = provider.get_parents(commit2.id)  # type: ignore[no-untyped-call]
         self.assertEqual(parents2, [commit1.id])
         self.assertEqual(parents2, [commit1.id])
 
 
         # Test fallback behavior by creating provider without commit graph
         # Test fallback behavior by creating provider without commit graph
@@ -602,13 +603,13 @@ class CommitGraphGenerationTests(unittest.TestCase):
         self.assertIsNone(provider_no_graph.commit_graph)
         self.assertIsNone(provider_no_graph.commit_graph)
 
 
         # Should still work via commit object fallback
         # Should still work via commit object fallback
-        parents1_fallback = provider_no_graph.get_parents(commit1.id)
+        parents1_fallback = provider_no_graph.get_parents(commit1.id)  # type: ignore[no-untyped-call]
         self.assertEqual(parents1_fallback, [])
         self.assertEqual(parents1_fallback, [])
 
 
-        parents2_fallback = provider_no_graph.get_parents(commit2.id)
+        parents2_fallback = provider_no_graph.get_parents(commit2.id)  # type: ignore[no-untyped-call]
         self.assertEqual(parents2_fallback, [commit1.id])
         self.assertEqual(parents2_fallback, [commit1.id])
 
 
-    def test_graph_operations_use_commit_graph(self):
+    def test_graph_operations_use_commit_graph(self) -> None:
         """Test that graph operations use commit graph when available."""
         """Test that graph operations use commit graph when available."""
         from dulwich.graph import can_fast_forward, find_merge_base
         from dulwich.graph import can_fast_forward, find_merge_base
         from dulwich.object_store import DiskObjectStore
         from dulwich.object_store import DiskObjectStore
@@ -695,7 +696,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
         repo2.object_store = object_store
         repo2.object_store = object_store
 
 
         # Verify commit graph is available
         # Verify commit graph is available
-        commit_graph = repo2.object_store.get_commit_graph()
+        commit_graph = repo2.object_store.get_commit_graph()  # type: ignore[no-untyped-call]
         self.assertIsNotNone(commit_graph)
         self.assertIsNotNone(commit_graph)
 
 
         # Test graph operations WITH commit graph
         # Test graph operations WITH commit graph
@@ -732,13 +733,14 @@ class CommitGraphGenerationTests(unittest.TestCase):
         )
         )
 
 
         # Verify parent lookups work through the provider
         # Verify parent lookups work through the provider
-        self.assertEqual(parents_provider.get_parents(commit1.id), [])
-        self.assertEqual(parents_provider.get_parents(commit2.id), [commit1.id])
+        self.assertEqual(parents_provider.get_parents(commit1.id), [])  # type: ignore[no-untyped-call]
+        self.assertEqual(parents_provider.get_parents(commit2.id), [commit1.id])  # type: ignore[no-untyped-call]
         self.assertEqual(
         self.assertEqual(
-            parents_provider.get_parents(commit5.id), [commit3.id, commit4.id]
+            parents_provider.get_parents(commit5.id),
+            [commit3.id, commit4.id],  # type: ignore[no-untyped-call]
         )
         )
 
 
-    def test_performance_with_commit_graph(self):
+    def test_performance_with_commit_graph(self) -> None:
         """Test that using commit graph provides performance benefits."""
         """Test that using commit graph provides performance benefits."""
         from dulwich.graph import find_merge_base
         from dulwich.graph import find_merge_base
         from dulwich.object_store import DiskObjectStore
         from dulwich.object_store import DiskObjectStore
@@ -754,7 +756,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
         object_store.add_object(tree)
         object_store.add_object(tree)
 
 
         # Create a chain of 20 commits
         # Create a chain of 20 commits
-        commits = []
+        commits: list[Commit] = []
         for i in range(20):
         for i in range(20):
             commit = Commit()
             commit = Commit()
             commit.tree = tree.id
             commit.tree = tree.id
@@ -785,7 +787,7 @@ class CommitGraphGenerationTests(unittest.TestCase):
         repo2.object_store = object_store
         repo2.object_store = object_store
 
 
         # Verify commit graph is loaded
         # Verify commit graph is loaded
-        self.assertIsNotNone(repo2.object_store.get_commit_graph())
+        self.assertIsNotNone(repo2.object_store.get_commit_graph())  # type: ignore[no-untyped-call]
 
 
         # Time operations with commit graph
         # Time operations with commit graph
         for _ in range(10):  # Run multiple times for better measurement
         for _ in range(10):  # Run multiple times for better measurement

+ 46 - 35
tests/test_dumb.py

@@ -22,45 +22,53 @@
 """Tests for dumb HTTP git repositories."""
 """Tests for dumb HTTP git repositories."""
 
 
 import zlib
 import zlib
+from typing import Callable, Optional, Union
 from unittest import TestCase
 from unittest import TestCase
 from unittest.mock import Mock
 from unittest.mock import Mock
 
 
 from dulwich.dumb import DumbHTTPObjectStore, DumbRemoteHTTPRepo
 from dulwich.dumb import DumbHTTPObjectStore, DumbRemoteHTTPRepo
 from dulwich.errors import NotGitRepository
 from dulwich.errors import NotGitRepository
-from dulwich.objects import Blob, Commit, Tag, Tree, sha_to_hex
+from dulwich.objects import Blob, Commit, ShaFile, Tag, Tree, sha_to_hex
 
 
 
 
 class MockResponse:
 class MockResponse:
-    def __init__(self, status=200, content=b"", headers=None):
+    def __init__(
+        self,
+        status: int = 200,
+        content: bytes = b"",
+        headers: Optional[dict[str, str]] = None,
+    ) -> None:
         self.status = status
         self.status = status
         self.content = content
         self.content = content
         self.headers = headers or {}
         self.headers = headers or {}
         self.closed = False
         self.closed = False
 
 
-    def close(self):
+    def close(self) -> None:
         self.closed = True
         self.closed = True
 
 
 
 
 class DumbHTTPObjectStoreTests(TestCase):
 class DumbHTTPObjectStoreTests(TestCase):
     """Tests for DumbHTTPObjectStore."""
     """Tests for DumbHTTPObjectStore."""
 
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.base_url = "https://example.com/repo.git/"
         self.base_url = "https://example.com/repo.git/"
-        self.responses = {}
+        self.responses: dict[str, dict[str, Union[int, bytes]]] = {}
         self.store = DumbHTTPObjectStore(self.base_url, self._mock_http_request)
         self.store = DumbHTTPObjectStore(self.base_url, self._mock_http_request)
 
 
-    def _mock_http_request(self, url, headers):
+    def _mock_http_request(
+        self, url: str, headers: dict[str, str]
+    ) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
         """Mock HTTP request function."""
         """Mock HTTP request function."""
         if url in self.responses:
         if url in self.responses:
             resp_data = self.responses[url]
             resp_data = self.responses[url]
             resp = MockResponse(
             resp = MockResponse(
-                resp_data.get("status", 200), resp_data.get("content", b"")
+                int(resp_data.get("status", 200)), bytes(resp_data.get("content", b""))
             )
             )
             # Create a mock read function that behaves like urllib3's read
             # Create a mock read function that behaves like urllib3's read
             content = resp.content
             content = resp.content
             offset = [0]  # Use list to make it mutable in closure
             offset = [0]  # Use list to make it mutable in closure
 
 
-            def read_func(size=None):
+            def read_func(size: Optional[int] = None) -> bytes:
                 if offset[0] >= len(content):
                 if offset[0] >= len(content):
                     return b""
                     return b""
                 if size is None:
                 if size is None:
@@ -76,12 +84,12 @@ class DumbHTTPObjectStoreTests(TestCase):
             resp = MockResponse(404)
             resp = MockResponse(404)
             return resp, lambda size: b""
             return resp, lambda size: b""
 
 
-    def _add_response(self, path, content, status=200):
+    def _add_response(self, path: str, content: bytes, status: int = 200) -> None:
         """Add a mock response for a given path."""
         """Add a mock response for a given path."""
         url = self.base_url + path
         url = self.base_url + path
         self.responses[url] = {"status": status, "content": content}
         self.responses[url] = {"status": status, "content": content}
 
 
-    def _make_object(self, obj):
+    def _make_object(self, obj: ShaFile) -> bytes:
         """Create compressed git object data."""
         """Create compressed git object data."""
         type_name = {
         type_name = {
             Blob.type_num: b"blob",
             Blob.type_num: b"blob",
@@ -94,7 +102,7 @@ class DumbHTTPObjectStoreTests(TestCase):
         header = type_name + b" " + str(len(content)).encode() + b"\x00"
         header = type_name + b" " + str(len(content)).encode() + b"\x00"
         return zlib.compress(header + content)
         return zlib.compress(header + content)
 
 
-    def test_fetch_loose_object_blob(self):
+    def test_fetch_loose_object_blob(self) -> None:
         # Create a blob object
         # Create a blob object
         blob = Blob()
         blob = Blob()
         blob.data = b"Hello, world!"
         blob.data = b"Hello, world!"
@@ -109,32 +117,33 @@ class DumbHTTPObjectStoreTests(TestCase):
         self.assertEqual(Blob.type_num, type_num)
         self.assertEqual(Blob.type_num, type_num)
         self.assertEqual(b"Hello, world!", content)
         self.assertEqual(b"Hello, world!", content)
 
 
-    def test_fetch_loose_object_not_found(self):
+    def test_fetch_loose_object_not_found(self) -> None:
         hex_sha = b"1" * 40
         hex_sha = b"1" * 40
         self.assertRaises(KeyError, self.store._fetch_loose_object, hex_sha)
         self.assertRaises(KeyError, self.store._fetch_loose_object, hex_sha)
 
 
-    def test_fetch_loose_object_invalid_format(self):
+    def test_fetch_loose_object_invalid_format(self) -> None:
         sha = b"1" * 20
         sha = b"1" * 20
         hex_sha = sha_to_hex(sha)
         hex_sha = sha_to_hex(sha)
-        path = f"objects/{hex_sha[:2]}/{hex_sha[2:]}"
+        path = f"objects/{hex_sha[:2].decode('ascii')}/{hex_sha[2:].decode('ascii')}"
 
 
         # Add invalid compressed data
         # Add invalid compressed data
         self._add_response(path, b"invalid data")
         self._add_response(path, b"invalid data")
 
 
         self.assertRaises(Exception, self.store._fetch_loose_object, sha)
         self.assertRaises(Exception, self.store._fetch_loose_object, sha)
 
 
-    def test_load_packs_empty(self):
+    def test_load_packs_empty(self) -> None:
         # No packs file
         # No packs file
         self.store._load_packs()
         self.store._load_packs()
         self.assertEqual([], self.store._packs)
         self.assertEqual([], self.store._packs)
 
 
-    def test_load_packs_with_entries(self):
+    def test_load_packs_with_entries(self) -> None:
         packs_content = b"""P pack-1234567890abcdef1234567890abcdef12345678.pack
         packs_content = b"""P pack-1234567890abcdef1234567890abcdef12345678.pack
 P pack-abcdef1234567890abcdef1234567890abcdef12.pack
 P pack-abcdef1234567890abcdef1234567890abcdef12.pack
 """
 """
         self._add_response("objects/info/packs", packs_content)
         self._add_response("objects/info/packs", packs_content)
 
 
         self.store._load_packs()
         self.store._load_packs()
+        assert self.store._packs is not None
         self.assertEqual(2, len(self.store._packs))
         self.assertEqual(2, len(self.store._packs))
         self.assertEqual(
         self.assertEqual(
             "pack-1234567890abcdef1234567890abcdef12345678", self.store._packs[0][0]
             "pack-1234567890abcdef1234567890abcdef12345678", self.store._packs[0][0]
@@ -143,7 +152,7 @@ P pack-abcdef1234567890abcdef1234567890abcdef12.pack
             "pack-abcdef1234567890abcdef1234567890abcdef12", self.store._packs[1][0]
             "pack-abcdef1234567890abcdef1234567890abcdef12", self.store._packs[1][0]
         )
         )
 
 
-    def test_get_raw_from_cache(self):
+    def test_get_raw_from_cache(self) -> None:
         sha = b"1" * 40
         sha = b"1" * 40
         self.store._cached_objects[sha] = (Blob.type_num, b"cached content")
         self.store._cached_objects[sha] = (Blob.type_num, b"cached content")
 
 
@@ -151,7 +160,7 @@ P pack-abcdef1234567890abcdef1234567890abcdef12.pack
         self.assertEqual(Blob.type_num, type_num)
         self.assertEqual(Blob.type_num, type_num)
         self.assertEqual(b"cached content", content)
         self.assertEqual(b"cached content", content)
 
 
-    def test_contains_loose(self):
+    def test_contains_loose(self) -> None:
         # Create a blob object
         # Create a blob object
         blob = Blob()
         blob = Blob()
         blob.data = b"Test blob"
         blob.data = b"Test blob"
@@ -164,35 +173,37 @@ P pack-abcdef1234567890abcdef1234567890abcdef12.pack
         self.assertTrue(self.store.contains_loose(hex_sha))
         self.assertTrue(self.store.contains_loose(hex_sha))
         self.assertFalse(self.store.contains_loose(b"0" * 40))
         self.assertFalse(self.store.contains_loose(b"0" * 40))
 
 
-    def test_add_object_not_implemented(self):
+    def test_add_object_not_implemented(self) -> None:
         blob = Blob()
         blob = Blob()
         blob.data = b"test"
         blob.data = b"test"
         self.assertRaises(NotImplementedError, self.store.add_object, blob)
         self.assertRaises(NotImplementedError, self.store.add_object, blob)
 
 
-    def test_add_objects_not_implemented(self):
+    def test_add_objects_not_implemented(self) -> None:
         self.assertRaises(NotImplementedError, self.store.add_objects, [])
         self.assertRaises(NotImplementedError, self.store.add_objects, [])
 
 
 
 
 class DumbRemoteHTTPRepoTests(TestCase):
 class DumbRemoteHTTPRepoTests(TestCase):
     """Tests for DumbRemoteHTTPRepo."""
     """Tests for DumbRemoteHTTPRepo."""
 
 
-    def setUp(self):
+    def setUp(self) -> None:
         self.base_url = "https://example.com/repo.git/"
         self.base_url = "https://example.com/repo.git/"
-        self.responses = {}
+        self.responses: dict[str, dict[str, Union[int, bytes]]] = {}
         self.repo = DumbRemoteHTTPRepo(self.base_url, self._mock_http_request)
         self.repo = DumbRemoteHTTPRepo(self.base_url, self._mock_http_request)
 
 
-    def _mock_http_request(self, url, headers):
+    def _mock_http_request(
+        self, url: str, headers: dict[str, str]
+    ) -> tuple[MockResponse, Callable[[Optional[int]], bytes]]:
         """Mock HTTP request function."""
         """Mock HTTP request function."""
         if url in self.responses:
         if url in self.responses:
             resp_data = self.responses[url]
             resp_data = self.responses[url]
             resp = MockResponse(
             resp = MockResponse(
-                resp_data.get("status", 200), resp_data.get("content", b"")
+                int(resp_data.get("status", 200)), bytes(resp_data.get("content", b""))
             )
             )
             # Create a mock read function that behaves like urllib3's read
             # Create a mock read function that behaves like urllib3's read
             content = resp.content
             content = resp.content
             offset = [0]  # Use list to make it mutable in closure
             offset = [0]  # Use list to make it mutable in closure
 
 
-            def read_func(size=None):
+            def read_func(size: Optional[int] = None) -> bytes:
                 if offset[0] >= len(content):
                 if offset[0] >= len(content):
                     return b""
                     return b""
                 if size is None:
                 if size is None:
@@ -208,12 +219,12 @@ class DumbRemoteHTTPRepoTests(TestCase):
             resp = MockResponse(404)
             resp = MockResponse(404)
             return resp, lambda size: b""
             return resp, lambda size: b""
 
 
-    def _add_response(self, path, content, status=200):
+    def _add_response(self, path: str, content: bytes, status: int = 200) -> None:
         """Add a mock response for a given path."""
         """Add a mock response for a given path."""
         url = self.base_url + path
         url = self.base_url + path
         self.responses[url] = {"status": status, "content": content}
         self.responses[url] = {"status": status, "content": content}
 
 
-    def test_get_refs(self):
+    def test_get_refs(self) -> None:
         refs_content = b"""0123456789abcdef0123456789abcdef01234567\trefs/heads/master
         refs_content = b"""0123456789abcdef0123456789abcdef01234567\trefs/heads/master
 abcdef0123456789abcdef0123456789abcdef01\trefs/heads/develop
 abcdef0123456789abcdef0123456789abcdef01\trefs/heads/develop
 fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
 fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
@@ -235,10 +246,10 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
             refs[b"refs/tags/v1.0"],
             refs[b"refs/tags/v1.0"],
         )
         )
 
 
-    def test_get_refs_not_found(self):
+    def test_get_refs_not_found(self) -> None:
         self.assertRaises(NotGitRepository, self.repo.get_refs)
         self.assertRaises(NotGitRepository, self.repo.get_refs)
 
 
-    def test_get_peeled(self):
+    def test_get_peeled(self) -> None:
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         self._add_response("info/refs", refs_content)
         self._add_response("info/refs", refs_content)
 
 
@@ -246,19 +257,19 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         peeled = self.repo.get_peeled(b"refs/heads/master")
         peeled = self.repo.get_peeled(b"refs/heads/master")
         self.assertEqual(b"0123456789abcdef0123456789abcdef01234567", peeled)
         self.assertEqual(b"0123456789abcdef0123456789abcdef01234567", peeled)
 
 
-    def test_fetch_pack_data_no_wants(self):
+    def test_fetch_pack_data_no_wants(self) -> None:
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         self._add_response("info/refs", refs_content)
         self._add_response("info/refs", refs_content)
 
 
         graph_walker = Mock()
         graph_walker = Mock()
 
 
-        def determine_wants(refs):
+        def determine_wants(refs: dict[bytes, bytes]) -> list[bytes]:
             return []
             return []
 
 
         result = list(self.repo.fetch_pack_data(graph_walker, determine_wants))
         result = list(self.repo.fetch_pack_data(graph_walker, determine_wants))
         self.assertEqual([], result)
         self.assertEqual([], result)
 
 
-    def test_fetch_pack_data_with_blob(self):
+    def test_fetch_pack_data_with_blob(self) -> None:
         # Set up refs
         # Set up refs
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         refs_content = b"0123456789abcdef0123456789abcdef01234567\trefs/heads/master\n"
         self._add_response("info/refs", refs_content)
         self._add_response("info/refs", refs_content)
@@ -277,10 +288,10 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         graph_walker = Mock()
         graph_walker = Mock()
         graph_walker.ack.return_value = []  # No existing objects
         graph_walker.ack.return_value = []  # No existing objects
 
 
-        def determine_wants(refs):
+        def determine_wants(refs: dict[bytes, bytes]) -> list[bytes]:
             return [blob_sha]
             return [blob_sha]
 
 
-        def progress(msg):
+        def progress(msg: bytes) -> None:
             assert isinstance(msg, bytes)
             assert isinstance(msg, bytes)
 
 
         result = list(
         result = list(
@@ -290,6 +301,6 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         self.assertEqual(Blob.type_num, result[0].pack_type_num)
         self.assertEqual(Blob.type_num, result[0].pack_type_num)
         self.assertEqual([blob.as_raw_string()], result[0].obj_chunks)
         self.assertEqual([blob.as_raw_string()], result[0].obj_chunks)
 
 
-    def test_object_store_property(self):
+    def test_object_store_property(self) -> None:
         self.assertIsInstance(self.repo.object_store, DumbHTTPObjectStore)
         self.assertIsInstance(self.repo.object_store, DumbHTTPObjectStore)
         self.assertEqual(self.base_url, self.repo.object_store.base_url)
         self.assertEqual(self.base_url, self.repo.object_store.base_url)

+ 5 - 5
tests/test_pack.py

@@ -1465,7 +1465,7 @@ class DeltaChainIteratorTests(TestCase):
         entries = build_pack(f, [(REF_DELTA, (blob.id, b"blob1"))], store=self.store)
         entries = build_pack(f, [(REF_DELTA, (blob.id, b"blob1"))], store=self.store)
         pack_iter = self.make_pack_iter(f)
         pack_iter = self.make_pack_iter(f)
         self.assertEntriesMatch([0], entries, pack_iter)
         self.assertEntriesMatch([0], entries, pack_iter)
-        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs())
+        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs)
 
 
     def test_ext_ref_chain(self) -> None:
     def test_ext_ref_chain(self) -> None:
         (blob,) = self.store_blobs([b"blob"])
         (blob,) = self.store_blobs([b"blob"])
@@ -1480,7 +1480,7 @@ class DeltaChainIteratorTests(TestCase):
         )
         )
         pack_iter = self.make_pack_iter(f)
         pack_iter = self.make_pack_iter(f)
         self.assertEntriesMatch([1, 0], entries, pack_iter)
         self.assertEntriesMatch([1, 0], entries, pack_iter)
-        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs())
+        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs)
 
 
     def test_ext_ref_chain_degenerate(self) -> None:
     def test_ext_ref_chain_degenerate(self) -> None:
         # Test a degenerate case where the sender is sending a REF_DELTA
         # Test a degenerate case where the sender is sending a REF_DELTA
@@ -1500,7 +1500,7 @@ class DeltaChainIteratorTests(TestCase):
         )
         )
         pack_iter = self.make_pack_iter(f)
         pack_iter = self.make_pack_iter(f)
         self.assertEntriesMatch([0, 1], entries, pack_iter)
         self.assertEntriesMatch([0, 1], entries, pack_iter)
-        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs())
+        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs)
 
 
     def test_ext_ref_multiple_times(self) -> None:
     def test_ext_ref_multiple_times(self) -> None:
         (blob,) = self.store_blobs([b"blob"])
         (blob,) = self.store_blobs([b"blob"])
@@ -1515,7 +1515,7 @@ class DeltaChainIteratorTests(TestCase):
         )
         )
         pack_iter = self.make_pack_iter(f)
         pack_iter = self.make_pack_iter(f)
         self.assertEntriesMatch([0, 1], entries, pack_iter)
         self.assertEntriesMatch([0, 1], entries, pack_iter)
-        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs())
+        self.assertEqual([hex_to_sha(blob.id)], pack_iter.ext_refs)
 
 
     def test_multiple_ext_refs(self) -> None:
     def test_multiple_ext_refs(self) -> None:
         b1, b2 = self.store_blobs([b"foo", b"bar"])
         b1, b2 = self.store_blobs([b"foo", b"bar"])
@@ -1530,7 +1530,7 @@ class DeltaChainIteratorTests(TestCase):
         )
         )
         pack_iter = self.make_pack_iter(f)
         pack_iter = self.make_pack_iter(f)
         self.assertEntriesMatch([0, 1], entries, pack_iter)
         self.assertEntriesMatch([0, 1], entries, pack_iter)
-        self.assertEqual([hex_to_sha(b1.id), hex_to_sha(b2.id)], pack_iter.ext_refs())
+        self.assertEqual([hex_to_sha(b1.id), hex_to_sha(b2.id)], pack_iter.ext_refs)
 
 
     def test_bad_ext_ref_non_thin_pack(self) -> None:
     def test_bad_ext_ref_non_thin_pack(self) -> None:
         (blob,) = self.store_blobs([b"blob"])
         (blob,) = self.store_blobs([b"blob"])

+ 1 - 1
tests/test_protocol.py

@@ -36,10 +36,10 @@ from dulwich.protocol import (
     ack_type,
     ack_type,
     extract_capabilities,
     extract_capabilities,
     extract_want_line_capabilities,
     extract_want_line_capabilities,
-    filter_ref_prefix,
     pkt_line,
     pkt_line,
     pkt_seq,
     pkt_seq,
 )
 )
+from dulwich.refs import filter_ref_prefix
 
 
 from . import TestCase
 from . import TestCase
 
 

Kaikkia tiedostoja ei voida näyttää, sillä liian monta tiedostoa muuttui tässä diffissä