Jelmer Vernooij 5 месяцев назад
Родитель
Сommit
ced3f56a0a

+ 32 - 10
dulwich/annotate.py

@@ -27,12 +27,18 @@ Python's difflib.
 """
 
 import difflib
+from typing import TYPE_CHECKING, Optional, cast
 
 from dulwich.walk import (
     ORDER_DATE,
     Walker,
 )
 
+if TYPE_CHECKING:
+    from dulwich.diff_tree import TreeChange, TreeEntry
+    from dulwich.object_store import BaseObjectStore
+    from dulwich.objects import Blob, Commit
+
 # Walk over ancestry graph breadth-first
 # When checking each revision, find lines that according to difflib.Differ()
 # are common between versions.
@@ -41,9 +47,13 @@ from dulwich.walk import (
 # graph.
 
 
-def update_lines(annotated_lines, new_history_data, new_blob):
+def update_lines(
+    annotated_lines: list[tuple[tuple["Commit", "TreeEntry"], bytes]],
+    new_history_data: tuple["Commit", "TreeEntry"],
+    new_blob: "Blob",
+) -> list[tuple[tuple["Commit", "TreeEntry"], bytes]]:
     """Update annotation lines with old blob lines."""
-    ret = []
+    ret: list[tuple[tuple[Commit, TreeEntry], bytes]] = []
     new_lines = new_blob.splitlines()
     matcher = difflib.SequenceMatcher(
         a=[line for (h, line) in annotated_lines], b=new_lines
@@ -60,7 +70,14 @@ def update_lines(annotated_lines, new_history_data, new_blob):
     return ret
 
 
-def annotate_lines(store, commit_id, path, order=ORDER_DATE, lines=None, follow=True):
+def annotate_lines(
+    store: "BaseObjectStore",
+    commit_id: bytes,
+    path: bytes,
+    order: str = ORDER_DATE,
+    lines: Optional[list[tuple[tuple["Commit", "TreeEntry"], bytes]]] = None,
+    follow: bool = True,
+) -> list[tuple[tuple["Commit", "TreeEntry"], bytes]]:
     """Annotate the lines of a blob.
 
     :param store: Object store to retrieve objects from
