Просмотр исходного кода

Fix ruff ANN401 errors in dulwich/file.py

Jelmer Vernooij 5 месяцев назад
Родитель
Сommit
8f716bf8ee
13 измененных файлов с 171 добавлено и 75 удалено
  1. 3 1
      dulwich/annotate.py
  2. 3 1
      dulwich/commit_graph.py
  3. 3 2
      dulwich/config.py
  4. 16 6
      dulwich/errors.py
  5. 27 7
      dulwich/file.py
  6. 3 1
      dulwich/filters.py
  7. 2 2
      dulwich/hooks.py
  8. 3 1
      dulwich/index.py
  9. 2 2
      dulwich/lfs.py
  10. 4 1
      dulwich/object_store.py
  11. 89 48
      dulwich/pack.py
  12. 5 2
      dulwich/reflog.py
  13. 11 1
      dulwich/refs.py

+ 3 - 1
dulwich/annotate.py

@@ -108,5 +108,7 @@ def annotate_lines(
 
 
     lines_annotated: list[tuple[tuple[Commit, TreeEntry], bytes]] = []
     lines_annotated: list[tuple[tuple[Commit, TreeEntry], bytes]] = []
     for commit, entry in reversed(revs):
     for commit, entry in reversed(revs):
-        lines_annotated = update_lines(lines_annotated, (commit, entry), cast("Blob", store[entry.sha]))
+        lines_annotated = update_lines(
+            lines_annotated, (commit, entry), cast("Blob", store[entry.sha])
+        )
     return lines_annotated
     return lines_annotated

+ 3 - 1
dulwich/commit_graph.py

@@ -21,6 +21,8 @@ import struct
 from collections.abc import Iterator
 from collections.abc import Iterator
 from typing import TYPE_CHECKING, BinaryIO, Optional, Union
 from typing import TYPE_CHECKING, BinaryIO, Optional, Union
 
 
+from .file import _GitFile
+
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
     from .object_store import BaseObjectStore
 
 
@@ -269,7 +271,7 @@ class CommitGraph:
         entry = self.get_entry_by_oid(oid)
         entry = self.get_entry_by_oid(oid)
         return entry.parents if entry else None
         return entry.parents if entry else None
 
 
-    def write_to_file(self, f: BinaryIO) -> None:
+    def write_to_file(self, f: Union[BinaryIO, _GitFile]) -> None:
         """Write commit graph to file."""
         """Write commit graph to file."""
         if not self.entries:
         if not self.entries:
             raise ValueError("Cannot write empty commit graph")
             raise ValueError("Cannot write empty commit graph")

+ 3 - 2
dulwich/config.py

@@ -40,6 +40,7 @@ from collections.abc import (
 from contextlib import suppress
 from contextlib import suppress
 from pathlib import Path
 from pathlib import Path
 from typing import (
 from typing import (
+    IO,
     Any,
     Any,
     BinaryIO,
     BinaryIO,
     Callable,
     Callable,
@@ -50,7 +51,7 @@ from typing import (
     overload,
     overload,
 )
 )
 
 
-from .file import GitFile
+from .file import GitFile, _GitFile
 
 
 ConfigKey = Union[str, bytes, tuple[Union[str, bytes], ...]]
 ConfigKey = Union[str, bytes, tuple[Union[str, bytes], ...]]
 ConfigValue = Union[str, bytes, bool, int]
 ConfigValue = Union[str, bytes, bool, int]
@@ -1112,7 +1113,7 @@ class ConfigFile(ConfigDict):
         with GitFile(path_to_use, "wb") as f:
         with GitFile(path_to_use, "wb") as f:
             self.write_to_file(f)
             self.write_to_file(f)
 
 
-    def write_to_file(self, f: BinaryIO) -> None:
+    def write_to_file(self, f: Union[IO[bytes], _GitFile]) -> None:
         """Write configuration to a file-like object."""
         """Write configuration to a file-like object."""
         for section, values in self._values.items():
         for section, values in self._values.items():
             try:
             try:

+ 16 - 6
dulwich/errors.py

@@ -33,15 +33,22 @@ from typing import Optional, Union
 class ChecksumMismatch(Exception):
 class ChecksumMismatch(Exception):
     """A checksum didn't match the expected contents."""
     """A checksum didn't match the expected contents."""
 
 
-    def __init__(self, expected: Union[bytes, str], got: Union[bytes, str], extra: Optional[str] = None) -> None:
+    def __init__(
+        self,
+        expected: Union[bytes, str],
+        got: Union[bytes, str],
+        extra: Optional[str] = None,
+    ) -> None:
         if isinstance(expected, bytes) and len(expected) == 20:
         if isinstance(expected, bytes) and len(expected) == 20:
-            expected_str = binascii.hexlify(expected).decode('ascii')
+            expected_str = binascii.hexlify(expected).decode("ascii")
         else:
         else:
-            expected_str = expected if isinstance(expected, str) else expected.decode('ascii')
+            expected_str = (
+                expected if isinstance(expected, str) else expected.decode("ascii")
+            )
         if isinstance(got, bytes) and len(got) == 20:
         if isinstance(got, bytes) and len(got) == 20:
-            got_str = binascii.hexlify(got).decode('ascii')
+            got_str = binascii.hexlify(got).decode("ascii")
         else:
         else:
-            got_str = got if isinstance(got, str) else got.decode('ascii')
+            got_str = got if isinstance(got, str) else got.decode("ascii")
         self.expected = expected_str
         self.expected = expected_str
         self.got = got_str
         self.got = got_str
         self.extra = extra
         self.extra = extra
@@ -148,7 +155,10 @@ class HangupException(GitProtocolError):
         self.stderr_lines = stderr_lines
         self.stderr_lines = stderr_lines
 
 
     def __eq__(self, other: object) -> bool:
     def __eq__(self, other: object) -> bool:
-        return isinstance(other, HangupException) and self.stderr_lines == other.stderr_lines
+        return (
+            isinstance(other, HangupException)
+            and self.stderr_lines == other.stderr_lines
+        )
 
 
 
 
 class UnexpectedCommandError(GitProtocolError):
 class UnexpectedCommandError(GitProtocolError):

+ 27 - 7
dulwich/file.py

@@ -24,7 +24,9 @@
 import os
 import os
 import sys
 import sys
 import warnings
 import warnings
-from typing import Any, ClassVar, Union
+from collections.abc import Iterator
+from types import TracebackType
+from typing import IO, Any, ClassVar, Optional, Union
 
 
 
 
 def ensure_dir_exists(dirname: Union[str, bytes, os.PathLike]) -> None:
 def ensure_dir_exists(dirname: Union[str, bytes, os.PathLike]) -> None:
@@ -58,8 +60,11 @@ def _fancy_rename(oldname: Union[str, bytes], newname: Union[str, bytes]) -> Non
 
 
 
 
 def GitFile(
 def GitFile(
-    filename: Union[str, bytes, os.PathLike], mode: str = "rb", bufsize: int = -1, mask: int = 0o644
-) -> Any:
+    filename: Union[str, bytes, os.PathLike],
+    mode: str = "rb",
+    bufsize: int = -1,
+    mask: int = 0o644,
+) -> Union[IO[bytes], "_GitFile"]:
     """Create a file object that obeys the git file locking protocol.
     """Create a file object that obeys the git file locking protocol.
 
 
     Returns: a builtin file object or a _GitFile object
     Returns: a builtin file object or a _GitFile object
@@ -90,7 +95,9 @@ def GitFile(
 class FileLocked(Exception):
 class FileLocked(Exception):
     """File is already locked."""
     """File is already locked."""
 
 
-    def __init__(self, filename: Union[str, bytes, os.PathLike], lockfilename: Union[str, bytes]) -> None:
+    def __init__(
+        self, filename: Union[str, bytes, os.PathLike], lockfilename: Union[str, bytes]
+    ) -> None:
         self.filename = filename
         self.filename = filename
         self.lockfilename = lockfilename
         self.lockfilename = lockfilename
         super().__init__(filename, lockfilename)
         super().__init__(filename, lockfilename)
@@ -132,7 +139,11 @@ class _GitFile:
     }
     }
 
 
     def __init__(
     def __init__(
-        self, filename: Union[str, bytes, os.PathLike], mode: str, bufsize: int, mask: int
+        self,
+        filename: Union[str, bytes, os.PathLike],
+        mode: str,
+        bufsize: int,
+        mask: int,
     ) -> None:
     ) -> None:
         # Convert PathLike to str/bytes for our internal use
         # Convert PathLike to str/bytes for our internal use
         self._filename: Union[str, bytes] = os.fspath(filename)
         self._filename: Union[str, bytes] = os.fspath(filename)
@@ -154,6 +165,10 @@ class _GitFile:
         for method in self.PROXY_METHODS:
         for method in self.PROXY_METHODS:
             setattr(self, method, getattr(self._file, method))
             setattr(self, method, getattr(self._file, method))
 
 
+    def __iter__(self) -> Iterator[bytes]:
+        """Iterate over lines in the file."""
+        return iter(self._file)
+
     def abort(self) -> None:
     def abort(self) -> None:
         """Close and discard the lockfile without overwriting the target.
         """Close and discard the lockfile without overwriting the target.
 
 
@@ -208,13 +223,18 @@ class _GitFile:
     def __enter__(self) -> "_GitFile":
     def __enter__(self) -> "_GitFile":
         return self
         return self
 
 
-    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+    def __exit__(
+        self,
+        exc_type: Optional[type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
         if exc_type is not None:
         if exc_type is not None:
             self.abort()
             self.abort()
         else:
         else:
             self.close()
             self.close()
 
 
-    def __getattr__(self, name: str) -> Any:
+    def __getattr__(self, name: str) -> Any:  # noqa: ANN401
         """Proxy property calls to the underlying file."""
         """Proxy property calls to the underlying file."""
         if name in self.PROXY_PROPERTIES:
         if name in self.PROXY_PROPERTIES:
             return getattr(self._file, name)
             return getattr(self._file, name)

+ 3 - 1
dulwich/filters.py

@@ -119,7 +119,9 @@ class ProcessFilterDriver:
 class FilterRegistry:
 class FilterRegistry:
     """Registry for filter drivers."""
     """Registry for filter drivers."""
 
 
-    def __init__(self, config: Optional["StackedConfig"] = None, repo: Optional["Repo"] = None) -> None:
+    def __init__(
+        self, config: Optional["StackedConfig"] = None, repo: Optional["Repo"] = None
+    ) -> None:
         self.config = config
         self.config = config
         self.repo = repo
         self.repo = repo
         self._drivers: dict[str, FilterDriver] = {}
         self._drivers: dict[str, FilterDriver] = {}

+ 2 - 2
dulwich/hooks.py

@@ -31,7 +31,7 @@ from .errors import HookError
 class Hook:
 class Hook:
     """Generic hook object."""
     """Generic hook object."""
 
 
-    def execute(self, *args: Any) -> Any:
+    def execute(self, *args: Any) -> Any:  # noqa: ANN401
         """Execute the hook with the given args.
         """Execute the hook with the given args.
 
 
         Args:
         Args:
@@ -86,7 +86,7 @@ class ShellHook(Hook):
 
 
         self.cwd = cwd
         self.cwd = cwd
 
 
-    def execute(self, *args: Any) -> Any:
+    def execute(self, *args: Any) -> Any:  # noqa: ANN401
         """Execute the hook with given args."""
         """Execute the hook with given args."""
         if len(args) != self.numparam:
         if len(args) != self.numparam:
             raise HookError(
             raise HookError(

+ 3 - 1
dulwich/index.py

@@ -2448,7 +2448,9 @@ class locked_index:
         self._path = path
         self._path = path
 
 
     def __enter__(self) -> Index:
     def __enter__(self) -> Index:
-        self._file = GitFile(self._path, "wb")
+        f = GitFile(self._path, "wb")
+        assert isinstance(f, _GitFile)  # GitFile in write mode always returns _GitFile
+        self._file = f
         self._index = Index(self._path)
         self._index = Index(self._path)
         return self._index
         return self._index
 
 

+ 2 - 2
dulwich/lfs.py

@@ -32,7 +32,7 @@ from urllib.request import Request, urlopen
 
 
 if TYPE_CHECKING:
 if TYPE_CHECKING:
     import urllib3
     import urllib3
-    
+
     from .config import Config
     from .config import Config
     from .repo import Repo
     from .repo import Repo
 
 
@@ -319,7 +319,7 @@ class LFSClient:
         """
         """
         self._base_url = url.rstrip("/") + "/"  # Ensure trailing slash for urljoin
         self._base_url = url.rstrip("/") + "/"  # Ensure trailing slash for urljoin
         self.config = config
         self.config = config
-        self._pool_manager = None
+        self._pool_manager: Optional[urllib3.PoolManager] = None
 
 
     @classmethod
     @classmethod
     def from_config(cls, config: "Config") -> Optional["LFSClient"]:
     def from_config(cls, config: "Config") -> Optional["LFSClient"]:

+ 4 - 1
dulwich/object_store.py

@@ -41,7 +41,7 @@ from typing import (
 )
 )
 
 
 from .errors import NotTreeError
 from .errors import NotTreeError
-from .file import GitFile
+from .file import GitFile, _GitFile
 from .objects import (
 from .objects import (
     S_ISGITLINK,
     S_ISGITLINK,
     ZERO_SHA,
     ZERO_SHA,
@@ -1505,6 +1505,9 @@ class DiskObjectStore(PackBasedObjectStore):
                 # Write using GitFile for atomic operation
                 # Write using GitFile for atomic operation
                 graph_path = os.path.join(info_dir, "commit-graph")
                 graph_path = os.path.join(info_dir, "commit-graph")
                 with GitFile(graph_path, "wb") as f:
                 with GitFile(graph_path, "wb") as f:
+                    assert isinstance(
+                        f, _GitFile
+                    )  # GitFile in write mode always returns _GitFile
                     graph.write_to_file(f)
                     graph.write_to_file(f)
 
 
             # Clear cached commit graph so it gets reloaded
             # Clear cached commit graph so it gets reloaded

+ 89 - 48
dulwich/pack.py

@@ -54,6 +54,9 @@ from itertools import chain
 from os import SEEK_CUR, SEEK_END
 from os import SEEK_CUR, SEEK_END
 from struct import unpack_from
 from struct import unpack_from
 from typing import (
 from typing import (
+    IO,
+    TYPE_CHECKING,
+    Any,
     BinaryIO,
     BinaryIO,
     Callable,
     Callable,
     Generic,
     Generic,
@@ -70,13 +73,16 @@ except ImportError:
 else:
 else:
     has_mmap = True
     has_mmap = True
 
 
+if TYPE_CHECKING:
+    from .commit_graph import CommitGraph
+
 # For some reason the above try, except fails to set has_mmap = False for plan9
 # For some reason the above try, except fails to set has_mmap = False for plan9
 if sys.platform == "Plan9":
 if sys.platform == "Plan9":
     has_mmap = False
     has_mmap = False
 
 
 from . import replace_me
 from . import replace_me
 from .errors import ApplyDeltaError, ChecksumMismatch
 from .errors import ApplyDeltaError, ChecksumMismatch
-from .file import GitFile
+from .file import GitFile, _GitFile
 from .lru_cache import LRUSizeCache
 from .lru_cache import LRUSizeCache
 from .objects import ObjectID, ShaFile, hex_to_sha, object_header, sha_to_hex
 from .objects import ObjectID, ShaFile, hex_to_sha, object_header, sha_to_hex
 
 
@@ -104,7 +110,7 @@ PackHint = tuple[int, Optional[bytes]]
 class UnresolvedDeltas(Exception):
 class UnresolvedDeltas(Exception):
     """Delta objects could not be resolved."""
     """Delta objects could not be resolved."""
 
 
-    def __init__(self, shas) -> None:
+    def __init__(self, shas: list[bytes]) -> None:
         self.shas = shas
         self.shas = shas
 
 
 
 
@@ -129,7 +135,7 @@ class ObjectContainer(Protocol):
     def __getitem__(self, sha1: bytes) -> ShaFile:
     def __getitem__(self, sha1: bytes) -> ShaFile:
         """Retrieve an object."""
         """Retrieve an object."""
 
 
-    def get_commit_graph(self):
+    def get_commit_graph(self) -> Optional["CommitGraph"]:
         """Get the commit graph for this object store.
         """Get the commit graph for this object store.
 
 
         Returns:
         Returns:
@@ -186,7 +192,7 @@ def take_msb_bytes(
 
 
 
 
 class PackFileDisappeared(Exception):
 class PackFileDisappeared(Exception):
-    def __init__(self, obj) -> None:
+    def __init__(self, obj: object) -> None:
         self.obj = obj
         self.obj = obj
 
 
 
 
@@ -219,19 +225,24 @@ class UnpackedObject:
     delta_base: Union[None, bytes, int]
     delta_base: Union[None, bytes, int]
     decomp_chunks: list[bytes]
     decomp_chunks: list[bytes]
     comp_chunks: Optional[list[bytes]]
     comp_chunks: Optional[list[bytes]]
+    decomp_len: Optional[int]
+    crc32: Optional[int]
+    offset: Optional[int]
+    pack_type_num: int
+    _sha: Optional[bytes]
 
 
     # TODO(dborowitz): read_zlib_chunks and unpack_object could very well be
     # TODO(dborowitz): read_zlib_chunks and unpack_object could very well be
     # methods of this object.
     # methods of this object.
     def __init__(
     def __init__(
         self,
         self,
-        pack_type_num,
+        pack_type_num: int,
         *,
         *,
-        delta_base=None,
-        decomp_len=None,
-        crc32=None,
-        sha=None,
-        decomp_chunks=None,
-        offset=None,
+        delta_base: Union[None, bytes, int] = None,
+        decomp_len: Optional[int] = None,
+        crc32: Optional[int] = None,
+        sha: Optional[bytes] = None,
+        decomp_chunks: Optional[list[bytes]] = None,
+        offset: Optional[int] = None,
     ) -> None:
     ) -> None:
         self.offset = offset
         self.offset = offset
         self._sha = sha
         self._sha = sha
@@ -253,13 +264,13 @@ class UnpackedObject:
             self.obj_chunks = self.decomp_chunks
             self.obj_chunks = self.decomp_chunks
             self.delta_base = delta_base
             self.delta_base = delta_base
 
 
-    def sha(self):
+    def sha(self) -> bytes:
         """Return the binary SHA of this object."""
         """Return the binary SHA of this object."""
         if self._sha is None:
         if self._sha is None:
             self._sha = obj_sha(self.obj_type_num, self.obj_chunks)
             self._sha = obj_sha(self.obj_type_num, self.obj_chunks)
         return self._sha
         return self._sha
 
 
-    def sha_file(self):
+    def sha_file(self) -> ShaFile:
         """Return a ShaFile from this object."""
         """Return a ShaFile from this object."""
         assert self.obj_type_num is not None and self.obj_chunks is not None
         assert self.obj_type_num is not None and self.obj_chunks is not None
         return ShaFile.from_raw_chunks(self.obj_type_num, self.obj_chunks)
         return ShaFile.from_raw_chunks(self.obj_type_num, self.obj_chunks)
@@ -274,7 +285,7 @@ class UnpackedObject:
         else:
         else:
             return self.decomp_chunks
             return self.decomp_chunks
 
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         if not isinstance(other, UnpackedObject):
         if not isinstance(other, UnpackedObject):
             return False
             return False
         for slot in self.__slots__:
         for slot in self.__slots__:
@@ -282,7 +293,7 @@ class UnpackedObject:
                 return False
                 return False
         return True
         return True
 
 
-    def __ne__(self, other):
+    def __ne__(self, other: object) -> bool:
         return not (self == other)
         return not (self == other)
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
@@ -322,7 +333,7 @@ def read_zlib_chunks(
     Raises:
     Raises:
       zlib.error: if a decompression error occurred.
       zlib.error: if a decompression error occurred.
     """
     """
-    if unpacked.decomp_len <= -1:
+    if unpacked.decomp_len is None or unpacked.decomp_len <= -1:
         raise ValueError("non-negative zlib data stream size expected")
         raise ValueError("non-negative zlib data stream size expected")
     decomp_obj = zlib.decompressobj()
     decomp_obj = zlib.decompressobj()
 
 
@@ -361,7 +372,7 @@ def read_zlib_chunks(
     return unused
     return unused
 
 
 
 
-def iter_sha1(iter):
+def iter_sha1(iter: Iterable[bytes]) -> bytes:
     """Return the hexdigest of the SHA1 over a set of names.
     """Return the hexdigest of the SHA1 over a set of names.
 
 
     Args:
     Args:
@@ -374,7 +385,7 @@ def iter_sha1(iter):
     return sha.hexdigest().encode("ascii")
     return sha.hexdigest().encode("ascii")
 
 
 
 
-def load_pack_index(path: Union[str, os.PathLike]):
+def load_pack_index(path: Union[str, os.PathLike]) -> "PackIndex":
     """Load an index file by path.
     """Load an index file by path.
 
 
     Args:
     Args:
@@ -385,7 +396,9 @@ def load_pack_index(path: Union[str, os.PathLike]):
         return load_pack_index_file(path, f)
         return load_pack_index_file(path, f)
 
 
 
 
-def _load_file_contents(f, size=None):
+def _load_file_contents(
+    f: Union[IO[bytes], _GitFile], size: Optional[int] = None
+) -> tuple[Union[bytes, Any], int]:
     try:
     try:
         fd = f.fileno()
         fd = f.fileno()
     except (UnsupportedOperation, AttributeError):
     except (UnsupportedOperation, AttributeError):
@@ -402,12 +415,14 @@ def _load_file_contents(f, size=None):
                 pass
                 pass
             else:
             else:
                 return contents, size
                 return contents, size
-    contents = f.read()
-    size = len(contents)
-    return contents, size
+    contents_bytes = f.read()
+    size = len(contents_bytes)
+    return contents_bytes, size
 
 
 
 
-def load_pack_index_file(path: Union[str, os.PathLike], f):
+def load_pack_index_file(
+    path: Union[str, os.PathLike], f: Union[IO[bytes], _GitFile]
+) -> "PackIndex":
     """Load an index file from a file-like object.
     """Load an index file from a file-like object.
 
 
     Args:
     Args:
@@ -428,7 +443,9 @@ def load_pack_index_file(path: Union[str, os.PathLike], f):
         return PackIndex1(path, file=f, contents=contents, size=size)
         return PackIndex1(path, file=f, contents=contents, size=size)
 
 
 
 
-def bisect_find_sha(start, end, sha, unpack_name):
+def bisect_find_sha(
+    start: int, end: int, sha: bytes, unpack_name: Callable[[int], bytes]
+) -> Optional[int]:
     """Find a SHA in a data blob with sorted SHAs.
     """Find a SHA in a data blob with sorted SHAs.
 
 
     Args:
     Args:
@@ -465,7 +482,7 @@ class PackIndex:
     hash_algorithm = 1
     hash_algorithm = 1
     hash_size = 20
     hash_size = 20
 
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         if not isinstance(other, PackIndex):
         if not isinstance(other, PackIndex):
             return False
             return False
 
 
@@ -476,7 +493,7 @@ class PackIndex:
                 return False
                 return False
         return True
         return True
 
 
-    def __ne__(self, other):
+    def __ne__(self, other: object) -> bool:
         return not self.__eq__(other)
         return not self.__eq__(other)
 
 
     def __len__(self) -> int:
     def __len__(self) -> int:
@@ -495,10 +512,10 @@ class PackIndex:
         """
         """
         raise NotImplementedError(self.iterentries)
         raise NotImplementedError(self.iterentries)
 
 
-    def get_pack_checksum(self) -> bytes:
+    def get_pack_checksum(self) -> Optional[bytes]:
         """Return the SHA1 checksum stored for the corresponding packfile.
         """Return the SHA1 checksum stored for the corresponding packfile.
 
 
-        Returns: 20-byte binary digest
+        Returns: 20-byte binary digest, or None if not available
         """
         """
         raise NotImplementedError(self.get_pack_checksum)
         raise NotImplementedError(self.get_pack_checksum)
 
 
@@ -552,7 +569,11 @@ class PackIndex:
 class MemoryPackIndex(PackIndex):
 class MemoryPackIndex(PackIndex):
     """Pack index that is stored entirely in memory."""
     """Pack index that is stored entirely in memory."""
 
 
-    def __init__(self, entries, pack_checksum=None) -> None:
+    def __init__(
+        self,
+        entries: list[tuple[bytes, int, Optional[int]]],
+        pack_checksum: Optional[bytes] = None,
+    ) -> None:
         """Create a new MemoryPackIndex.
         """Create a new MemoryPackIndex.
 
 
         Args:
         Args:
@@ -567,33 +588,35 @@ class MemoryPackIndex(PackIndex):
         self._entries = entries
         self._entries = entries
         self._pack_checksum = pack_checksum
         self._pack_checksum = pack_checksum
 
 
-    def get_pack_checksum(self):
+    def get_pack_checksum(self) -> Optional[bytes]:
         return self._pack_checksum
         return self._pack_checksum
 
 
     def __len__(self) -> int:
     def __len__(self) -> int:
         return len(self._entries)
         return len(self._entries)
 
 
-    def object_offset(self, sha):
+    def object_offset(self, sha: bytes) -> int:
         if len(sha) == 40:
         if len(sha) == 40:
             sha = hex_to_sha(sha)
             sha = hex_to_sha(sha)
         return self._by_sha[sha]
         return self._by_sha[sha]
 
 
-    def object_sha1(self, offset):
+    def object_sha1(self, offset: int) -> bytes:
         return self._by_offset[offset]
         return self._by_offset[offset]
 
 
-    def _itersha(self):
+    def _itersha(self) -> Iterator[bytes]:
         return iter(self._by_sha)
         return iter(self._by_sha)
 
 
-    def iterentries(self):
+    def iterentries(self) -> Iterator[PackIndexEntry]:
         return iter(self._entries)
         return iter(self._entries)
 
 
     @classmethod
     @classmethod
-    def for_pack(cls, pack):
-        return MemoryPackIndex(pack.sorted_entries(), pack.calculate_checksum())
+    def for_pack(cls, pack: "Pack") -> "MemoryPackIndex":
+        return MemoryPackIndex(
+            list(pack.sorted_entries()), pack.data.get_stored_checksum()
+        )
 
 
     @classmethod
     @classmethod
-    def clone(cls, other_index):
-        return cls(other_index.iterentries(), other_index.get_pack_checksum())
+    def clone(cls, other_index: "PackIndex") -> "MemoryPackIndex":
+        return cls(list(other_index.iterentries()), other_index.get_pack_checksum())
 
 
 
 
 class FilePackIndex(PackIndex):
 class FilePackIndex(PackIndex):
@@ -610,7 +633,13 @@ class FilePackIndex(PackIndex):
 
 
     _fan_out_table: list[int]
     _fan_out_table: list[int]
 
 
-    def __init__(self, filename, file=None, contents=None, size=None) -> None:
+    def __init__(
+        self,
+        filename: Union[str, os.PathLike],
+        file: Optional[BinaryIO] = None,
+        contents: Optional[Union[bytes, "mmap.mmap"]] = None,
+        size: Optional[int] = None,
+    ) -> None:
         """Create a pack index object.
         """Create a pack index object.
 
 
         Provide it with the name of the index file to consider, and it will map
         Provide it with the name of the index file to consider, and it will map
@@ -626,13 +655,14 @@ class FilePackIndex(PackIndex):
         if contents is None:
         if contents is None:
             self._contents, self._size = _load_file_contents(self._file, size)
             self._contents, self._size = _load_file_contents(self._file, size)
         else:
         else:
-            self._contents, self._size = (contents, size)
+            self._contents = contents
+            self._size = size if size is not None else len(contents)
 
 
     @property
     @property
     def path(self) -> str:
     def path(self) -> str:
-        return self._filename
+        return os.fspath(self._filename)
 
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         # Quick optimization:
         # Quick optimization:
         if (
         if (
             isinstance(other, FilePackIndex)
             isinstance(other, FilePackIndex)
@@ -644,8 +674,9 @@ class FilePackIndex(PackIndex):
 
 
     def close(self) -> None:
     def close(self) -> None:
         self._file.close()
         self._file.close()
-        if getattr(self._contents, "close", None) is not None:
-            self._contents.close()
+        close_fn = getattr(self._contents, "close", None)
+        if close_fn is not None:
+            close_fn()
 
 
     def __len__(self) -> int:
     def __len__(self) -> int:
         """Return the number of entries in this pack index."""
         """Return the number of entries in this pack index."""
@@ -684,7 +715,7 @@ class FilePackIndex(PackIndex):
         for i in range(len(self)):
         for i in range(len(self)):
             yield self._unpack_entry(i)
             yield self._unpack_entry(i)
 
 
-    def _read_fan_out_table(self, start_offset: int):
+    def _read_fan_out_table(self, start_offset: int) -> list[int]:
         ret = []
         ret = []
         for i in range(0x100):
         for i in range(0x100):
             fanout_entry = self._contents[
             fanout_entry = self._contents[
@@ -1604,6 +1635,9 @@ class DeltaChainIterator(Generic[T]):
             done.add(off)
             done.add(off)
             base_ofs = None
             base_ofs = None
             if unpacked.pack_type_num == OFS_DELTA:
             if unpacked.pack_type_num == OFS_DELTA:
+                assert unpacked.offset is not None
+                assert unpacked.delta_base is not None
+                assert isinstance(unpacked.delta_base, int)
                 base_ofs = unpacked.offset - unpacked.delta_base
                 base_ofs = unpacked.offset - unpacked.delta_base
             elif unpacked.pack_type_num == REF_DELTA:
             elif unpacked.pack_type_num == REF_DELTA:
                 with suppress(KeyError):
                 with suppress(KeyError):
@@ -1616,7 +1650,10 @@ class DeltaChainIterator(Generic[T]):
     def record(self, unpacked: UnpackedObject) -> None:
     def record(self, unpacked: UnpackedObject) -> None:
         type_num = unpacked.pack_type_num
         type_num = unpacked.pack_type_num
         offset = unpacked.offset
         offset = unpacked.offset
+        assert offset is not None
         if type_num == OFS_DELTA:
         if type_num == OFS_DELTA:
+            assert unpacked.delta_base is not None
+            assert isinstance(unpacked.delta_base, int)
             base_offset = offset - unpacked.delta_base
             base_offset = offset - unpacked.delta_base
             self._pending_ofs[base_offset].append(offset)
             self._pending_ofs[base_offset].append(offset)
         elif type_num == REF_DELTA:
         elif type_num == REF_DELTA:
@@ -1690,6 +1727,7 @@ class DeltaChainIterator(Generic[T]):
             unpacked = self._resolve_object(offset, obj_type_num, base_chunks)
             unpacked = self._resolve_object(offset, obj_type_num, base_chunks)
             yield self._result(unpacked)
             yield self._result(unpacked)
 
 
+            assert unpacked.offset is not None
             unblocked = chain(
             unblocked = chain(
                 self._pending_ofs.pop(unpacked.offset, []),
                 self._pending_ofs.pop(unpacked.offset, []),
                 self._pending_ref.pop(unpacked.sha(), []),
                 self._pending_ref.pop(unpacked.sha(), []),
@@ -2858,7 +2896,10 @@ class Pack:
         )
         )
         idx_stored_checksum = self.index.get_pack_checksum()
         idx_stored_checksum = self.index.get_pack_checksum()
         data_stored_checksum = self.data.get_stored_checksum()
         data_stored_checksum = self.data.get_stored_checksum()
-        if idx_stored_checksum != data_stored_checksum:
+        if (
+            idx_stored_checksum is not None
+            and idx_stored_checksum != data_stored_checksum
+        ):
             raise ChecksumMismatch(
             raise ChecksumMismatch(
                 sha_to_hex(idx_stored_checksum),
                 sha_to_hex(idx_stored_checksum),
                 sha_to_hex(data_stored_checksum),
                 sha_to_hex(data_stored_checksum),
@@ -2955,7 +2996,7 @@ class Pack:
                 yield child
                 yield child
         assert not ofs_pending
         assert not ofs_pending
         if not allow_missing and todo:
         if not allow_missing and todo:
-            raise UnresolvedDeltas(todo)
+            raise UnresolvedDeltas(list(todo))
 
 
     def iter_unpacked(self, include_comp=False):
     def iter_unpacked(self, include_comp=False):
         ofs_to_entries = {
         ofs_to_entries = {
@@ -3026,7 +3067,7 @@ class Pack:
                 base_offset, base_type, base_obj = get_ref(basename)
                 base_offset, base_type, base_obj = get_ref(basename)
                 assert isinstance(base_type, int)
                 assert isinstance(base_type, int)
                 if base_offset == prev_offset:  # object is based on itself
                 if base_offset == prev_offset:  # object is based on itself
-                    raise UnresolvedDeltas(sha_to_hex(basename))
+                    raise UnresolvedDeltas([basename])
             delta_stack.append((prev_offset, base_type, delta))
             delta_stack.append((prev_offset, base_type, delta))
 
 
         # Now grab the base object (mustn't be a delta) and apply the
         # Now grab the base object (mustn't be a delta) and apply the

+ 5 - 2
dulwich/reflog.py

@@ -23,8 +23,9 @@
 
 
 import collections
 import collections
 from collections.abc import Generator
 from collections.abc import Generator
-from typing import BinaryIO, Optional, Union
+from typing import IO, BinaryIO, Optional, Union
 
 
+from .file import _GitFile
 from .objects import ZERO_SHA, format_timezone, parse_timezone
 from .objects import ZERO_SHA, format_timezone, parse_timezone
 
 
 Entry = collections.namedtuple(
 Entry = collections.namedtuple(
@@ -89,7 +90,9 @@ def parse_reflog_line(line: bytes) -> Entry:
     )
     )
 
 
 
 
-def read_reflog(f: BinaryIO) -> Generator[Entry, None, None]:
+def read_reflog(
+    f: Union[BinaryIO, IO[bytes], _GitFile],
+) -> Generator[Entry, None, None]:
     """Read reflog.
     """Read reflog.
 
 
     Args:
     Args:

+ 11 - 1
dulwich/refs.py

@@ -876,6 +876,7 @@ class DiskRefsContainer(RefsContainer):
         self._check_refname(other)
         self._check_refname(other)
         filename = self.refpath(name)
         filename = self.refpath(name)
         f = GitFile(filename, "wb")
         f = GitFile(filename, "wb")
+        assert isinstance(f, _GitFile)  # GitFile in write mode always returns _GitFile
         try:
         try:
             f.write(SYMREF + other + b"\n")
             f.write(SYMREF + other + b"\n")
             sha = self.follow(name)[-1]
             sha = self.follow(name)[-1]
@@ -935,6 +936,9 @@ class DiskRefsContainer(RefsContainer):
 
 
         ensure_dir_exists(os.path.dirname(filename))
         ensure_dir_exists(os.path.dirname(filename))
         with GitFile(filename, "wb") as f:
         with GitFile(filename, "wb") as f:
+            assert isinstance(
+                f, _GitFile
+            )  # GitFile in write mode always returns _GitFile
             if old_ref is not None:
             if old_ref is not None:
                 try:
                 try:
                     # read again while holding the lock to handle race conditions
                     # read again while holding the lock to handle race conditions
@@ -1006,6 +1010,9 @@ class DiskRefsContainer(RefsContainer):
         filename = self.refpath(realname)
         filename = self.refpath(realname)
         ensure_dir_exists(os.path.dirname(filename))
         ensure_dir_exists(os.path.dirname(filename))
         with GitFile(filename, "wb") as f:
         with GitFile(filename, "wb") as f:
+            assert isinstance(
+                f, _GitFile
+            )  # GitFile in write mode always returns _GitFile
             if os.path.exists(filename) or name in self.get_packed_refs():
             if os.path.exists(filename) or name in self.get_packed_refs():
                 f.abort()
                 f.abort()
                 return False
                 return False
@@ -1051,6 +1058,7 @@ class DiskRefsContainer(RefsContainer):
         filename = self.refpath(name)
         filename = self.refpath(name)
         ensure_dir_exists(os.path.dirname(filename))
         ensure_dir_exists(os.path.dirname(filename))
         f = GitFile(filename, "wb")
         f = GitFile(filename, "wb")
+        assert isinstance(f, _GitFile)  # GitFile in write mode always returns _GitFile
         try:
         try:
             if old_ref is not None:
             if old_ref is not None:
                 orig_ref = self.read_loose_ref(name)
                 orig_ref = self.read_loose_ref(name)
@@ -1401,7 +1409,9 @@ class locked_ref:
 
 
         filename = self._refs_container.refpath(self._realname)
         filename = self._refs_container.refpath(self._realname)
         ensure_dir_exists(os.path.dirname(filename))
         ensure_dir_exists(os.path.dirname(filename))
-        self._file = GitFile(filename, "wb")
+        f = GitFile(filename, "wb")
+        assert isinstance(f, _GitFile)  # GitFile in write mode always returns _GitFile
+        self._file = f
         return self
         return self
 
 
     def __exit__(
     def __exit__(