Sfoglia il codice sorgente

Fix ruff ANN401 errors in dulwich/file.py

Jelmer Vernooij 5 mesi fa
parent
commit
8f716bf8ee

+ 3 - 1
dulwich/annotate.py

@@ -108,5 +108,7 @@ def annotate_lines(
 
     lines_annotated: list[tuple[tuple[Commit, TreeEntry], bytes]] = []
     for commit, entry in reversed(revs):
-        lines_annotated = update_lines(lines_annotated, (commit, entry), cast("Blob", store[entry.sha]))
+        lines_annotated = update_lines(
+            lines_annotated, (commit, entry), cast("Blob", store[entry.sha])
+        )
     return lines_annotated

+ 3 - 1
dulwich/commit_graph.py

@@ -21,6 +21,8 @@ import struct
 from collections.abc import Iterator
 from typing import TYPE_CHECKING, BinaryIO, Optional, Union
 
+from .file import _GitFile
+
 if TYPE_CHECKING:
     from .object_store import BaseObjectStore
 
@@ -269,7 +271,7 @@ class CommitGraph:
         entry = self.get_entry_by_oid(oid)
         return entry.parents if entry else None
 
-    def write_to_file(self, f: BinaryIO) -> None:
+    def write_to_file(self, f: Union[BinaryIO, _GitFile]) -> None:
         """Write commit graph to file."""
         if not self.entries:
             raise ValueError("Cannot write empty commit graph")

+ 3 - 2
dulwich/config.py

@@ -40,6 +40,7 @@ from collections.abc import (
 from contextlib import suppress
 from pathlib import Path
 from typing import (
+    IO,
     Any,
     BinaryIO,
     Callable,
@@ -50,7 +51,7 @@ from typing import (
     overload,
 )
 
-from .file import GitFile
+from .file import GitFile, _GitFile
 
 ConfigKey = Union[str, bytes, tuple[Union[str, bytes], ...]]
 ConfigValue = Union[str, bytes, bool, int]
@@ -1112,7 +1113,7 @@ class ConfigFile(ConfigDict):
         with GitFile(path_to_use, "wb") as f:
             self.write_to_file(f)
 
-    def write_to_file(self, f: BinaryIO) -> None:
+    def write_to_file(self, f: Union[IO[bytes], _GitFile]) -> None:
         """Write configuration to a file-like object."""
         for section, values in self._values.items():
             try:

+ 16 - 6
dulwich/errors.py

@@ -33,15 +33,22 @@ from typing import Optional, Union
 class ChecksumMismatch(Exception):
     """A checksum didn't match the expected contents."""
 
-    def __init__(self, expected: Union[bytes, str], got: Union[bytes, str], extra: Optional[str] = None) -> None:
+    def __init__(
+        self,
+        expected: Union[bytes, str],
+        got: Union[bytes, str],
+        extra: Optional[str] = None,
+    ) -> None:
         if isinstance(expected, bytes) and len(expected) == 20:
-            expected_str = binascii.hexlify(expected).decode('ascii')
+            expected_str = binascii.hexlify(expected).decode("ascii")
         else:
-            expected_str = expected if isinstance(expected, str) else expected.decode('ascii')
+            expected_str = (
+                expected if isinstance(expected, str) else expected.decode("ascii")
+            )
         if isinstance(got, bytes) and len(got) == 20:
-            got_str = binascii.hexlify(got).decode('ascii')
+            got_str = binascii.hexlify(got).decode("ascii")
         else:
-            got_str = got if isinstance(got, str) else got.decode('ascii')
+            got_str = got if isinstance(got, str) else got.decode("ascii")
         self.expected = expected_str
         self.got = got_str
         self.extra = extra
@@ -148,7 +155,10 @@ class HangupException(GitProtocolError):
         self.stderr_lines = stderr_lines
 
     def __eq__(self, other: object) -> bool:
-        return isinstance(other, HangupException) and self.stderr_lines == other.stderr_lines
+        return (
+            isinstance(other, HangupException)
+            and self.stderr_lines == other.stderr_lines
+        )
 
 
 class UnexpectedCommandError(GitProtocolError):

+ 27 - 7
dulwich/file.py

@@ -24,7 +24,9 @@
 import os
 import sys
 import warnings
-from typing import Any, ClassVar, Union
+from collections.abc import Iterator
+from types import TracebackType
+from typing import IO, Any, ClassVar, Optional, Union
 
 
 def ensure_dir_exists(dirname: Union[str, bytes, os.PathLike]) -> None:
@@ -58,8 +60,11 @@ def _fancy_rename(oldname: Union[str, bytes], newname: Union[str, bytes]) -> Non
 
 
 def GitFile(
-    filename: Union[str, bytes, os.PathLike], mode: str = "rb", bufsize: int = -1, mask: int = 0o644
-) -> Any:
+    filename: Union[str, bytes, os.PathLike],
+    mode: str = "rb",
+    bufsize: int = -1,
+    mask: int = 0o644,
+) -> Union[IO[bytes], "_GitFile"]:
     """Create a file object that obeys the git file locking protocol.
 
     Returns: a builtin file object or a _GitFile object
@@ -90,7 +95,9 @@ def GitFile(
 class FileLocked(Exception):
     """File is already locked."""
 
-    def __init__(self, filename: Union[str, bytes, os.PathLike], lockfilename: Union[str, bytes]) -> None:
+    def __init__(
+        self, filename: Union[str, bytes, os.PathLike], lockfilename: Union[str, bytes]
+    ) -> None:
         self.filename = filename
         self.lockfilename = lockfilename
         super().__init__(filename, lockfilename)
@@ -132,7 +139,11 @@ class _GitFile:
     }
 
     def __init__(
-        self, filename: Union[str, bytes, os.PathLike], mode: str, bufsize: int, mask: int
+        self,
+        filename: Union[str, bytes, os.PathLike],
+        mode: str,
+        bufsize: int,
+        mask: int,
     ) -> None:
         # Convert PathLike to str/bytes for our internal use
         self._filename: Union[str, bytes] = os.fspath(filename)
@@ -154,6 +165,10 @@ class _GitFile:
         for method in self.PROXY_METHODS:
             setattr(self, method, getattr(self._file, method))
 
+    def __iter__(self) -> Iterator[bytes]:
+        """Iterate over lines in the file."""
+        return iter(self._file)
+
     def abort(self) -> None:
         """Close and discard the lockfile without overwriting the target.
 
@@ -208,13 +223,18 @@ class _GitFile:
     def __enter__(self) -> "_GitFile":
         return self
 
-    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
+    def __exit__(
+        self,
+        exc_type: Optional[type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[TracebackType],
+    ) -> None:
         if exc_type is not None:
             self.abort()
         else:
             self.close()
 
-    def __getattr__(self, name: str) -> Any:
+    def __getattr__(self, name: str) -> Any:  # noqa: ANN401
         """Proxy property calls to the underlying file."""
         if name in self.PROXY_PROPERTIES:
             return getattr(self._file, name)

+ 3 - 1
dulwich/filters.py

@@ -119,7 +119,9 @@ class ProcessFilterDriver:
 class FilterRegistry:
     """Registry for filter drivers."""
 
-    def __init__(self, config: Optional["StackedConfig"] = None, repo: Optional["Repo"] = None) -> None:
+    def __init__(
+        self, config: Optional["StackedConfig"] = None, repo: Optional["Repo"] = None
+    ) -> None:
         self.config = config
         self.repo = repo
         self._drivers: dict[str, FilterDriver] = {}

+ 2 - 2
dulwich/hooks.py

@@ -31,7 +31,7 @@ from .errors import HookError
 class Hook:
     """Generic hook object."""
 
-    def execute(self, *args: Any) -> Any:
+    def execute(self, *args: Any) -> Any:  # noqa: ANN401
         """Execute the hook with the given args.
 
         Args:
@@ -86,7 +86,7 @@ class ShellHook(Hook):
 
         self.cwd = cwd
 
-    def execute(self, *args: Any) -> Any:
+    def execute(self, *args: Any) -> Any:  # noqa: ANN401
         """Execute the hook with given args."""
         if len(args) != self.numparam:
             raise HookError(

+ 3 - 1
dulwich/index.py

@@ -2448,7 +2448,9 @@ class locked_index:
         self._path = path
 
     def __enter__(self) -> Index:
-        self._file = GitFile(self._path, "wb")
+        f = GitFile(self._path, "wb")
+        assert isinstance(f, _GitFile)  # GitFile in write mode always returns _GitFile
+        self._file = f
         self._index = Index(self._path)
         return self._index
 

+ 2 - 2
dulwich/lfs.py

@@ -32,7 +32,7 @@ from urllib.request import Request, urlopen
 
 if TYPE_CHECKING:
     import urllib3
-    
+
     from .config import Config
     from .repo import Repo
 
@@ -319,7 +319,7 @@ class LFSClient:
         """
         self._base_url = url.rstrip("/") + "/"  # Ensure trailing slash for urljoin
         self.config = config
-        self._pool_manager = None
+        self._pool_manager: Optional[urllib3.PoolManager] = None
 
     @classmethod
     def from_config(cls, config: "Config") -> Optional["LFSClient"]:

+ 4 - 1
dulwich/object_store.py

@@ -41,7 +41,7 @@ from typing import (
 )
 
 from .errors import NotTreeError
-from .file import GitFile
+from .file import GitFile, _GitFile
 from .objects import (
     S_ISGITLINK,
     ZERO_SHA,
@@ -1505,6 +1505,9 @@ class DiskObjectStore(PackBasedObjectStore):
                 # Write using GitFile for atomic operation
                 graph_path = os.path.join(info_dir, "commit-graph")
                 with GitFile(graph_path, "wb") as f:
+                    assert isinstance(
+                        f, _GitFile
+                    )  # GitFile in write mode always returns _GitFile
                     graph.write_to_file(f)
 
             # Clear cached commit graph so it gets reloaded

+ 89 - 48
dulwich/pack.py

@@ -54,6 +54,9 @@ from itertools import chain
 from os import SEEK_CUR, SEEK_END
 from struct import unpack_from
 from typing import (
+    IO,
+    TYPE_CHECKING,
+    Any,
     BinaryIO,
     Callable,
     Generic,
@@ -70,13 +73,16 @@ except ImportError:
 else:
     has_mmap = True
 
+if TYPE_CHECKING:
+    from .commit_graph import CommitGraph
+
 # For some reason the above try, except fails to set has_mmap = False for plan9
 if sys.platform == "Plan9":
     has_mmap = False
 
 from . import replace_me
 from .errors import ApplyDeltaError, ChecksumMismatch
-from .file import GitFile
+from .file import GitFile, _GitFile
 from .lru_cache import LRUSizeCache
 from .objects import ObjectID, ShaFile, hex_to_sha, object_header, sha_to_hex
 
@@ -104,7 +110,7 @@ PackHint = tuple[int, Optional[bytes]]
 class UnresolvedDeltas(Exception):
     """Delta objects could not be resolved."""
 
-    def __init__(self, shas) -> None:
+    def __init__(self, shas: list[bytes]) -> None:
         self.shas = shas
 
 
@@ -129,7 +135,7 @@ class ObjectContainer(Protocol):
     def __getitem__(self, sha1: bytes) -> ShaFile:
         """Retrieve an object."""
 
-    def get_commit_graph(self):
+    def get_commit_graph(self) -> Optional["CommitGraph"]:
         """Get the commit graph for this object store.
 
         Returns:
@@ -186,7 +192,7 @@ def take_msb_bytes(
 
 
 class PackFileDisappeared(Exception):
-    def __init__(self, obj) -> None:
+    def __init__(self, obj: object) -> None:
         self.obj = obj
 
 
@@ -219,19 +225,24 @@ class UnpackedObject:
     delta_base: Union[None, bytes, int]
     decomp_chunks: list[bytes]
     comp_chunks: Optional[list[bytes]]
+    decomp_len: Optional[int]
+    crc32: Optional[int]
+    offset: Optional[int]
+    pack_type_num: int
+    _sha: Optional[bytes]
 
     # TODO(dborowitz): read_zlib_chunks and unpack_object could very well be
     # methods of this object.
     def __init__(
         self,
-        pack_type_num,
+        pack_type_num: int,
         *,
-        delta_base=None,
-        decomp_len=None,
-        crc32=None,
-        sha=None,
-        decomp_chunks=None,
-        offset=None,
+        delta_base: Union[None, bytes, int] = None,
+        decomp_len: Optional[int] = None,
+        crc32: Optional[int] = None,
+        sha: Optional[bytes] = None,
+        decomp_chunks: Optional[list[bytes]] = None,
+        offset: Optional[int] = None,
     ) -> None:
         self.offset = offset
         self._sha = sha
@@ -253,13 +264,13 @@ class UnpackedObject:
             self.obj_chunks = self.decomp_chunks
             self.delta_base = delta_base
 
-    def sha(self):
+    def sha(self) -> bytes:
         """Return the binary SHA of this object."""
         if self._sha is None:
             self._sha = obj_sha(self.obj_type_num, self.obj_chunks)
         return self._sha
 
-    def sha_file(self):
+    def sha_file(self) -> ShaFile:
         """Return a ShaFile from this object."""
         assert self.obj_type_num is not None and self.obj_chunks is not None
         return ShaFile.from_raw_chunks(self.obj_type_num, self.obj_chunks)
@@ -274,7 +285,7 @@ class UnpackedObject:
         else:
             return self.decomp_chunks
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         if not isinstance(other, UnpackedObject):
             return False
         for slot in self.__slots__:
@@ -282,7 +293,7 @@ class UnpackedObject:
                 return False
         return True
 
-    def __ne__(self, other):
+    def __ne__(self, other: object) -> bool:
         return not (self == other)
 
     def __repr__(self) -> str:
@@ -322,7 +333,7 @@ def read_zlib_chunks(
     Raises:
       zlib.error: if a decompression error occurred.
     """
-    if unpacked.decomp_len <= -1:
+    if unpacked.decomp_len is None or unpacked.decomp_len <= -1:
         raise ValueError("non-negative zlib data stream size expected")
     decomp_obj = zlib.decompressobj()
 
@@ -361,7 +372,7 @@ def read_zlib_chunks(
     return unused
 
 
-def iter_sha1(iter):
+def iter_sha1(iter: Iterable[bytes]) -> bytes:
     """Return the hexdigest of the SHA1 over a set of names.
 
     Args:
@@ -374,7 +385,7 @@ def iter_sha1(iter):
     return sha.hexdigest().encode("ascii")
 
 
-def load_pack_index(path: Union[str, os.PathLike]):
+def load_pack_index(path: Union[str, os.PathLike]) -> "PackIndex":
     """Load an index file by path.
 
     Args:
@@ -385,7 +396,9 @@ def load_pack_index(path: Union[str, os.PathLike]):
         return load_pack_index_file(path, f)
 
 
-def _load_file_contents(f, size=None):
+def _load_file_contents(
+    f: Union[IO[bytes], _GitFile], size: Optional[int] = None
+) -> tuple[Union[bytes, Any], int]:
     try:
         fd = f.fileno()
     except (UnsupportedOperation, AttributeError):
@@ -402,12 +415,14 @@ def _load_file_contents(f, size=None):
                 pass
             else:
                 return contents, size
-    contents = f.read()
-    size = len(contents)
-    return contents, size
+    contents_bytes = f.read()
+    size = len(contents_bytes)
+    return contents_bytes, size
 
 
-def load_pack_index_file(path: Union[str, os.PathLike], f):
+def load_pack_index_file(
+    path: Union[str, os.PathLike], f: Union[IO[bytes], _GitFile]
+) -> "PackIndex":
     """Load an index file from a file-like object.
 
     Args:
@@ -428,7 +443,9 @@ def load_pack_index_file(path: Union[str, os.PathLike], f):
         return PackIndex1(path, file=f, contents=contents, size=size)
 
 
-def bisect_find_sha(start, end, sha, unpack_name):
+def bisect_find_sha(
+    start: int, end: int, sha: bytes, unpack_name: Callable[[int], bytes]
+) -> Optional[int]:
     """Find a SHA in a data blob with sorted SHAs.
 
     Args:
@@ -465,7 +482,7 @@ class PackIndex:
     hash_algorithm = 1
     hash_size = 20
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         if not isinstance(other, PackIndex):
             return False
 
@@ -476,7 +493,7 @@ class PackIndex:
                 return False
         return True
 
-    def __ne__(self, other):
+    def __ne__(self, other: object) -> bool:
         return not self.__eq__(other)
 
     def __len__(self) -> int:
@@ -495,10 +512,10 @@ class PackIndex:
         """
         raise NotImplementedError(self.iterentries)
 
-    def get_pack_checksum(self) -> bytes:
+    def get_pack_checksum(self) -> Optional[bytes]:
         """Return the SHA1 checksum stored for the corresponding packfile.
 
-        Returns: 20-byte binary digest
+        Returns: 20-byte binary digest, or None if not available
         """
         raise NotImplementedError(self.get_pack_checksum)
 
@@ -552,7 +569,11 @@ class PackIndex:
 class MemoryPackIndex(PackIndex):
     """Pack index that is stored entirely in memory."""
 
-    def __init__(self, entries, pack_checksum=None) -> None:
+    def __init__(
+        self,
+        entries: list[tuple[bytes, int, Optional[int]]],
+        pack_checksum: Optional[bytes] = None,
+    ) -> None:
         """Create a new MemoryPackIndex.
 
         Args:
@@ -567,33 +588,35 @@ class MemoryPackIndex(PackIndex):
         self._entries = entries
         self._pack_checksum = pack_checksum
 
-    def get_pack_checksum(self):
+    def get_pack_checksum(self) -> Optional[bytes]:
         return self._pack_checksum
 
     def __len__(self) -> int:
         return len(self._entries)
 
-    def object_offset(self, sha):
+    def object_offset(self, sha: bytes) -> int:
         if len(sha) == 40:
             sha = hex_to_sha(sha)
         return self._by_sha[sha]
 
-    def object_sha1(self, offset):
+    def object_sha1(self, offset: int) -> bytes:
         return self._by_offset[offset]
 
-    def _itersha(self):
+    def _itersha(self) -> Iterator[bytes]:
         return iter(self._by_sha)
 
-    def iterentries(self):
+    def iterentries(self) -> Iterator[PackIndexEntry]:
         return iter(self._entries)
 
     @classmethod
-    def for_pack(cls, pack):
-        return MemoryPackIndex(pack.sorted_entries(), pack.calculate_checksum())
+    def for_pack(cls, pack: "Pack") -> "MemoryPackIndex":
+        return MemoryPackIndex(
+            list(pack.sorted_entries()), pack.data.get_stored_checksum()
+        )
 
     @classmethod
-    def clone(cls, other_index):
-        return cls(other_index.iterentries(), other_index.get_pack_checksum())
+    def clone(cls, other_index: "PackIndex") -> "MemoryPackIndex":
+        return cls(list(other_index.iterentries()), other_index.get_pack_checksum())
 
 
 class FilePackIndex(PackIndex):
@@ -610,7 +633,13 @@ class FilePackIndex(PackIndex):
 
     _fan_out_table: list[int]
 
-    def __init__(self, filename, file=None, contents=None, size=None) -> None:
+    def __init__(
+        self,
+        filename: Union[str, os.PathLike],
+        file: Optional[BinaryIO] = None,
+        contents: Optional[Union[bytes, "mmap.mmap"]] = None,
+        size: Optional[int] = None,
+    ) -> None:
         """Create a pack index object.
 
         Provide it with the name of the index file to consider, and it will map
@@ -626,13 +655,14 @@ class FilePackIndex(PackIndex):
         if contents is None:
             self._contents, self._size = _load_file_contents(self._file, size)
         else:
-            self._contents, self._size = (contents, size)
+            self._contents = contents
+            self._size = size if size is not None else len(contents)
 
     @property
     def path(self) -> str:
-        return self._filename
+        return os.fspath(self._filename)
 
-    def __eq__(self, other):
+    def __eq__(self, other: object) -> bool:
         # Quick optimization:
         if (
             isinstance(other, FilePackIndex)
@@ -644,8 +674,9 @@ class FilePackIndex(PackIndex):
 
     def close(self) -> None:
         self._file.close()
-        if getattr(self._contents, "close", None) is not None:
-            self._contents.close()
+        close_fn = getattr(self._contents, "close", None)
+        if close_fn is not None:
+            close_fn()
 
     def __len__(self) -> int:
         """Return the number of entries in this pack index."""
@@ -684,7 +715,7 @@ class FilePackIndex(PackIndex):
         for i in range(len(self)):
             yield self._unpack_entry(i)
 
-    def _read_fan_out_table(self, start_offset: int):
+    def _read_fan_out_table(self, start_offset: int) -> list[int]:
         ret = []
         for i in range(0x100):
             fanout_entry = self._contents[
@@ -1604,6 +1635,9 @@ class DeltaChainIterator(Generic[T]):
             done.add(off)
             base_ofs = None
             if unpacked.pack_type_num == OFS_DELTA:
+                assert unpacked.offset is not None
+                assert unpacked.delta_base is not None
+                assert isinstance(unpacked.delta_base, int)
                 base_ofs = unpacked.offset - unpacked.delta_base
             elif unpacked.pack_type_num == REF_DELTA:
                 with suppress(KeyError):
@@ -1616,7 +1650,10 @@ class DeltaChainIterator(Generic[T]):
     def record(self, unpacked: UnpackedObject) -> None:
         type_num = unpacked.pack_type_num
         offset = unpacked.offset
+        assert offset is not None
         if type_num == OFS_DELTA:
+            assert unpacked.delta_base is not None
+            assert isinstance(unpacked.delta_base, int)
             base_offset = offset - unpacked.delta_base
             self._pending_ofs[base_offset].append(offset)
         elif type_num == REF_DELTA:
@@ -1690,6 +1727,7 @@ class DeltaChainIterator(Generic[T]):
             unpacked = self._resolve_object(offset, obj_type_num, base_chunks)
             yield self._result(unpacked)
 
+            assert unpacked.offset is not None
             unblocked = chain(
                 self._pending_ofs.pop(unpacked.offset, []),
                 self._pending_ref.pop(unpacked.sha(), []),
@@ -2858,7 +2896,10 @@ class Pack:
         )
         idx_stored_checksum = self.index.get_pack_checksum()
         data_stored_checksum = self.data.get_stored_checksum()
-        if idx_stored_checksum != data_stored_checksum:
+        if (
+            idx_stored_checksum is not None
+            and idx_stored_checksum != data_stored_checksum
+        ):
             raise ChecksumMismatch(
                 sha_to_hex(idx_stored_checksum),
                 sha_to_hex(data_stored_checksum),
@@ -2955,7 +2996,7 @@ class Pack:
                 yield child
         assert not ofs_pending
         if not allow_missing and todo:
-            raise UnresolvedDeltas(todo)
+            raise UnresolvedDeltas(list(todo))
 
     def iter_unpacked(self, include_comp=False):
         ofs_to_entries = {
@@ -3026,7 +3067,7 @@ class Pack:
                 base_offset, base_type, base_obj = get_ref(basename)
                 assert isinstance(base_type, int)
                 if base_offset == prev_offset:  # object is based on itself
-                    raise UnresolvedDeltas(sha_to_hex(basename))
+                    raise UnresolvedDeltas([basename])
             delta_stack.append((prev_offset, base_type, delta))
 
         # Now grab the base object (mustn't be a delta) and apply the

+ 5 - 2
dulwich/reflog.py

@@ -23,8 +23,9 @@
 
 import collections
 from collections.abc import Generator
-from typing import BinaryIO, Optional, Union
+from typing import IO, BinaryIO, Optional, Union
 
+from .file import _GitFile
 from .objects import ZERO_SHA, format_timezone, parse_timezone
 
 Entry = collections.namedtuple(
@@ -89,7 +90,9 @@ def parse_reflog_line(line: bytes) -> Entry:
     )
 
 
-def read_reflog(f: BinaryIO) -> Generator[Entry, None, None]:
+def read_reflog(
+    f: Union[BinaryIO, IO[bytes], _GitFile],
+) -> Generator[Entry, None, None]:
     """Read reflog.
 
     Args:

+ 11 - 1
dulwich/refs.py

@@ -876,6 +876,7 @@ class DiskRefsContainer(RefsContainer):
         self._check_refname(other)
         filename = self.refpath(name)
         f = GitFile(filename, "wb")
+        assert isinstance(f, _GitFile)  # GitFile in write mode always returns _GitFile
         try:
             f.write(SYMREF + other + b"\n")
             sha = self.follow(name)[-1]
@@ -935,6 +936,9 @@ class DiskRefsContainer(RefsContainer):
 
         ensure_dir_exists(os.path.dirname(filename))
         with GitFile(filename, "wb") as f:
+            assert isinstance(
+                f, _GitFile
+            )  # GitFile in write mode always returns _GitFile
             if old_ref is not None:
                 try:
                     # read again while holding the lock to handle race conditions
@@ -1006,6 +1010,9 @@ class DiskRefsContainer(RefsContainer):
         filename = self.refpath(realname)
         ensure_dir_exists(os.path.dirname(filename))
         with GitFile(filename, "wb") as f:
+            assert isinstance(
+                f, _GitFile
+            )  # GitFile in write mode always returns _GitFile
             if os.path.exists(filename) or name in self.get_packed_refs():
                 f.abort()
                 return False
@@ -1051,6 +1058,7 @@ class DiskRefsContainer(RefsContainer):
         filename = self.refpath(name)
         ensure_dir_exists(os.path.dirname(filename))
         f = GitFile(filename, "wb")
+        assert isinstance(f, _GitFile)  # GitFile in write mode always returns _GitFile
         try:
             if old_ref is not None:
                 orig_ref = self.read_loose_ref(name)
@@ -1401,7 +1409,9 @@ class locked_ref:
 
         filename = self._refs_container.refpath(self._realname)
         ensure_dir_exists(os.path.dirname(filename))
-        self._file = GitFile(filename, "wb")
+        f = GitFile(filename, "wb")
+        assert isinstance(f, _GitFile)  # GitFile in write mode always returns _GitFile
+        self._file = f
         return self
 
     def __exit__(