@@ -75,18 +92,23 @@ def annotate_lines(store, commit_id, path, order=ORDER_DATE, lines=None, follow=
     walker = Walker(
         store, include=[commit_id], paths=[path], order=order, follow=follow
     )
-    revs = []
+    revs: list[tuple[Commit, TreeEntry]] = []
     for log_entry in walker:
         for tree_change in log_entry.changes():
-            if type(tree_change) is not list:
-                tree_change = [tree_change]
-            for change in tree_change:
+            changes: list[TreeChange]
+            if isinstance(tree_change, list):
+                changes = tree_change
+            else:
+                changes = [tree_change]
+            for change in changes:
                 if change.new.path == path:
                     path = change.old.path
                     revs.append((log_entry.commit, change.new))
                     break
 
-    lines = []
+    lines_annotated: list[tuple[tuple[Commit, TreeEntry], bytes]] = []
     for commit, entry in reversed(revs):
-        lines = update_lines(lines, (commit, entry), store[entry.sha])
-    return lines
+        lines_annotated = update_lines(
+            lines_annotated, (commit, entry), cast("Blob", store[entry.sha])
+        )
+    return lines_annotated

+ 3 - 3
dulwich/attrs.py

@@ -23,7 +23,7 @@
 
 import os
 import re
-from collections.abc import Generator, Mapping
+from collections.abc import Generator, Iterator, Mapping
 from typing import (
     IO,
     Optional,
@@ -168,7 +168,7 @@ class Pattern:
         self._regex: Optional[re.Pattern[bytes]] = None
         self._compile()
 
-    def _compile(self):
+    def _compile(self) -> None:
         """Compile the pattern to a regular expression."""
         regex_pattern = _translate_pattern(self.pattern)
         # Add anchors
@@ -305,7 +305,7 @@ class GitAttributes:
         """Return the number of patterns."""
         return len(self._patterns)
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[tuple["Pattern", Mapping[bytes, AttributeValue]]]:
         """Iterate over patterns."""
         return iter(self._patterns)
 

+ 2 - 1
dulwich/bisect.py

@@ -408,7 +408,8 @@ class BisectState:
         obj = self.repo.object_store[sha]
         if isinstance(obj, Commit):
             message = obj.message.decode("utf-8", errors="replace")
-            return message.split("\n")[0]
+            lines = message.split("\n")
+            return lines[0] if lines else ""
         return ""
 
     def _append_to_log(self, line: str) -> None:

+ 3 - 1
dulwich/commit_graph.py

@@ -21,6 +21,8 @@ import struct
 from collections.abc import Iterator
 from typing import TYPE_CHECKING, BinaryIO, Optional, Union
 
+from .file import _GitFile
+
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
 
@@ -269,7 +271,7 @@ class CommitGraph:
         entry = self.get_entry_by_oid(oid)
         return entry.parents if entry else None
 
-    def write_to_file(self, f: BinaryIO) -> None:
+    def write_to_file(self, f: Union[BinaryIO, _GitFile]) -> None:
         """Write commit graph to file."""
         if not self.entries:
             raise ValueError("Cannot write empty commit graph")

+ 3 - 2
dulwich/config.py

@@ -40,6 +40,7 @@ from collections.abc import (
 from contextlib import suppress
 from pathlib import Path
 from typing import (
+    IO,
     Any,
     BinaryIO,
     Callable,
@@ -50,7 +51,7 @@ from typing import (
     overload,
 )
 
-from .file import GitFile
+from .file import GitFile, _GitFile
 
 ConfigKey = Union[str, bytes, tuple[Union[str, bytes], ...]]
 ConfigValue = Union[str, bytes, bool, int]
@@ -1112,7 +1113,7 @@ class ConfigFile(ConfigDict):
         with GitFile(path_to_use, "wb") as f:
             self.write_to_file(f)
 
-    def write_to_file(self, f: BinaryIO) -> None:
+    def write_to_file(self, f: Union[IO[bytes], _GitFile]) -> None:
         """Write configuration to a file-like object."""
         for section, values in self._values.items():
             try:

+ 38 - 23
dulwich/errors.py

@@ -27,20 +27,32 @@
 # that raises the error.
 
 import binascii
+from typing import Optional, Union
 
 
 class ChecksumMismatch(Exception):
     """A checksum didn't match the expected contents."""
 
-    def __init__(self, expected, got, extra=None) -> None:
-        if len(expected) == 20:
-            expected = binascii.hexlify(expected)
-        if len(got) == 20:
-            got = binascii.hexlify(got)
-        self.expected = expected
-        self.got = got
+    def __init__(
+        self,
+        expected: Union[bytes, str],
+        got: Union[bytes, str],
+        extra: Optional[str] = None,
+    ) -> None:
+        if isinstance(expected, bytes) and len(expected) == 20:
+            expected_str = binascii.hexlify(expected).decode("ascii")
+        else:
+            expected_str = (
+                expected if isinstance(expected, str) else expected.decode("ascii")
+            )
+        if isinstance(got, bytes) and len(got) == 20:
+            got_str = binascii.hexlify(got).decode("ascii")
+        else:
+            got_str = got if isinstance(got, str) else got.decode("ascii")
+        self.expected = expected_str
+        self.got = got_str
         self.extra = extra
-        message = f"Checksum mismatch: Expected {expected}, got {got}"
+        message = f"Checksum mismatch: Expected {expected_str}, got {got_str}"
         if self.extra is not None:
             message += f"; {extra}"
         Exception.__init__(self, message)
@@ -57,8 +69,8 @@ class WrongObjectException(Exception):
 
     type_name: str
 
-    def __init__(self, sha, *args, **kwargs) -> None:
-        Exception.__init__(self, f"{sha} is not a {self.type_name}")
+    def __init__(self, sha: bytes, *args: object, **kwargs: object) -> None:
+        Exception.__init__(self, f"{sha.decode('ascii')} is not a {self.type_name}")
 
 
 class NotCommitError(WrongObjectException):
@@ -88,40 +100,40 @@ class NotBlobError(WrongObjectException):
 class MissingCommitError(Exception):
     """Indicates that a commit was not found in the repository."""
 
-    def __init__(self, sha, *args, **kwargs) -> None:
+    def __init__(self, sha: bytes, *args: object, **kwargs: object) -> None:
         self.sha = sha
-        Exception.__init__(self, f"{sha} is not in the revision store")
+        Exception.__init__(self, f"{sha.decode('ascii')} is not in the revision store")
 
 
 class ObjectMissing(Exception):
     """Indicates that a requested object is missing."""
 
-    def __init__(self, sha, *args, **kwargs) -> None:
-        Exception.__init__(self, f"{sha} is not in the pack")
+    def __init__(self, sha: bytes, *args: object, **kwargs: object) -> None:
+        Exception.__init__(self, f"{sha.decode('ascii')} is not in the pack")
 
 
 class ApplyDeltaError(Exception):
     """Indicates that applying a delta failed."""
 
-    def __init__(self, *args, **kwargs) -> None:
+    def __init__(self, *args: object, **kwargs: object) -> None:
         Exception.__init__(self, *args, **kwargs)
 
 
 class NotGitRepository(Exception):
     """Indicates that no Git repository was found."""
 
-    def __init__(self, *args, **kwargs) -> None:
+    def __init__(self, *args: object, **kwargs: object) -> None:
         Exception.__init__(self, *args, **kwargs)
 
 
 class GitProtocolError(Exception):
     """Git protocol exception."""
 
-    def __init__(self, *args, **kwargs) -> None:
+    def __init__(self, *args: object, **kwargs: object) -> None:
         Exception.__init__(self, *args, **kwargs)
 
-    def __eq__(self, other):
-        return isinstance(self, type(other)) and self.args == other.args
+    def __eq__(self, other: object) -> bool:
+        return isinstance(other, GitProtocolError) and self.args == other.args
 
 
 class SendPackError(GitProtocolError):
@@ -131,7 +143,7 @@ class SendPackError(GitProtocolError):
 class HangupException(GitProtocolError):
     """Hangup exception."""
 
-    def __init__(self, stderr_lines=None) -> None:
+    def __init__(self, stderr_lines: Optional[list[bytes]] = None) -> None:
         if stderr_lines:
             super().__init__(
                 "\n".join(
@@ -142,14 +154,17 @@ class HangupException(GitProtocolError):
             super().__init__("The remote server unexpectedly closed the connection.")
         self.stderr_lines = stderr_lines
 
-    def __eq__(self, other):
-        return isinstance(self, type(other)) and self.stderr_lines == other.stderr_lines
+    def __eq__(self, other: object) -> bool:
+        return (
+            isinstance(other, HangupException)
+            and self.stderr_lines == other.stderr_lines
+        )
 
 
 class UnexpectedCommandError(GitProtocolError):
     """Unexpected command received in a proto line."""
 
-    def __init__(self, command) -> None:
+    def __init__(self, command: Optional[str]) -> None:
         command_str = "flush-pkt" if command is None else f"command {command}"
         super().__init__(f"Protocol got unexpected {command_str}")
 

+ 58 - 11
dulwich/file.py

@@ -24,10 +24,12 @@
 import os
 import sys
 import warnings
-from typing import ClassVar, Union
+from collections.abc import Iterator
+from types import TracebackType
+from typing import IO, Any, ClassVar, Literal, Optional, Union, overload
 
 
-def ensure_dir_exists(dirname) -> None:
+def ensure_dir_exists(dirname: Union[str, bytes, os.PathLike]) -> None:
     """Ensure a directory exists, creating if necessary."""
     try:
         os.makedirs(dirname)
@@ -35,7 +37,7 @@ def ensure_dir_exists(dirname) -> None:
         pass
 
 
-def _fancy_rename(oldname, newname) -> None:
+def _fancy_rename(oldname: Union[str, bytes], newname: Union[str, bytes]) -> None:
     """Rename file with temporary backup file to rollback if rename fails."""
     if not os.path.exists(newname):
         os.rename(oldname, newname)
@@ -45,7 +47,7 @@ def _fancy_rename(oldname, newname) -> None:
     import tempfile
 
     # destination file exists
-    (fd, tmpfile) = tempfile.mkstemp(".tmp", prefix=oldname, dir=".")
+    (fd, tmpfile) = tempfile.mkstemp(".tmp", prefix=str(oldname), dir=".")
     os.close(fd)
     os.remove(tmpfile)
     os.rename(newname, tmpfile)
@@ -57,9 +59,39 @@ def _fancy_rename(oldname, newname) -> None:
     os.remove(tmpfile)
 
 
+@overload
 def GitFile(
-    filename: Union[str, bytes, os.PathLike], mode="rb", bufsize=-1, mask=0o644
-):
+    filename: Union[str, bytes, os.PathLike],
+    mode: Literal["wb"],
+    bufsize: int = -1,
+    mask: int = 0o644,
+) -> "_GitFile": ...
+
+
+@overload
+def GitFile(
+    filename: Union[str, bytes, os.PathLike],
+    mode: Literal["rb"] = "rb",
+    bufsize: int = -1,
+    mask: int = 0o644,
+) -> IO[bytes]: ...
+
+
+@overload
+def GitFile(
+    filename: Union[str, bytes, os.PathLike],
+    mode: str = "rb",
+    bufsize: int = -1,
+    mask: int = 0o644,
+) -> Union[IO[bytes], "_GitFile"]: ...
+
+
+def GitFile(
+    filename: Union[str, bytes, os.PathLike],
+    mode: str = "rb",
+    bufsize: int = -1,
+    mask: int = 0o644,
+) -> Union[IO[bytes], "_GitFile"]:
     """Create a file object that obeys the git file locking protocol.
 
     Returns: a builtin file object or a _GitFile object
@@ -90,7 +122,9 @@ def GitFile(
 class FileLocked(Exception):
     """File is already locked."""
 
-    def __init__(self, filename, lockfilename) -> None:
+    def __init__(
+        self, filename: Union[str, bytes, os.PathLike], lockfilename: Union[str, bytes]
+    ) -> None:
         self.filename = filename
         self.lockfilename = lockfilename
         super().__init__(filename, lockfilename)
@@ -132,7 +166,11 @@ class _GitFile:
     }
 
     def __init__(
-        self, filename: Union[str, bytes, os.PathLike], mode, bufsize, mask
+        self,
+        filename: Union[str, bytes, os.PathLike],
+        mode: str,
+        bufsize: int,
+        mask: int,
     ) -> None:
         # Convert PathLike to str/bytes for our internal use
         self._filename: Union[str, bytes] = os.fspath(filename)
@@ -154,6 +192,10 @@ class _GitFile:
         for method in self.PROXY_METHODS:
             setattr(self, method, getattr(self._file, method))
 
+    def __iter__(self) -> Iterator[bytes]:
+        """Iterate over lines in the file."""
+        return iter(self._file)
+
     def abort(self) -> None:
         """Close and discard the lockfile without overwriting the target.
 
@@ -205,16 +247,21 @@ class _GitFile:
             warnings.warn(f"unclosed {self!r}", ResourceWarning, stacklevel=2)
             self.abort()
 
-    def __enter__(self):
+    def __enter__(self) -> "_GitFile":
         return self
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(
+        self,
+        exc_type: Optional[type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
         if exc_type is not None:
             self.abort()
         else:
             self.close()
 
-    def __getattr__(self, name):
+    def __getattr__(self, name: str) -> Any:  # noqa: ANN401
         """Proxy property calls to the underlying file."""
         if name in self.PROXY_PROPERTIES:
             return getattr(self._file, name)

+ 5 - 2
dulwich/filters.py

@@ -30,6 +30,7 @@ from .objects import Blob
 
 if TYPE_CHECKING:
     from .config import StackedConfig
+    from .repo import Repo
 
 
 class FilterError(Exception):
@@ -118,7 +119,9 @@ class ProcessFilterDriver:
 class FilterRegistry:
     """Registry for filter drivers."""
 
-    def __init__(self, config: Optional["StackedConfig"] = None, repo=None) -> None:
+    def __init__(
+        self, config: Optional["StackedConfig"] = None, repo: Optional["Repo"] = None
+    ) -> None:
         self.config = config
         self.repo = repo
         self._drivers: dict[str, FilterDriver] = {}
@@ -372,7 +375,7 @@ class FilterBlobNormalizer:
         config_stack: Optional["StackedConfig"],
         gitattributes: GitAttributes,
         filter_registry: Optional[FilterRegistry] = None,
-        repo=None,
+        repo: Optional["Repo"] = None,
     ) -> None:
         self.config_stack = config_stack
         self.gitattributes = gitattributes

+ 17 - 15
dulwich/hooks.py

@@ -23,6 +23,7 @@
 
 import os
 import subprocess
+from typing import Any, Callable, Optional
 
 from .errors import HookError
 
@@ -30,7 +31,7 @@ from .errors import HookError
 class Hook:
     """Generic hook object."""
 
-    def execute(self, *args):
+    def execute(self, *args: Any) -> Any:  # noqa: ANN401
         """Execute the hook with the given args.
 
         Args:
@@ -53,12 +54,12 @@ class ShellHook(Hook):
 
     def __init__(
         self,
-        name,
-        path,
-        numparam,
-        pre_exec_callback=None,
-        post_exec_callback=None,
-        cwd=None,
+        name: str,
+        path: str,
+        numparam: int,
+        pre_exec_callback: Optional[Callable[..., Any]] = None,
+        post_exec_callback: Optional[Callable[..., Any]] = None,
+        cwd: Optional[str] = None,
     ) -> None:
         """Setup shell hook definition.
 
@@ -85,7 +86,7 @@ class ShellHook(Hook):
 
         self.cwd = cwd
 
-    def execute(self, *args):
+    def execute(self, *args: Any) -> Any:  # noqa: ANN401
         """Execute the hook with given args."""
         if len(args) != self.numparam:
             raise HookError(
@@ -113,7 +114,7 @@ class ShellHook(Hook):
 class PreCommitShellHook(ShellHook):
     """pre-commit shell hook."""
 
-    def __init__(self, cwd, controldir) -> None:
+    def __init__(self, cwd: str, controldir: str) -> None:
         filepath = os.path.join(controldir, "hooks", "pre-commit")
 
         ShellHook.__init__(self, "pre-commit", filepath, 0, cwd=cwd)
@@ -122,7 +123,7 @@ class PreCommitShellHook(ShellHook):
 class PostCommitShellHook(ShellHook):
     """post-commit shell hook."""
 
-    def __init__(self, controldir) -> None:
+    def __init__(self, controldir: str) -> None:
         filepath = os.path.join(controldir, "hooks", "post-commit")
 
         ShellHook.__init__(self, "post-commit", filepath, 0, cwd=controldir)
@@ -131,10 +132,10 @@ class PostCommitShellHook(ShellHook):
 class CommitMsgShellHook(ShellHook):
     """commit-msg shell hook."""
 
-    def __init__(self, controldir) -> None:
+    def __init__(self, controldir: str) -> None:
         filepath = os.path.join(controldir, "hooks", "commit-msg")
 
-        def prepare_msg(*args):
+        def prepare_msg(*args: bytes) -> tuple[str, ...]:
             import tempfile
 
             (fd, path) = tempfile.mkstemp()
@@ -144,13 +145,14 @@ class CommitMsgShellHook(ShellHook):
 
             return (path,)
 
-        def clean_msg(success, *args):
+        def clean_msg(success: int, *args: str) -> Optional[bytes]:
             if success:
                 with open(args[0], "rb") as f:
                     new_msg = f.read()
                 os.unlink(args[0])
                 return new_msg
             os.unlink(args[0])
+            return None
 
         ShellHook.__init__(
             self, "commit-msg", filepath, 1, prepare_msg, clean_msg, controldir
@@ -160,12 +162,12 @@ class CommitMsgShellHook(ShellHook):
 class PostReceiveShellHook(ShellHook):
     """post-receive shell hook."""
 
-    def __init__(self, controldir) -> None:
+    def __init__(self, controldir: str) -> None:
         self.controldir = controldir
         filepath = os.path.join(controldir, "hooks", "post-receive")
         ShellHook.__init__(self, "post-receive", path=filepath, numparam=0)
 
-    def execute(self, client_refs):
+    def execute(self, client_refs: list[tuple[bytes, bytes, bytes]]) -> Optional[bytes]:
         # do nothing if the script doesn't exist
         if not os.path.exists(self.filepath):
             return None

+ 3 - 1
dulwich/index.py

@@ -2448,7 +2448,9 @@ class locked_index:
         self._path = path
 
     def __enter__(self) -> Index:
-        self._file = GitFile(self._path, "wb")
+        f = GitFile(self._path, "wb")
+        assert isinstance(f, _GitFile)  # GitFile in write mode always returns _GitFile
+        self._file = f
         self._index = Index(self._path)
         return self._index
 

+ 8 - 6
dulwich/lfs.py

@@ -31,6 +31,8 @@ from urllib.parse import urljoin, urlparse
 from urllib.request import Request, urlopen
 
 if TYPE_CHECKING:
+    import urllib3
+
     from .config import Config
     from .repo import Repo
 
@@ -289,7 +291,7 @@ class LFSFilterDriver:
         return content
 
 
-def _get_lfs_user_agent(config):
+def _get_lfs_user_agent(config: Optional["Config"]) -> str:
     """Get User-Agent string for LFS requests, respecting git config."""
     try:
         if config:
@@ -317,7 +319,7 @@ class LFSClient:
         """
         self._base_url = url.rstrip("/") + "/"  # Ensure trailing slash for urljoin
         self.config = config
-        self._pool_manager = None
+        self._pool_manager: Optional[urllib3.PoolManager] = None
 
     @classmethod
     def from_config(cls, config: "Config") -> Optional["LFSClient"]:
@@ -365,12 +367,12 @@ class LFSClient:
         """Get the LFS server URL without trailing slash."""
         return self._base_url.rstrip("/")
 
-    def _get_pool_manager(self):
+    def _get_pool_manager(self) -> "urllib3.PoolManager":
         """Get urllib3 pool manager with git config applied."""
         if self._pool_manager is None:
             from dulwich.client import default_urllib3_manager
 
-            self._pool_manager = default_urllib3_manager(self.config)
+            self._pool_manager = default_urllib3_manager(self.config)  # type: ignore[assignment]
         return self._pool_manager
 
     def _make_request(
@@ -397,7 +399,7 @@ class LFSClient:
             raise ValueError(
                 f"HTTP {response.status}: {response.data.decode('utf-8', errors='ignore')}"
             )
-        return response.data
+        return response.data  # type: ignore[return-value]
 
     def batch(
         self,
@@ -513,7 +515,7 @@ class LFSClient:
         if actual_oid != oid:
             raise LFSError(f"Downloaded OID {actual_oid} != expected {oid}")
 
-        return content
+        return content  # type: ignore[return-value]
 
     def upload(
         self, oid: str, size: int, content: bytes, ref: Optional[str] = None

+ 4 - 1
dulwich/object_store.py

@@ -41,7 +41,7 @@ from typing import (
 )
 
 from .errors import NotTreeError
-from .file import GitFile
+from .file import GitFile, _GitFile
 from .objects import (
     S_ISGITLINK,
     ZERO_SHA,
@@ -1505,6 +1505,9 @@ class DiskObjectStore(PackBasedObjectStore):
                 # Write using GitFile for atomic operation
                 graph_path = os.path.join(info_dir, "commit-graph")
                 with GitFile(graph_path, "wb") as f:
+                    assert isinstance(
+                        f, _GitFile
+                    )  # GitFile in write mode always returns _GitFile
                     graph.write_to_file(f)
 
             # Clear cached commit graph so it gets reloaded

+ 15 - 8
dulwich/objects.py

@@ -45,7 +45,7 @@ else:
     from typing_extensions import Self
 
 if sys.version_info >= (3, 10):
-    from typing import TypeGuard  # type: ignore
+    from typing import TypeGuard
 else:
     from typing_extensions import TypeGuard
 
@@ -244,6 +244,13 @@ def check_identity(identity: Optional[bytes], error_msg: str) -> None:
         raise ObjectFormatException(error_msg)
 
 
+def _path_to_bytes(path: Union[str, bytes]) -> bytes:
+    """Convert a path to bytes for use in error messages."""
+    if isinstance(path, str):
+        return path.encode("utf-8", "surrogateescape")
+    return path
+
+
 def check_time(time_seconds: int) -> None:
     """Check if the specified time is not prone to overflow error.
 
@@ -270,7 +277,7 @@ class FixedSha:
 
     def __init__(self, hexsha: Union[str, bytes]) -> None:
         if isinstance(hexsha, str):
-            hexsha = hexsha.encode("ascii")  # type: ignore
+            hexsha = hexsha.encode("ascii")
         if not isinstance(hexsha, bytes):
             raise TypeError(f"Expected bytes for hexsha, got {hexsha!r}")
         self._hexsha = hexsha
@@ -432,7 +439,7 @@ class ShaFile:
         if sha is None:
             self._sha = None
         else:
-            self._sha = FixedSha(sha)  # type: ignore
+            self._sha = FixedSha(sha)
         self._needs_serialization = False
 
     @staticmethod
@@ -686,7 +693,7 @@ class Blob(ShaFile):
     def from_path(cls, path: Union[str, bytes]) -> "Blob":
         blob = ShaFile.from_path(path)
         if not isinstance(blob, cls):
-            raise NotBlobError(path)
+            raise NotBlobError(_path_to_bytes(path))
         return blob
 
     def check(self) -> None:
@@ -706,7 +713,7 @@ class Blob(ShaFile):
         if not chunks:
             return []
         if len(chunks) == 1:
-            return chunks[0].splitlines(True)
+            return chunks[0].splitlines(True)  # type: ignore[no-any-return]
         remaining = None
         ret = []
         for chunk in chunks:
@@ -834,7 +841,7 @@ class Tag(ShaFile):
     def from_path(cls, filename: Union[str, bytes]) -> "Tag":
         tag = ShaFile.from_path(filename)
         if not isinstance(tag, cls):
-            raise NotTagError(filename)
+            raise NotTagError(_path_to_bytes(filename))
         return tag
 
     def check(self) -> None:
@@ -1186,7 +1193,7 @@ class Tree(ShaFile):
     def from_path(cls, filename: Union[str, bytes]) -> "Tree":
         tree = ShaFile.from_path(filename)
         if not isinstance(tree, cls):
-            raise NotTreeError(filename)
+            raise NotTreeError(_path_to_bytes(filename))
         return tree
 
     def __contains__(self, name: bytes) -> bool:
@@ -1536,7 +1543,7 @@ class Commit(ShaFile):
     def from_path(cls, path: Union[str, bytes]) -> "Commit":
         commit = ShaFile.from_path(path)
         if not isinstance(commit, cls):
-            raise NotCommitError(path)
+            raise NotCommitError(_path_to_bytes(path))
         return commit
 
     def _deserialize(self, chunks: list[bytes]) -> None:

+ 89 - 48
dulwich/pack.py

@@ -54,6 +54,9 @@ from itertools import chain
 from os import SEEK_CUR, SEEK_END
 from struct import unpack_from
 from typing import (
+    IO,
+    TYPE_CHECKING,
+    Any,
     BinaryIO,
     Callable,
     Generic,
@@ -70,13 +73,16 @@ except ImportError:
 else:
     has_mmap = True
 
+if TYPE_CHECKING:
+    from .commit_graph import CommitGraph
+
 # For some reason the above try, except fails to set has_mmap = False for plan9
 if sys.platform == "Plan9":
     has_mmap = False
 
 from . import replace_me
 from .errors import ApplyDeltaError, ChecksumMismatch
-from .file import GitFile
+from .file import GitFile, _GitFile
 from .lru_cache import LRUSizeCache
 from .objects import ObjectID, ShaFile, hex_to_sha, object_header, sha_to_hex
 
@@ -104,7 +110,7 @@ PackHint = tuple[int, Optional[bytes]]
 class UnresolvedDeltas(Exception):
     """Delta objects could not be resolved."""
 
-    def __init__(self, shas) -> None:
+    def __init__(self, shas: list[bytes]) -> None:
         self.shas = shas
 
 
@@ -129,7 +135,7 @@ class ObjectContainer(Protocol):
     def __getitem__(self, sha1: bytes) -> ShaFile:
         """Retrieve an object."""
 
-    def get_commit_graph(self):
+    def get_commit_graph(self) -> Optional["CommitGraph"]:
         """Get the commit graph for this object store.
 
         Returns:
@@ -186,7 +192,7 @@ def take_msb_bytes(
 
 
 class PackFileDisappeared(Exception):
-    def __init__(self, obj) -> None:
+    def __init__(self, obj: object) -> None:
         self.obj = obj
 
 
@@ -219,19 +225,24 @@ class UnpackedObject:
     delta_base: Union[None, bytes, int]
     decomp_chunks: list[bytes]
     comp_chunks: Optional[list[bytes]]
+    decomp_len: Optional[int]
+    crc32: Optional[int]
+    offset: Optional[int]
+    pack_type_num: int
+    _sha: Optional[bytes]
 
     # TODO(dborowitz): read_zlib_chunks and unpack_object could very well be
     # methods of this object.
     def __init__(
         self,
-        pack_type_num,
+        pack_type_num: int,
         *,
-        delta_base=None,
-        decomp_len=None,
-        crc32=None,
-        sha=None,
-        decomp_chunks=None,
-        offset=None,
+        delta_base: Union[None, bytes, int] = None,
+        decomp_len: Optional[int] = None,
+        crc32: Optional[int] = None,
+        sha: Optional[bytes] = None,
+        decomp_chunks: Optional[list[bytes]] = None,
+        offset: Optional[int] = None,
     ) -> None:
         self.offset = offset
         self._sha = sha
@@ -253,13 +264,13 @@ class UnpackedObject:
             self.obj_chunks = self.decomp_chunks
             self.delta_base = delta_base
 
-    def sha(self):
+    def sha(self) -> bytes:
         """Return the binary SHA of this object."""
         if self._sha is None:
             self._sha = obj_sha(self.obj_type_num, self.obj_chunks)
         return self._sha
 
-    def sha_file(self):
+    def sha_file(self) -> ShaFile:
         """Return a ShaFile from this object."""
         assert self.obj_type_num is not None and self.obj_chunks is not None
         return ShaFile.from_raw_chunks(self.obj_type_num, self.obj_chunks)
@@ -274,7 +285,7 @@ class UnpackedObject:
         else:
             return self.decomp_chunks
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         if not isinstance(other, UnpackedObject):
             return False
         for slot in self.__slots__:
@@ -282,7 +293,7 @@ class UnpackedObject:
                 return False
         return True
 
-    def __ne__(self, other):
+    def __ne__(self, other: object) -> bool:
         return not (self == other)
 
     def __repr__(self) -> str:
@@ -322,7 +333,7 @@ def read_zlib_chunks(
     Raises:
       zlib.error: if a decompression error occurred.
     """
-    if unpacked.decomp_len <= -1:
+    if unpacked.decomp_len is None or unpacked.decomp_len <= -1:
         raise ValueError("non-negative zlib data stream size expected")
     decomp_obj = zlib.decompressobj()
 
@@ -361,7 +372,7 @@ def read_zlib_chunks(
     return unused
 
 
-def iter_sha1(iter):
+def iter_sha1(iter: Iterable[bytes]) -> bytes:
     """Return the hexdigest of the SHA1 over a set of names.
 
     Args:
@@ -374,7 +385,7 @@ def iter_sha1(iter):
     return sha.hexdigest().encode("ascii")
 
 
-def load_pack_index(path: Union[str, os.PathLike]):
+def load_pack_index(path: Union[str, os.PathLike]) -> "PackIndex":
     """Load an index file by path.
 
     Args:
@@ -385,7 +396,9 @@ def load_pack_index(path: Union[str, os.PathLike]):
         return load_pack_index_file(path, f)
 
 
-def _load_file_contents(f, size=None):
+def _load_file_contents(
+    f: Union[IO[bytes], _GitFile], size: Optional[int] = None
+) -> tuple[Union[bytes, Any], int]:
     try:
         fd = f.fileno()
     except (UnsupportedOperation, AttributeError):
@@ -402,12 +415,14 @@ def _load_file_contents(f, size=None):
                 pass
             else:
                 return contents, size
-    contents = f.read()
-    size = len(contents)
-    return contents, size
+    contents_bytes = f.read()
+    size = len(contents_bytes)
+    return contents_bytes, size
 
 
-def load_pack_index_file(path: Union[str, os.PathLike], f):
+def load_pack_index_file(
+    path: Union[str, os.PathLike], f: Union[IO[bytes], _GitFile]
+) -> "PackIndex":
     """Load an index file from a file-like object.
 
     Args:
@@ -428,7 +443,9 @@ def load_pack_index_file(path: Union[str, os.PathLike], f):
         return PackIndex1(path, file=f, contents=contents, size=size)
 
 
-def bisect_find_sha(start, end, sha, unpack_name):
+def bisect_find_sha(
+    start: int, end: int, sha: bytes, unpack_name: Callable[[int], bytes]
+) -> Optional[int]:
     """Find a SHA in a data blob with sorted SHAs.
 
     Args:
@@ -465,7 +482,7 @@ class PackIndex:
     hash_algorithm = 1
     hash_size = 20
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         if not isinstance(other, PackIndex):
             return False
 
@@ -476,7 +493,7 @@ class PackIndex:
                 return False
         return True
 
-    def __ne__(self, other):
+    def __ne__(self, other: object) -> bool:
         return not self.__eq__(other)
 
     def __len__(self) -> int:
@@ -495,10 +512,10 @@ class PackIndex:
         """
         raise NotImplementedError(self.iterentries)
 
-    def get_pack_checksum(self) -> bytes:
+    def get_pack_checksum(self) -> Optional[bytes]:
         """Return the SHA1 checksum stored for the corresponding packfile.
 
-        Returns: 20-byte binary digest
+        Returns: 20-byte binary digest, or None if not available
         """
         raise NotImplementedError(self.get_pack_checksum)
 
@@ -552,7 +569,11 @@ class PackIndex:
 class MemoryPackIndex(PackIndex):
     """Pack index that is stored entirely in memory."""
 
-    def __init__(self, entries, pack_checksum=None) -> None:
+    def __init__(
+        self,
+        entries: list[tuple[bytes, int, Optional[int]]],
+        pack_checksum: Optional[bytes] = None,
+    ) -> None:
         """Create a new MemoryPackIndex.
 
         Args:
@@ -567,33 +588,35 @@ class MemoryPackIndex(PackIndex):
         self._entries = entries
         self._pack_checksum = pack_checksum
 
-    def get_pack_checksum(self):
+    def get_pack_checksum(self) -> Optional[bytes]:
         return self._pack_checksum
 
     def __len__(self) -> int:
         return len(self._entries)
 
-    def object_offset(self, sha):
+    def object_offset(self, sha: bytes) -> int:
         if len(sha) == 40:
             sha = hex_to_sha(sha)
         return self._by_sha[sha]
 
-    def object_sha1(self, offset):
+    def object_sha1(self, offset: int) -> bytes:
         return self._by_offset[offset]
 
-    def _itersha(self):
+    def _itersha(self) -> Iterator[bytes]:
         return iter(self._by_sha)
 
-    def iterentries(self):
+    def iterentries(self) -> Iterator[PackIndexEntry]:
         return iter(self._entries)
 
     @classmethod
-    def for_pack(cls, pack):
-        return MemoryPackIndex(pack.sorted_entries(), pack.calculate_checksum())
+    def for_pack(cls, pack_data: "PackData") -> "MemoryPackIndex":
+        return MemoryPackIndex(
+            list(pack_data.sorted_entries()), pack_data.get_stored_checksum()
+        )
 
     @classmethod
-    def clone(cls, other_index):
-        return cls(other_index.iterentries(), other_index.get_pack_checksum())
+    def clone(cls, other_index: "PackIndex") -> "MemoryPackIndex":
+        return cls(list(other_index.iterentries()), other_index.get_pack_checksum())
 
 
 class FilePackIndex(PackIndex):
@@ -610,7 +633,13 @@ class FilePackIndex(PackIndex):
 
     _fan_out_table: list[int]
 
-    def __init__(self, filename, file=None, contents=None, size=None) -> None:
+    def __init__(
+        self,
+        filename: Union[str, os.PathLike],
+        file: Optional[BinaryIO] = None,
+        contents: Optional[Union[bytes, "mmap.mmap"]] = None,
+        size: Optional[int] = None,
+    ) -> None:
         """Create a pack index object.
 
         Provide it with the name of the index file to consider, and it will map
@@ -626,13 +655,14 @@ class FilePackIndex(PackIndex):
         if contents is None:
             self._contents, self._size = _load_file_contents(self._file, size)
         else:
-            self._contents, self._size = (contents, size)
+            self._contents = contents
+            self._size = size if size is not None else len(contents)
 
     @property
     def path(self) -> str:
-        return self._filename
+        return os.fspath(self._filename)
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         # Quick optimization:
         if (
             isinstance(other, FilePackIndex)
@@ -644,8 +674,9 @@ class FilePackIndex(PackIndex):
 
     def close(self) -> None:
         self._file.close()
-        if getattr(self._contents, "close", None) is not None:
-            self._contents.close()
+        close_fn = getattr(self._contents, "close", None)
+        if close_fn is not None:
+            close_fn()
 
     def __len__(self) -> int:
         """Return the number of entries in this pack index."""
@@ -684,7 +715,7 @@ class FilePackIndex(PackIndex):
         for i in range(len(self)):
             yield self._unpack_entry(i)
 
-    def _read_fan_out_table(self, start_offset: int):
+    def _read_fan_out_table(self, start_offset: int) -> list[int]:
         ret = []
         for i in range(0x100):
             fanout_entry = self._contents[
@@ -1604,6 +1635,9 @@ class DeltaChainIterator(Generic[T]):
             done.add(off)
             base_ofs = None
             if unpacked.pack_type_num == OFS_DELTA:
+                assert unpacked.offset is not None
+                assert unpacked.delta_base is not None
+                assert isinstance(unpacked.delta_base, int)
                 base_ofs = unpacked.offset - unpacked.delta_base
             elif unpacked.pack_type_num == REF_DELTA:
                 with suppress(KeyError):
@@ -1616,7 +1650,10 @@ class DeltaChainIterator(Generic[T]):
     def record(self, unpacked: UnpackedObject) -> None:
         type_num = unpacked.pack_type_num
         offset = unpacked.offset
+        assert offset is not None
         if type_num == OFS_DELTA:
+            assert unpacked.delta_base is not None
+            assert isinstance(unpacked.delta_base, int)
             base_offset = offset - unpacked.delta_base
             self._pending_ofs[base_offset].append(offset)
         elif type_num == REF_DELTA:
@@ -1690,6 +1727,7 @@ class DeltaChainIterator(Generic[T]):
             unpacked = self._resolve_object(offset, obj_type_num, base_chunks)
             yield self._result(unpacked)
 
+            assert unpacked.offset is not None
             unblocked = chain(
                 self._pending_ofs.pop(unpacked.offset, []),
                 self._pending_ref.pop(unpacked.sha(), []),
@@ -2858,7 +2896,10 @@ class Pack:
         )
         idx_stored_checksum = self.index.get_pack_checksum()
         data_stored_checksum = self.data.get_stored_checksum()
-        if idx_stored_checksum != data_stored_checksum:
+        if (
+            idx_stored_checksum is not None
+            and idx_stored_checksum != data_stored_checksum
+        ):
             raise ChecksumMismatch(
                 sha_to_hex(idx_stored_checksum),
                 sha_to_hex(data_stored_checksum),
@@ -2955,7 +2996,7 @@ class Pack:
                 yield child
         assert not ofs_pending
         if not allow_missing and todo:
-            raise UnresolvedDeltas(todo)
+            raise UnresolvedDeltas(list(todo))
 
     def iter_unpacked(self, include_comp=False):
         ofs_to_entries = {
@@ -3026,7 +3067,7 @@ class Pack:
                 base_offset, base_type, base_obj = get_ref(basename)
                 assert isinstance(base_type, int)
                 if base_offset == prev_offset:  # object is based on itself
-                    raise UnresolvedDeltas(sha_to_hex(basename))
+                    raise UnresolvedDeltas([basename])
             delta_stack.append((prev_offset, base_type, delta))
 
         # Now grab the base object (mustn't be a delta) and apply the

+ 2 - 1
dulwich/patch.py

@@ -118,7 +118,8 @@ def get_summary(commit: "Commit") -> str:
     Returns: Summary string
     """
     decoded = commit.message.decode(errors="replace")
-    return decoded.splitlines()[0].replace(" ", "-")
+    lines = decoded.splitlines()
+    return lines[0].replace(" ", "-") if lines else ""
 
 
 #  Unified Diff

+ 1 - 1
dulwich/rebase.py

@@ -458,7 +458,7 @@ class Rebaser:
 
     def is_in_progress(self) -> bool:
         """Check if a rebase is currently in progress."""
-        return self._state_manager.exists()
+        return bool(self._state_manager.exists())
 
     def abort(self) -> None:
         """Abort an in-progress rebase and restore original state."""

+ 5 - 2
dulwich/reflog.py

@@ -23,8 +23,9 @@
 
 import collections
 from collections.abc import Generator
-from typing import BinaryIO, Optional, Union
+from typing import IO, BinaryIO, Optional, Union
 
+from .file import _GitFile
 from .objects import ZERO_SHA, format_timezone, parse_timezone
 
 Entry = collections.namedtuple(
@@ -89,7 +90,9 @@ def parse_reflog_line(line: bytes) -> Entry:
     )
 
 
-def read_reflog(f: BinaryIO) -> Generator[Entry, None, None]:
+def read_reflog(
+    f: Union[BinaryIO, IO[bytes], _GitFile],
+) -> Generator[Entry, None, None]:
     """Read reflog.
 
     Args:

+ 2 - 1
dulwich/refs.py

@@ -1401,7 +1401,8 @@ class locked_ref:
 
         filename = self._refs_container.refpath(self._realname)
         ensure_dir_exists(os.path.dirname(filename))
-        self._file = GitFile(filename, "wb")
+        f = GitFile(filename, "wb")
+        self._file = f
         return self
 
     def __exit__(

+ 1 - 1
dulwich/sparse_patterns.py

@@ -168,7 +168,7 @@ def apply_included_paths(
         norm_data = normalizer.checkin_normalize(disk_data, full_path)
         if not isinstance(blob_obj, Blob):
             return True
-        return norm_data != blob_obj.data
+        return bool(norm_data != blob_obj.data)
 
     # 1) Update skip-worktree bits
 

+ 1 - 1
dulwich/stash.py

@@ -299,7 +299,7 @@ class Stash:
         # TODO(jelmer): Just pass parents into do_commit()?
         self._repo.refs[self._ref] = self._repo.head()
 
-        cid = self._repo.get_worktree().commit(
+        cid: ObjectID = self._repo.get_worktree().commit(
             ref=self._ref,
             tree=stash_tree_id,
             message=message,

+ 1 - 1
tests/test_pack.py

@@ -1594,7 +1594,7 @@ class DeltaChainIteratorTests(TestCase):
             # Attempting to open this REF_DELTA object would loop forever
             pack[b1.id]
         except UnresolvedDeltas as e:
-            self.assertEqual((b1.id), e.shas)
+            self.assertEqual([hex_to_sha(b1.id)], e.shas)
 
 
 class DeltaEncodeSizeTests(TestCase):