Просмотр исходного кода

Add extensive type annotations to dulwich package

Jelmer Vernooij 5 месяцев назад
Родитель
Сommit
0c63eec777
6 измененных файлов с 331 добавлено и 474 удалено
  1. 31 31
      dulwich/cli.py
  2. 35 72
      dulwich/client.py
  3. 61 76
      dulwich/object_store.py
  4. 86 95
      dulwich/pack.py
  5. 92 174
      dulwich/refs.py
  6. 26 26
      dulwich/repo.py

+ 31 - 31
dulwich/cli.py

@@ -55,7 +55,7 @@ class CommitMessageError(Exception):
     """Raised when there's an issue with the commit message."""
     """Raised when there's an issue with the commit message."""
 
 
 
 
-def signal_int(signal, frame) -> None:
+def signal_int(signal: int, frame) -> None:
     """Handle interrupt signal by exiting.
     """Handle interrupt signal by exiting.
 
 
     Args:
     Args:
@@ -65,7 +65,7 @@ def signal_int(signal, frame) -> None:
     sys.exit(1)
     sys.exit(1)
 
 
 
 
-def signal_quit(signal, frame) -> None:
+def signal_quit(signal: int, frame) -> None:
     """Handle quit signal by entering debugger.
     """Handle quit signal by entering debugger.
 
 
     Args:
     Args:
@@ -77,7 +77,7 @@ def signal_quit(signal, frame) -> None:
     pdb.set_trace()
     pdb.set_trace()
 
 
 
 
-def parse_relative_time(time_str):
+def parse_relative_time(time_str: str) -> int:
     """Parse a relative time string like '2 weeks ago' into seconds.
     """Parse a relative time string like '2 weeks ago' into seconds.
 
 
     Args:
     Args:
@@ -126,7 +126,7 @@ def parse_relative_time(time_str):
         raise
         raise
 
 
 
 
-def format_bytes(bytes):
+def format_bytes(bytes: int) -> str:
     """Format bytes as human-readable string.
     """Format bytes as human-readable string.
 
 
     Args:
     Args:
@@ -142,7 +142,7 @@ def format_bytes(bytes):
     return f"{bytes:.1f} TB"
     return f"{bytes:.1f} TB"
 
 
 
 
-def launch_editor(template_content=b""):
+def launch_editor(template_content: bytes = b"") -> bytes:
     """Launch an editor for the user to enter text.
     """Launch an editor for the user to enter text.
 
 
     Args:
     Args:
@@ -176,7 +176,7 @@ def launch_editor(template_content=b""):
 class PagerBuffer:
 class PagerBuffer:
     """Binary buffer wrapper for Pager to mimic sys.stdout.buffer."""
     """Binary buffer wrapper for Pager to mimic sys.stdout.buffer."""
 
 
-    def __init__(self, pager):
+    def __init__(self, pager: "Pager") -> None:
         """Initialize PagerBuffer.
         """Initialize PagerBuffer.
 
 
         Args:
         Args:
@@ -184,40 +184,40 @@ class PagerBuffer:
         """
         """
         self.pager = pager
         self.pager = pager
 
 
-    def write(self, data: bytes):
+    def write(self, data: bytes) -> int:
         """Write bytes to pager."""
         """Write bytes to pager."""
         if isinstance(data, bytes):
         if isinstance(data, bytes):
             text = data.decode("utf-8", errors="replace")
             text = data.decode("utf-8", errors="replace")
             return self.pager.write(text)
             return self.pager.write(text)
         return self.pager.write(data)
         return self.pager.write(data)
 
 
-    def flush(self):
+    def flush(self) -> None:
         """Flush the pager."""
         """Flush the pager."""
         return self.pager.flush()
         return self.pager.flush()
 
 
-    def writelines(self, lines):
+    def writelines(self, lines) -> None:
         """Write multiple lines to pager."""
         """Write multiple lines to pager."""
         for line in lines:
         for line in lines:
             self.write(line)
             self.write(line)
 
 
-    def readable(self):
+    def readable(self) -> bool:
         """Return whether the buffer is readable (it's not)."""
         """Return whether the buffer is readable (it's not)."""
         return False
         return False
 
 
-    def writable(self):
+    def writable(self) -> bool:
         """Return whether the buffer is writable."""
         """Return whether the buffer is writable."""
         return not self.pager._closed
         return not self.pager._closed
 
 
-    def seekable(self):
+    def seekable(self) -> bool:
         """Return whether the buffer is seekable (it's not)."""
         """Return whether the buffer is seekable (it's not)."""
         return False
         return False
 
 
-    def close(self):
+    def close(self) -> None:
         """Close the pager."""
         """Close the pager."""
         return self.pager.close()
         return self.pager.close()
 
 
     @property
     @property
-    def closed(self):
+    def closed(self) -> bool:
         """Return whether the buffer is closed."""
         """Return whether the buffer is closed."""
         return self.pager.closed
         return self.pager.closed
 
 
@@ -225,7 +225,7 @@ class PagerBuffer:
 class Pager:
 class Pager:
     """File-like object that pages output through external pager programs."""
     """File-like object that pages output through external pager programs."""
 
 
-    def __init__(self, pager_cmd="cat"):
+    def __init__(self, pager_cmd: str = "cat") -> None:
         """Initialize Pager.
         """Initialize Pager.
 
 
         Args:
         Args:
@@ -241,7 +241,7 @@ class Pager:
         """Get the pager command to use."""
         """Get the pager command to use."""
         return self.pager_cmd
         return self.pager_cmd
 
 
-    def _ensure_pager_started(self):
+    def _ensure_pager_started(self) -> None:
         """Start the pager process if not already started."""
         """Start the pager process if not already started."""
         if self.pager_process is None and not self._closed:
         if self.pager_process is None and not self._closed:
             try:
             try:
@@ -280,7 +280,7 @@ class Pager:
             # No pager available, write directly to stdout
             # No pager available, write directly to stdout
             return sys.stdout.write(text)
             return sys.stdout.write(text)
 
 
-    def flush(self):
+    def flush(self) -> None:
         """Flush the pager."""
         """Flush the pager."""
         if self._closed or self._pager_died:
         if self._closed or self._pager_died:
             return
             return
@@ -293,7 +293,7 @@ class Pager:
         else:
         else:
             sys.stdout.flush()
             sys.stdout.flush()
 
 
-    def close(self):
+    def close(self) -> None:
         """Close the pager."""
         """Close the pager."""
         if self._closed:
         if self._closed:
             return
             return
@@ -308,16 +308,16 @@ class Pager:
                 pass
                 pass
             self.pager_process = None
             self.pager_process = None
 
 
-    def __enter__(self):
+    def __enter__(self) -> "Pager":
         """Context manager entry."""
         """Context manager entry."""
         return self
         return self
 
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
         """Context manager exit."""
         """Context manager exit."""
         self.close()
         self.close()
 
 
     # Additional file-like methods for compatibility
     # Additional file-like methods for compatibility
-    def writelines(self, lines):
+    def writelines(self, lines) -> None:
         """Write a list of lines to the pager."""
         """Write a list of lines to the pager."""
         if self._pager_died:
         if self._pager_died:
             return
             return
@@ -325,19 +325,19 @@ class Pager:
             self.write(line)
             self.write(line)
 
 
     @property
     @property
-    def closed(self):
+    def closed(self) -> bool:
         """Return whether the pager is closed."""
         """Return whether the pager is closed."""
         return self._closed
         return self._closed
 
 
-    def readable(self):
+    def readable(self) -> bool:
         """Return whether the pager is readable (it's not)."""
         """Return whether the pager is readable (it's not)."""
         return False
         return False
 
 
-    def writable(self):
+    def writable(self) -> bool:
         """Return whether the pager is writable."""
         """Return whether the pager is writable."""
         return not self._closed
         return not self._closed
 
 
-    def seekable(self):
+    def seekable(self) -> bool:
         """Return whether the pager is seekable (it's not)."""
         """Return whether the pager is seekable (it's not)."""
         return False
         return False
 
 
@@ -345,7 +345,7 @@ class Pager:
 class _StreamContextAdapter:
 class _StreamContextAdapter:
     """Adapter to make streams work with context manager protocol."""
     """Adapter to make streams work with context manager protocol."""
 
 
-    def __init__(self, stream):
+    def __init__(self, stream) -> None:
         self.stream = stream
         self.stream = stream
         # Expose buffer if it exists
         # Expose buffer if it exists
         if hasattr(stream, "buffer"):
         if hasattr(stream, "buffer"):
@@ -356,15 +356,15 @@ class _StreamContextAdapter:
     def __enter__(self):
     def __enter__(self):
         return self.stream
         return self.stream
 
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(self, exc_type, exc_val, exc_tb) -> None:
         # For stdout/stderr, we don't close them
         # For stdout/stderr, we don't close them
         pass
         pass
 
 
-    def __getattr__(self, name):
+    def __getattr__(self, name: str):
         return getattr(self.stream, name)
         return getattr(self.stream, name)
 
 
 
 
-def get_pager(config=None, cmd_name=None):
+def get_pager(config=None, cmd_name: Optional[str] = None):
     """Get a pager instance if paging should be used, otherwise return sys.stdout.
     """Get a pager instance if paging should be used, otherwise return sys.stdout.
 
 
     Args:
     Args:
@@ -447,12 +447,12 @@ def get_pager(config=None, cmd_name=None):
     return Pager(pager_cmd)
     return Pager(pager_cmd)
 
 
 
 
-def disable_pager():
+def disable_pager() -> None:
     """Disable pager for this session."""
     """Disable pager for this session."""
     get_pager._disabled = True
     get_pager._disabled = True
 
 
 
 
-def enable_pager():
+def enable_pager() -> None:
     """Enable pager for this session."""
     """Enable pager for this session."""
     get_pager._disabled = False
     get_pager._disabled = False
 
 

+ 35 - 72
dulwich/client.py

@@ -53,6 +53,7 @@ from io import BufferedReader, BytesIO
 from typing import (
 from typing import (
     IO,
     IO,
     TYPE_CHECKING,
     TYPE_CHECKING,
+    Any,
     Callable,
     Callable,
     ClassVar,
     ClassVar,
     Optional,
     Optional,
@@ -149,7 +150,7 @@ logger = logging.getLogger(__name__)
 class InvalidWants(Exception):
 class InvalidWants(Exception):
     """Invalid wants."""
     """Invalid wants."""
 
 
-    def __init__(self, wants) -> None:
+    def __init__(self, wants: Any) -> None:
         """Initialize InvalidWants exception.
         """Initialize InvalidWants exception.
 
 
         Args:
         Args:
@@ -163,7 +164,7 @@ class InvalidWants(Exception):
 class HTTPUnauthorized(Exception):
 class HTTPUnauthorized(Exception):
     """Raised when authentication fails."""
     """Raised when authentication fails."""
 
 
-    def __init__(self, www_authenticate, url) -> None:
+    def __init__(self, www_authenticate: Any, url: str) -> None:
         """Initialize HTTPUnauthorized exception.
         """Initialize HTTPUnauthorized exception.
 
 
         Args:
         Args:
@@ -178,7 +179,7 @@ class HTTPUnauthorized(Exception):
 class HTTPProxyUnauthorized(Exception):
 class HTTPProxyUnauthorized(Exception):
     """Raised when proxy authentication fails."""
     """Raised when proxy authentication fails."""
 
 
-    def __init__(self, proxy_authenticate, url) -> None:
+    def __init__(self, proxy_authenticate: Any, url: str) -> None:
         """Initialize HTTPProxyUnauthorized exception.
         """Initialize HTTPProxyUnauthorized exception.
 
 
         Args:
         Args:
@@ -190,12 +191,12 @@ class HTTPProxyUnauthorized(Exception):
         self.url = url
         self.url = url
 
 
 
 
-def _fileno_can_read(fileno):
+def _fileno_can_read(fileno: int) -> bool:
     """Check if a file descriptor is readable."""
     """Check if a file descriptor is readable."""
     return len(select.select([fileno], [], [], 0)[0]) > 0
     return len(select.select([fileno], [], [], 0)[0]) > 0
 
 
 
 
-def _win32_peek_avail(handle):
+def _win32_peek_avail(handle: Any) -> int:
     """Wrapper around PeekNamedPipe to check how many bytes are available."""
     """Wrapper around PeekNamedPipe to check how many bytes are available."""
     from ctypes import byref, windll, wintypes
     from ctypes import byref, windll, wintypes
 
 
@@ -233,7 +234,7 @@ class ReportStatusParser:
         self._pack_status = None
         self._pack_status = None
         self._ref_statuses: list[bytes] = []
         self._ref_statuses: list[bytes] = []
 
 
-    def check(self):
+    def check(self) -> Any:
         """Check if there were any errors and, if so, raise exceptions.
         """Check if there were any errors and, if so, raise exceptions.
 
 
         Raises:
         Raises:
@@ -257,7 +258,7 @@ class ReportStatusParser:
             else:
             else:
                 raise GitProtocolError(f"invalid ref status {status!r}")
                 raise GitProtocolError(f"invalid ref status {status!r}")
 
 
-    def handle_packet(self, pkt) -> None:
+    def handle_packet(self, pkt: Optional[bytes]) -> None:
         """Handle a packet.
         """Handle a packet.
 
 
         Raises:
         Raises:
@@ -276,13 +277,7 @@ class ReportStatusParser:
             self._ref_statuses.append(ref_status)
             self._ref_statuses.append(ref_status)
 
 
 
 
-def negotiate_protocol_version(proto) -> int:
-    """Negotiate the protocol version to use.
-
-    Args:
-      proto: Protocol instance to negotiate with
-    Returns: Protocol version (0, 1, or 2)
-    """
+def negotiate_protocol_version(proto: Any) -> int:
     pkt = proto.read_pkt_line()
     pkt = proto.read_pkt_line()
     if pkt is not None and pkt.strip() == b"version 2":
     if pkt is not None and pkt.strip() == b"version 2":
         return 2
         return 2
@@ -290,13 +285,7 @@ def negotiate_protocol_version(proto) -> int:
     return 0
     return 0
 
 
 
 
-def read_server_capabilities(pkt_seq):
-    """Read server capabilities from a packet sequence.
-
-    Args:
-      pkt_seq: Sequence of packets from server
-    Returns: Set of server capabilities
-    """
+def read_server_capabilities(pkt_seq: Any) -> set:
     server_capabilities = []
     server_capabilities = []
     for pkt in pkt_seq:
     for pkt in pkt_seq:
         server_capabilities.append(pkt)
         server_capabilities.append(pkt)
@@ -304,7 +293,7 @@ def read_server_capabilities(pkt_seq):
 
 
 
 
 def read_pkt_refs_v2(
 def read_pkt_refs_v2(
-    pkt_seq,
+    pkt_seq: Any,
 ) -> tuple[dict[bytes, bytes], dict[bytes, bytes], dict[bytes, bytes]]:
 ) -> tuple[dict[bytes, bytes], dict[bytes, bytes], dict[bytes, bytes]]:
     """Read packet references in protocol v2 format.
     """Read packet references in protocol v2 format.
 
 
@@ -334,13 +323,7 @@ def read_pkt_refs_v2(
     return refs, symrefs, peeled
     return refs, symrefs, peeled
 
 
 
 
-def read_pkt_refs_v1(pkt_seq) -> tuple[dict[bytes, bytes], set[bytes]]:
-    """Read packet references in protocol v1 format.
-
-    Args:
-      pkt_seq: Sequence of packets
-    Returns: Tuple of (refs dict, server capabilities set)
-    """
+def read_pkt_refs_v1(pkt_seq: Any) -> tuple[dict[bytes, bytes], set[bytes]]:
     server_capabilities = None
     server_capabilities = None
     refs = {}
     refs = {}
     # Receive refs from server
     # Receive refs from server
@@ -389,11 +372,11 @@ class _DeprecatedDictProxy:
             stacklevel=3,
             stacklevel=3,
         )
         )
 
 
-    def __contains__(self, name) -> bool:
+    def __contains__(self, name: bytes) -> bool:
         self._warn_deprecated()
         self._warn_deprecated()
         return name in self.refs
         return name in self.refs
 
 
-    def __getitem__(self, name):
+    def __getitem__(self, name: bytes) -> bytes:
         self._warn_deprecated()
         self._warn_deprecated()
         return self.refs[name]
         return self.refs[name]
 
 
@@ -401,11 +384,11 @@ class _DeprecatedDictProxy:
         self._warn_deprecated()
         self._warn_deprecated()
         return len(self.refs)
         return len(self.refs)
 
 
-    def __iter__(self):
+    def __iter__(self) -> Any:
         self._warn_deprecated()
         self._warn_deprecated()
         return iter(self.refs)
         return iter(self.refs)
 
 
-    def __getattribute__(self, name):
+    def __getattribute__(self, name: str) -> Any:
         # Avoid infinite recursion by checking against class variable directly
         # Avoid infinite recursion by checking against class variable directly
         if name != "_FORWARDED_ATTRS" and name in type(self)._FORWARDED_ATTRS:
         if name != "_FORWARDED_ATTRS" and name in type(self)._FORWARDED_ATTRS:
             self._warn_deprecated()
             self._warn_deprecated()
@@ -425,7 +408,7 @@ class FetchPackResult(_DeprecatedDictProxy):
     """
     """
 
 
     def __init__(
     def __init__(
-        self, refs, symrefs, agent, new_shallow=None, new_unshallow=None
+        self, refs: dict, symrefs: dict, agent: Optional[bytes], new_shallow: Optional[Any] = None, new_unshallow: Optional[Any] = None
     ) -> None:
     ) -> None:
         """Initialize FetchPackResult.
         """Initialize FetchPackResult.
 
 
@@ -442,8 +425,7 @@ class FetchPackResult(_DeprecatedDictProxy):
         self.new_shallow = new_shallow
         self.new_shallow = new_shallow
         self.new_unshallow = new_unshallow
         self.new_unshallow = new_unshallow
 
 
-    def __eq__(self, other):
-        """Check equality with another FetchPackResult."""
+    def __eq__(self, other: Any) -> bool:
         if isinstance(other, dict):
         if isinstance(other, dict):
             self._warn_deprecated()
             self._warn_deprecated()
             return self.refs == other
             return self.refs == other
@@ -466,7 +448,7 @@ class LsRemoteResult(_DeprecatedDictProxy):
       symrefs: Dictionary with remote symrefs
       symrefs: Dictionary with remote symrefs
     """
     """
 
 
-    def __init__(self, refs, symrefs) -> None:
+    def __init__(self, refs: dict, symrefs: dict) -> None:
         """Initialize LsRemoteResult.
         """Initialize LsRemoteResult.
 
 
         Args:
         Args:
@@ -486,8 +468,7 @@ class LsRemoteResult(_DeprecatedDictProxy):
             stacklevel=3,
             stacklevel=3,
         )
         )
 
 
-    def __eq__(self, other):
-        """Check equality with another LsRemoteResult."""
+    def __eq__(self, other: Any) -> bool:
         if isinstance(other, dict):
         if isinstance(other, dict):
             self._warn_deprecated()
             self._warn_deprecated()
             return self.refs == other
             return self.refs == other
@@ -508,7 +489,7 @@ class SendPackResult(_DeprecatedDictProxy):
         failed to update), or None if it was updated successfully
         failed to update), or None if it was updated successfully
     """
     """
 
 
-    def __init__(self, refs, agent=None, ref_status=None) -> None:
+    def __init__(self, refs: dict, agent: Optional[bytes] = None, ref_status: Optional[dict] = None) -> None:
         """Initialize SendPackResult.
         """Initialize SendPackResult.
 
 
         Args:
         Args:
@@ -520,8 +501,7 @@ class SendPackResult(_DeprecatedDictProxy):
         self.agent = agent
         self.agent = agent
         self.ref_status = ref_status
         self.ref_status = ref_status
 
 
-    def __eq__(self, other):
-        """Check equality with another SendPackResult."""
+    def __eq__(self, other: Any) -> bool:
         if isinstance(other, dict):
         if isinstance(other, dict):
             self._warn_deprecated()
             self._warn_deprecated()
             return self.refs == other
             return self.refs == other
@@ -532,13 +512,7 @@ class SendPackResult(_DeprecatedDictProxy):
         return f"{self.__class__.__name__}({self.refs!r}, {self.agent!r})"
         return f"{self.__class__.__name__}({self.refs!r}, {self.agent!r})"
 
 
 
 
-def _read_shallow_updates(pkt_seq):
-    """Read shallow/unshallow updates from a packet sequence.
-
-    Args:
-      pkt_seq: Sequence of packets
-    Returns: Tuple of (new_shallow set, new_unshallow set)
-    """
+def _read_shallow_updates(pkt_seq: Any) -> tuple[set, set]:
     new_shallow = set()
     new_shallow = set()
     new_unshallow = set()
     new_unshallow = set()
     for pkt in pkt_seq:
     for pkt in pkt_seq:
@@ -558,19 +532,16 @@ def _read_shallow_updates(pkt_seq):
 
 
 
 
 class _v1ReceivePackHeader:
 class _v1ReceivePackHeader:
-    """Handler for v1 receive-pack header."""
-
-    def __init__(self, capabilities, old_refs, new_refs) -> None:
+    def __init__(self, capabilities: list, old_refs: dict, new_refs: dict) -> None:
         self.want: list[bytes] = []
         self.want: list[bytes] = []
         self.have: list[bytes] = []
         self.have: list[bytes] = []
         self._it = self._handle_receive_pack_head(capabilities, old_refs, new_refs)
         self._it = self._handle_receive_pack_head(capabilities, old_refs, new_refs)
         self.sent_capabilities = False
         self.sent_capabilities = False
 
 
-    def __iter__(self):
-        """Iterate over the receive-pack header lines."""
+    def __iter__(self) -> Any:
         return self._it
         return self._it
 
 
-    def _handle_receive_pack_head(self, capabilities, old_refs, new_refs):
+    def _handle_receive_pack_head(self, capabilities: list, old_refs: dict, new_refs: dict) -> Any:
         """Handle the head of a 'git-receive-pack' request.
         """Handle the head of a 'git-receive-pack' request.
 
 
         Args:
         Args:
@@ -632,15 +603,7 @@ def _read_side_band64k_data(pkt_seq: Iterable[bytes]) -> Iterator[tuple[int, byt
         yield channel, pkt[1:]
         yield channel, pkt[1:]
 
 
 
 
-def find_capability(capabilities, key, value):
-    """Find a capability in the list of capabilities.
-
-    Args:
-      capabilities: List of capabilities
-      key: Capability key to search for
-      value: Optional specific value to match
-    Returns: The matching capability or None
-    """
+def find_capability(capabilities: list, key: bytes, value: Optional[bytes]) -> Optional[bytes]:
     for capability in capabilities:
     for capability in capabilities:
         k, v = parse_capability(capability)
         k, v = parse_capability(capability)
         if k != key:
         if k != key:
@@ -651,14 +614,14 @@ def find_capability(capabilities, key, value):
 
 
 
 
 def _handle_upload_pack_head(
 def _handle_upload_pack_head(
-    proto,
-    capabilities,
-    graph_walker,
-    wants,
-    can_read,
+    proto: Any,
+    capabilities: list,
+    graph_walker: Any,
+    wants: list,
+    can_read: Callable,
     depth: Optional[int],
     depth: Optional[int],
-    protocol_version,
-):
+    protocol_version: Optional[int],
+) -> None:
     """Handle the head of a 'git-upload-pack' request.
     """Handle the head of a 'git-upload-pack' request.
 
 
     Args:
     Args:
@@ -773,7 +736,7 @@ def _handle_upload_pack_tail(
         if progress is None:
         if progress is None:
             # Just ignore progress data
             # Just ignore progress data
 
 
-            def progress(x) -> None:
+            def progress(x: bytes) -> None:
                 pass
                 pass
 
 
         for chan, data in _read_side_band64k_data(proto.read_pkt_seq()):
         for chan, data in _read_side_band64k_data(proto.read_pkt_seq()):

+ 61 - 76
dulwich/object_store.py

@@ -95,7 +95,7 @@ PACK_MODE = 0o444 if sys.platform != "win32" else 0o644
 DEFAULT_TEMPFILE_GRACE_PERIOD = 14 * 24 * 60 * 60  # 2 weeks
 DEFAULT_TEMPFILE_GRACE_PERIOD = 14 * 24 * 60 * 60  # 2 weeks
 
 
 
 
-def find_shallow(store, heads, depth):
+def find_shallow(store: 'BaseObjectStore', heads: Any, depth: int) -> tuple:
     """Find shallow commits according to a given depth.
     """Find shallow commits according to a given depth.
 
 
     Args:
     Args:
@@ -110,7 +110,7 @@ def find_shallow(store, heads, depth):
     parents = {}
     parents = {}
     commit_graph = store.get_commit_graph()
     commit_graph = store.get_commit_graph()
 
 
-    def get_parents(sha):
+    def get_parents(sha: bytes) -> list[bytes]:
         result = parents.get(sha, None)
         result = parents.get(sha, None)
         if not result:
         if not result:
             # Try to use commit graph first if available
             # Try to use commit graph first if available
@@ -150,11 +150,11 @@ def find_shallow(store, heads, depth):
 
 
 
 
 def get_depth(
 def get_depth(
-    store,
-    head,
-    get_parents=lambda commit: commit.parents,
-    max_depth=None,
-):
+    store: 'BaseObjectStore',
+    head: bytes,
+    get_parents: Callable = lambda commit: commit.parents,
+    max_depth: Optional[int] = None,
+) -> int:
     """Return the current available depth for the given head.
     """Return the current available depth for the given head.
 
 
     For commits with multiple parents, the largest possible depth will be
     For commits with multiple parents, the largest possible depth will be
@@ -206,17 +206,7 @@ class BaseObjectStore:
     def determine_wants_all(
     def determine_wants_all(
         self, refs: dict[Ref, ObjectID], depth: Optional[int] = None
         self, refs: dict[Ref, ObjectID], depth: Optional[int] = None
     ) -> list[ObjectID]:
     ) -> list[ObjectID]:
-        """Determine all objects that are wanted by the client.
-
-        Args:
-          refs: Dictionary mapping ref names to object IDs
-          depth: Shallow fetch depth (None for full fetch)
-
-        Returns:
-          List of object IDs that are wanted
-        """
-
-        def _want_deepen(sha):
+        def _want_deepen(sha: bytes) -> bool:
             if not depth:
             if not depth:
                 return False
                 return False
             if depth == DEPTH_INFINITE:
             if depth == DEPTH_INFINITE:
@@ -231,7 +221,7 @@ class BaseObjectStore:
             and not sha == ZERO_SHA
             and not sha == ZERO_SHA
         ]
         ]
 
 
-    def contains_loose(self, sha) -> bool:
+    def contains_loose(self, sha: bytes) -> bool:
         """Check if a particular object is present by SHA1 and is loose."""
         """Check if a particular object is present by SHA1 and is loose."""
         raise NotImplementedError(self.contains_loose)
         raise NotImplementedError(self.contains_loose)
 
 
@@ -243,11 +233,11 @@ class BaseObjectStore:
         return self.contains_loose(sha1)
         return self.contains_loose(sha1)
 
 
     @property
     @property
-    def packs(self):
+    def packs(self) -> Any:
         """Iterable of pack objects."""
         """Iterable of pack objects."""
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def get_raw(self, name) -> tuple[int, bytes]:
+    def get_raw(self, name: bytes) -> tuple[int, bytes]:
         """Obtain the raw text for an object.
         """Obtain the raw text for an object.
 
 
         Args:
         Args:
@@ -261,15 +251,15 @@ class BaseObjectStore:
         type_num, uncomp = self.get_raw(sha1)
         type_num, uncomp = self.get_raw(sha1)
         return ShaFile.from_raw_string(type_num, uncomp, sha=sha1)
         return ShaFile.from_raw_string(type_num, uncomp, sha=sha1)
 
 
-    def __iter__(self):
+    def __iter__(self) -> Any:
         """Iterate over the SHAs that are present in this store."""
         """Iterate over the SHAs that are present in this store."""
         raise NotImplementedError(self.__iter__)
         raise NotImplementedError(self.__iter__)
 
 
-    def add_object(self, obj) -> None:
+    def add_object(self, obj: Any) -> None:
         """Add a single object to this object store."""
         """Add a single object to this object store."""
         raise NotImplementedError(self.add_object)
         raise NotImplementedError(self.add_object)
 
 
-    def add_objects(self, objects, progress=None) -> None:
+    def add_objects(self, objects: Any, progress: Optional[Callable] = None) -> None:
         """Add a set of objects to this object store.
         """Add a set of objects to this object store.
 
 
         Args:
         Args:
@@ -280,14 +270,14 @@ class BaseObjectStore:
 
 
     def tree_changes(
     def tree_changes(
         self,
         self,
-        source,
-        target,
-        want_unchanged=False,
-        include_trees=False,
-        change_type_same=False,
-        rename_detector=None,
-        paths=None,
-    ):
+        source: Optional[bytes],
+        target: Optional[bytes],
+        want_unchanged: bool = False,
+        include_trees: bool = False,
+        change_type_same: bool = False,
+        rename_detector: Optional[Any] = None,
+        paths: Optional[Any] = None,
+    ) -> Any:
         """Find the differences between the contents of two trees.
         """Find the differences between the contents of two trees.
 
 
         Args:
         Args:
@@ -320,7 +310,7 @@ class BaseObjectStore:
                 (change.old.sha, change.new.sha),
                 (change.old.sha, change.new.sha),
             )
             )
 
 
-    def iter_tree_contents(self, tree_id, include_trees=False):
+    def iter_tree_contents(self, tree_id: bytes, include_trees: bool = False) -> Any:
         """Iterate the contents of a tree and all subtrees.
         """Iterate the contents of a tree and all subtrees.
 
 
         Iteration is depth-first pre-order, as in e.g. os.walk.
         Iteration is depth-first pre-order, as in e.g. os.walk.
@@ -362,13 +352,13 @@ class BaseObjectStore:
 
 
     def find_missing_objects(
     def find_missing_objects(
         self,
         self,
-        haves,
-        wants,
-        shallow=None,
-        progress=None,
-        get_tagged=None,
-        get_parents=lambda commit: commit.parents,
-    ):
+        haves: Any,
+        wants: Any,
+        shallow: Optional[Any] = None,
+        progress: Optional[Callable] = None,
+        get_tagged: Optional[Callable] = None,
+        get_parents: Callable = lambda commit: commit.parents,
+    ) -> Any:
         """Find the missing objects required for a set of revisions.
         """Find the missing objects required for a set of revisions.
 
 
         Args:
         Args:
@@ -395,7 +385,7 @@ class BaseObjectStore:
         )
         )
         return iter(finder)
         return iter(finder)
 
 
-    def find_common_revisions(self, graphwalker):
+    def find_common_revisions(self, graphwalker: Any) -> list[bytes]:
         """Find which revisions this store has in common using graphwalker.
         """Find which revisions this store has in common using graphwalker.
 
 
         Args:
         Args:
@@ -412,7 +402,7 @@ class BaseObjectStore:
         return haves
         return haves
 
 
     def generate_pack_data(
     def generate_pack_data(
-        self, have, want, shallow=None, progress=None, ofs_delta=True
+        self, have: Any, want: Any, shallow: Optional[Any] = None, progress: Optional[Callable] = None, ofs_delta: bool = True
     ) -> tuple[int, Iterator[UnpackedObject]]:
     ) -> tuple[int, Iterator[UnpackedObject]]:
         """Generate pack data objects for a set of wants/haves.
         """Generate pack data objects for a set of wants/haves.
 
 
@@ -435,7 +425,7 @@ class BaseObjectStore:
             progress=progress,
             progress=progress,
         )
         )
 
 
-    def peel_sha(self, sha):
+    def peel_sha(self, sha: bytes) -> bytes:
         """Peel all tags from a SHA.
         """Peel all tags from a SHA.
 
 
         Args:
         Args:
@@ -453,10 +443,10 @@ class BaseObjectStore:
 
 
     def _get_depth(
     def _get_depth(
         self,
         self,
-        head,
-        get_parents=lambda commit: commit.parents,
-        max_depth=None,
-    ):
+        head: bytes,
+        get_parents: Callable = lambda commit: commit.parents,
+        max_depth: Optional[int] = None,
+    ) -> int:
         """Return the current available depth for the given head.
         """Return the current available depth for the given head.
 
 
         For commits with multiple parents, the largest possible depth will be
         For commits with multiple parents, the largest possible depth will be
@@ -496,7 +486,7 @@ class BaseObjectStore:
             if sha.startswith(prefix):
             if sha.startswith(prefix):
                 yield sha
                 yield sha
 
 
-    def get_commit_graph(self):
+    def get_commit_graph(self) -> Optional[Any]:
         """Get the commit graph for this object store.
         """Get the commit graph for this object store.
 
 
         Returns:
         Returns:
@@ -504,7 +494,7 @@ class BaseObjectStore:
         """
         """
         return None
         return None
 
 
-    def write_commit_graph(self, refs=None, reachable=True) -> None:
+    def write_commit_graph(self, refs: Optional[Any] = None, reachable: bool = True) -> None:
         """Write a commit graph file for this object store.
         """Write a commit graph file for this object store.
 
 
         Args:
         Args:
@@ -518,7 +508,7 @@ class BaseObjectStore:
         """
         """
         raise NotImplementedError(self.write_commit_graph)
         raise NotImplementedError(self.write_commit_graph)
 
 
-    def get_object_mtime(self, sha):
+    def get_object_mtime(self, sha: bytes) -> float:
         """Get the modification time of an object.
         """Get the modification time of an object.
 
 
         Args:
         Args:
@@ -545,14 +535,14 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
 
 
     def __init__(
     def __init__(
         self,
         self,
-        pack_compression_level=-1,
-        pack_index_version=None,
-        pack_delta_window_size=None,
-        pack_window_memory=None,
-        pack_delta_cache_size=None,
-        pack_depth=None,
-        pack_threads=None,
-        pack_big_file_threshold=None,
+        pack_compression_level: int = -1,
+        pack_index_version: Optional[int] = None,
+        pack_delta_window_size: Optional[int] = None,
+        pack_window_memory: Optional[int] = None,
+        pack_delta_cache_size: Optional[int] = None,
+        pack_depth: Optional[int] = None,
+        pack_threads: Optional[int] = None,
+        pack_big_file_threshold: Optional[int] = None,
     ) -> None:
     ) -> None:
         """Initialize a PackBasedObjectStore.
         """Initialize a PackBasedObjectStore.
 
 
@@ -581,7 +571,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         raise NotImplementedError(self.add_pack)
         raise NotImplementedError(self.add_pack)
 
 
     def add_pack_data(
     def add_pack_data(
-        self, count: int, unpacked_objects: Iterator[UnpackedObject], progress=None
+        self, count: int, unpacked_objects: Iterator[UnpackedObject], progress: Optional[Callable] = None
     ) -> None:
     ) -> None:
         """Add pack data to this object store.
         """Add pack data to this object store.
 
 
@@ -609,15 +599,10 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
             return commit()
             return commit()
 
 
     @property
     @property
-    def alternates(self):
-        """Get the list of alternate object stores.
-
-        Returns:
-          List of alternate BaseObjectStore instances
-        """
+    def alternates(self) -> list:
         return []
         return []
 
 
-    def contains_packed(self, sha) -> bool:
+    def contains_packed(self, sha: bytes) -> bool:
         """Check if a particular object is present by SHA1 and is packed.
         """Check if a particular object is present by SHA1 and is packed.
 
 
         This does not check alternates.
         This does not check alternates.
@@ -642,7 +627,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
                 return True
                 return True
         return False
         return False
 
 
-    def _add_cached_pack(self, base_name, pack) -> None:
+    def _add_cached_pack(self, base_name: str, pack: Any) -> None:
         """Add a newly appeared pack to the cache by path."""
         """Add a newly appeared pack to the cache by path."""
         prev_pack = self._pack_cache.get(base_name)
         prev_pack = self._pack_cache.get(base_name)
         if prev_pack is not pack:
         if prev_pack is not pack:
@@ -682,7 +667,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
             (name, pack) = pack_cache.popitem()
             (name, pack) = pack_cache.popitem()
             pack.close()
             pack.close()
 
 
-    def _iter_cached_packs(self):
+    def _iter_cached_packs(self) -> Any:
         return self._pack_cache.values()
         return self._pack_cache.values()
 
 
     def _update_pack_cache(self) -> list[Pack]:
     def _update_pack_cache(self) -> list[Pack]:
@@ -696,7 +681,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         self._clear_cached_packs()
         self._clear_cached_packs()
 
 
     @property
     @property
-    def packs(self):
+    def packs(self) -> Any:
         """List with pack objects."""
         """List with pack objects."""
         return list(self._iter_cached_packs()) + list(self._update_pack_cache())
         return list(self._iter_cached_packs()) + list(self._update_pack_cache())
 
 
@@ -714,19 +699,19 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
                 count += 1
                 count += 1
         return count
         return count
 
 
-    def _iter_alternate_objects(self):
+    def _iter_alternate_objects(self) -> Any:
         """Iterate over the SHAs of all the objects in alternate stores."""
         """Iterate over the SHAs of all the objects in alternate stores."""
         for alternate in self.alternates:
         for alternate in self.alternates:
             yield from alternate
             yield from alternate
 
 
-    def _iter_loose_objects(self):
+    def _iter_loose_objects(self) -> Any:
         """Iterate over the SHAs of all loose objects."""
         """Iterate over the SHAs of all loose objects."""
         raise NotImplementedError(self._iter_loose_objects)
         raise NotImplementedError(self._iter_loose_objects)
 
 
-    def _get_loose_object(self, sha) -> Optional[ShaFile]:
+    def _get_loose_object(self, sha: bytes) -> Optional[ShaFile]:
         raise NotImplementedError(self._get_loose_object)
         raise NotImplementedError(self._get_loose_object)
 
 
-    def delete_loose_object(self, sha) -> None:
+    def delete_loose_object(self, sha: bytes) -> None:
         """Delete a loose object.
         """Delete a loose object.
 
 
         This method only handles loose objects. For packed objects,
         This method only handles loose objects. For packed objects,
@@ -734,10 +719,10 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
         """
         """
         raise NotImplementedError(self.delete_loose_object)
         raise NotImplementedError(self.delete_loose_object)
 
 
-    def _remove_pack(self, name) -> None:
+    def _remove_pack(self, name: str) -> None:
         raise NotImplementedError(self._remove_pack)
         raise NotImplementedError(self._remove_pack)
 
 
-    def pack_loose_objects(self):
+    def pack_loose_objects(self) -> int:
         """Pack loose objects.
         """Pack loose objects.
 
 
         Returns: Number of objects packed
         Returns: Number of objects packed
@@ -750,7 +735,7 @@ class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
             self.delete_loose_object(obj.id)
             self.delete_loose_object(obj.id)
         return len(objects)
         return len(objects)
 
 
-    def repack(self, exclude=None):
+    def repack(self, exclude: Optional[set] = None) -> None:
         """Repack the packs in this repository.
         """Repack the packs in this repository.
 
 
         Note that this implementation is fairly naive and currently keeps all
         Note that this implementation is fairly naive and currently keeps all

+ 86 - 95
dulwich/pack.py

@@ -1343,7 +1343,7 @@ class PackStreamCopier(PackStreamReader):
     appropriate and written out to the given file-like object.
     appropriate and written out to the given file-like object.
     """
     """
 
 
-    def __init__(self, read_all, read_some, outfile, delta_iter=None) -> None:
+    def __init__(self, read_all: Callable, read_some: Callable, outfile: Any, delta_iter: Optional[Any] = None) -> None:
         """Initialize the copier.
         """Initialize the copier.
 
 
         Args:
         Args:
@@ -1359,19 +1359,13 @@ class PackStreamCopier(PackStreamReader):
         self.outfile = outfile
         self.outfile = outfile
         self._delta_iter = delta_iter
         self._delta_iter = delta_iter
 
 
-    def _read(self, read, size):
-        """Read data from the read callback and write it to the file.
-
-        Args:
-          read: Read callback function
-          size: Number of bytes to read
-        Returns: Data read
-        """
+    def _read(self, read: Callable, size: int) -> bytes:
+        """Read data from the read callback and write it to the file."""
         data = super()._read(read, size)
         data = super()._read(read, size)
         self.outfile.write(data)
         self.outfile.write(data)
         return data
         return data
 
 
-    def verify(self, progress=None) -> None:
+    def verify(self, progress: Optional[Callable] = None) -> None:
         """Verify a pack stream and write it to the output file.
         """Verify a pack stream and write it to the output file.
 
 
         See PackStreamReader.iterobjects for a list of exceptions this may
         See PackStreamReader.iterobjects for a list of exceptions this may
@@ -1456,15 +1450,15 @@ class PackData:
     def __init__(
     def __init__(
         self,
         self,
         filename: Union[str, os.PathLike],
         filename: Union[str, os.PathLike],
-        file=None,
-        size=None,
+        file: Optional[Any] = None,
+        size: Optional[int] = None,
         *,
         *,
-        delta_window_size=None,
-        window_memory=None,
-        delta_cache_size=None,
-        depth=None,
-        threads=None,
-        big_file_threshold=None,
+        delta_window_size: Optional[int] = None,
+        window_memory: Optional[int] = None,
+        delta_cache_size: Optional[int] = None,
+        depth: Optional[int] = None,
+        threads: Optional[int] = None,
+        big_file_threshold: Optional[int] = None,
     ) -> None:
     ) -> None:
         """Create a PackData object representing the pack in the given filename.
         """Create a PackData object representing the pack in the given filename.
 
 
@@ -1497,7 +1491,7 @@ class PackData:
         )
         )
 
 
     @property
     @property
-    def filename(self):
+    def filename(self) -> str:
         """Get the filename of the pack file.
         """Get the filename of the pack file.
 
 
         Returns:
         Returns:
@@ -1506,7 +1500,7 @@ class PackData:
         return os.path.basename(self._filename)
         return os.path.basename(self._filename)
 
 
     @property
     @property
-    def path(self):
+    def path(self) -> str:
         """Get the full path of the pack file.
         """Get the full path of the pack file.
 
 
         Returns:
         Returns:
@@ -1515,7 +1509,7 @@ class PackData:
         return self._filename
         return self._filename
 
 
     @classmethod
     @classmethod
-    def from_file(cls, file, size=None):
+    def from_file(cls, file: Any, size: Optional[int] = None) -> 'PackData':
         """Create a PackData object from an open file.
         """Create a PackData object from an open file.
 
 
         Args:
         Args:
@@ -1528,7 +1522,7 @@ class PackData:
         return cls(str(file), file=file, size=size)
         return cls(str(file), file=file, size=size)
 
 
     @classmethod
     @classmethod
-    def from_path(cls, path: Union[str, os.PathLike]):
+    def from_path(cls, path: Union[str, os.PathLike]) -> 'PackData':
         """Create a PackData object from a file path.
         """Create a PackData object from a file path.
 
 
         Args:
         Args:
@@ -1543,26 +1537,20 @@ class PackData:
         """Close the underlying pack file."""
         """Close the underlying pack file."""
         self._file.close()
         self._file.close()
 
 
-    def __enter__(self):
+    def __enter__(self) -> 'PackData':
         """Enter context manager."""
         """Enter context manager."""
         return self
         return self
 
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional[Any]) -> None:
         """Exit context manager."""
         """Exit context manager."""
         self.close()
         self.close()
 
 
-    def __eq__(self, other):
-        """Check equality based on pack checksum."""
+    def __eq__(self, other: Any) -> bool:
         if isinstance(other, PackData):
         if isinstance(other, PackData):
             return self.get_stored_checksum() == other.get_stored_checksum()
             return self.get_stored_checksum() == other.get_stored_checksum()
         return False
         return False
 
 
-    def _get_size(self):
-        """Get the size of the pack file.
-
-        Returns: Size in bytes
-        Raises: AssertionError if file is too small to be a pack
-        """
+    def _get_size(self) -> int:
         if self._size is not None:
         if self._size is not None:
             return self._size
             return self._size
         self._size = os.path.getsize(self._filename)
         self._size = os.path.getsize(self._filename)
@@ -1575,20 +1563,14 @@ class PackData:
         """Returns the number of objects in this pack."""
         """Returns the number of objects in this pack."""
         return self._num_objects
         return self._num_objects
 
 
-    def calculate_checksum(self):
+    def calculate_checksum(self) -> bytes:
         """Calculate the checksum for this pack.
         """Calculate the checksum for this pack.
 
 
         Returns: 20-byte binary SHA1 digest
         Returns: 20-byte binary SHA1 digest
         """
         """
         return compute_file_sha(self._file, end_ofs=-20).digest()
         return compute_file_sha(self._file, end_ofs=-20).digest()
 
 
-    def iter_unpacked(self, *, include_comp: bool = False):
-        """Iterate over unpacked objects in the pack.
-
-        Args:
-          include_comp: If True, include compressed object data
-        Yields: UnpackedObject instances
-        """
+    def iter_unpacked(self, *, include_comp: bool = False) -> Any:
         self._file.seek(self._header_size)
         self._file.seek(self._header_size)
 
 
         if self._num_objects is None:
         if self._num_objects is None:
@@ -1626,7 +1608,7 @@ class PackData:
         self,
         self,
         progress: Optional[ProgressFn] = None,
         progress: Optional[ProgressFn] = None,
         resolve_ext_ref: Optional[ResolveExtRefFn] = None,
         resolve_ext_ref: Optional[ResolveExtRefFn] = None,
-    ):
+    ) -> Any:
         """Return entries in this pack, sorted by SHA.
         """Return entries in this pack, sorted by SHA.
 
 
         Args:
         Args:
@@ -1639,7 +1621,7 @@ class PackData:
             self.iterentries(progress=progress, resolve_ext_ref=resolve_ext_ref)
             self.iterentries(progress=progress, resolve_ext_ref=resolve_ext_ref)
         )
         )
 
 
-    def create_index_v1(self, filename, progress=None, resolve_ext_ref=None):
+    def create_index_v1(self, filename: str, progress: Optional[Callable] = None, resolve_ext_ref: Optional[Callable] = None) -> bytes:
         """Create a version 1 file for this data file.
         """Create a version 1 file for this data file.
 
 
         Args:
         Args:
@@ -1654,7 +1636,7 @@ class PackData:
         with GitFile(filename, "wb") as f:
         with GitFile(filename, "wb") as f:
             return write_pack_index_v1(f, entries, self.calculate_checksum())
             return write_pack_index_v1(f, entries, self.calculate_checksum())
 
 
-    def create_index_v2(self, filename, progress=None, resolve_ext_ref=None):
+    def create_index_v2(self, filename: str, progress: Optional[Callable] = None, resolve_ext_ref: Optional[Callable] = None) -> bytes:
         """Create a version 2 index file for this data file.
         """Create a version 2 index file for this data file.
 
 
         Args:
         Args:
@@ -1670,8 +1652,8 @@ class PackData:
             return write_pack_index_v2(f, entries, self.calculate_checksum())
             return write_pack_index_v2(f, entries, self.calculate_checksum())
 
 
     def create_index_v3(
     def create_index_v3(
-        self, filename, progress=None, resolve_ext_ref=None, hash_algorithm=1
-    ):
+        self, filename: str, progress: Optional[Callable] = None, resolve_ext_ref: Optional[Callable] = None, hash_algorithm: int = 1
+    ) -> bytes:
         """Create a version 3 index file for this data file.
         """Create a version 3 index file for this data file.
 
 
         Args:
         Args:
@@ -1690,8 +1672,8 @@ class PackData:
             )
             )
 
 
     def create_index(
     def create_index(
-        self, filename, progress=None, version=2, resolve_ext_ref=None, hash_algorithm=1
-    ):
+        self, filename: str, progress: Optional[Callable] = None, version: int = 2, resolve_ext_ref: Optional[Callable] = None, hash_algorithm: int = 1
+    ) -> bytes:
         """Create an  index file for this data file.
         """Create an  index file for this data file.
 
 
         Args:
         Args:
@@ -1720,7 +1702,7 @@ class PackData:
         else:
         else:
             raise ValueError(f"unknown index format {version}")
             raise ValueError(f"unknown index format {version}")
 
 
-    def get_stored_checksum(self):
+    def get_stored_checksum(self) -> bytes:
         """Return the expected checksum stored in this pack."""
         """Return the expected checksum stored in this pack."""
         self._file.seek(-20, SEEK_END)
         self._file.seek(-20, SEEK_END)
         return self._file.read(20)
         return self._file.read(20)
@@ -1791,6 +1773,7 @@ class DeltaChainIterator(Generic[T]):
             file_obj: File object to read pack data from
             file_obj: File object to read pack data from
             resolve_ext_ref: Optional function to resolve external references
             resolve_ext_ref: Optional function to resolve external references
         """
         """
+    def __init__(self, file_obj: Any, *, resolve_ext_ref: Optional[Callable] = None) -> None:
         self._file = file_obj
         self._file = file_obj
         self._resolve_ext_ref = resolve_ext_ref
         self._resolve_ext_ref = resolve_ext_ref
         self._pending_ofs: dict[int, list[int]] = defaultdict(list)
         self._pending_ofs: dict[int, list[int]] = defaultdict(list)
@@ -1799,7 +1782,7 @@ class DeltaChainIterator(Generic[T]):
         self._ext_refs: list[bytes] = []
         self._ext_refs: list[bytes] = []
 
 
     @classmethod
     @classmethod
-    def for_pack_data(cls, pack_data: PackData, resolve_ext_ref=None):
+    def for_pack_data(cls, pack_data: PackData, resolve_ext_ref: Optional[Callable] = None) -> 'DeltaChainIterator':
         """Create a DeltaChainIterator from pack data.
         """Create a DeltaChainIterator from pack data.
 
 
         Args:
         Args:
@@ -1822,8 +1805,8 @@ class DeltaChainIterator(Generic[T]):
         shas: Iterable[bytes],
         shas: Iterable[bytes],
         *,
         *,
         allow_missing: bool = False,
         allow_missing: bool = False,
-        resolve_ext_ref=None,
-    ):
+        resolve_ext_ref: Optional[Callable] = None,
+    ) -> 'DeltaChainIterator':
         """Create a DeltaChainIterator for a subset of objects.
         """Create a DeltaChainIterator for a subset of objects.
 
 
         Args:
         Args:
@@ -1895,7 +1878,7 @@ class DeltaChainIterator(Generic[T]):
         """
         """
         self._file = pack_data._file
         self._file = pack_data._file
 
 
-    def _walk_all_chains(self):
+    def _walk_all_chains(self) -> Any:
         for offset, type_num in self._full_ofs:
         for offset, type_num in self._full_ofs:
             yield from self._follow_chain(offset, type_num, None)
             yield from self._follow_chain(offset, type_num, None)
         yield from self._walk_ref_chains()
         yield from self._walk_ref_chains()
@@ -1905,7 +1888,7 @@ class DeltaChainIterator(Generic[T]):
         if self._pending_ref:
         if self._pending_ref:
             raise UnresolvedDeltas([sha_to_hex(s) for s in self._pending_ref])
             raise UnresolvedDeltas([sha_to_hex(s) for s in self._pending_ref])
 
 
-    def _walk_ref_chains(self):
+    def _walk_ref_chains(self) -> Any:
         if not self._resolve_ext_ref:
         if not self._resolve_ext_ref:
             self._ensure_no_pending()
             self._ensure_no_pending()
             return
             return
@@ -1927,11 +1910,11 @@ class DeltaChainIterator(Generic[T]):
 
 
         self._ensure_no_pending()
         self._ensure_no_pending()
 
 
-    def _result(self, unpacked: UnpackedObject) -> T:
+    def _result(self, unpacked: UnpackedObject) -> Any:
         raise NotImplementedError
         raise NotImplementedError
 
 
     def _resolve_object(
     def _resolve_object(
-        self, offset: int, obj_type_num: int, base_chunks: list[bytes]
+        self, offset: int, obj_type_num: int, base_chunks: Optional[list[bytes]]
     ) -> UnpackedObject:
     ) -> UnpackedObject:
         self._file.seek(offset)
         self._file.seek(offset)
         unpacked, _ = unpack_object(
         unpacked, _ = unpack_object(
@@ -1948,7 +1931,7 @@ class DeltaChainIterator(Generic[T]):
             unpacked.obj_chunks = apply_delta(base_chunks, unpacked.decomp_chunks)
             unpacked.obj_chunks = apply_delta(base_chunks, unpacked.decomp_chunks)
         return unpacked
         return unpacked
 
 
-    def _follow_chain(self, offset: int, obj_type_num: int, base_chunks: list[bytes]):
+    def _follow_chain(self, offset: int, obj_type_num: int, base_chunks: list[bytes]) -> Iterator[T]:
         # Unlike PackData.get_object_at, there is no need to cache offsets as
         # Unlike PackData.get_object_at, there is no need to cache offsets as
         # this approach by design inflates each object exactly once.
         # this approach by design inflates each object exactly once.
         todo = [(offset, obj_type_num, base_chunks)]
         todo = [(offset, obj_type_num, base_chunks)]
@@ -1971,7 +1954,8 @@ class DeltaChainIterator(Generic[T]):
         """Iterate over objects in the pack."""
         """Iterate over objects in the pack."""
         return self._walk_all_chains()
         return self._walk_all_chains()
 
 
-    def ext_refs(self):
+    @property
+    def ext_refs(self) -> list[bytes]:
         """Return external references."""
         """Return external references."""
         return self._ext_refs
         return self._ext_refs
 
 
@@ -1979,7 +1963,7 @@ class DeltaChainIterator(Generic[T]):
 class UnpackedObjectIterator(DeltaChainIterator[UnpackedObject]):
 class UnpackedObjectIterator(DeltaChainIterator[UnpackedObject]):
     """Delta chain iterator that yield unpacked objects."""
     """Delta chain iterator that yield unpacked objects."""
 
 
-    def _result(self, unpacked):
+    def _result(self, unpacked: UnpackedObject) -> UnpackedObject:
         """Return the unpacked object.
         """Return the unpacked object.
 
 
         Args:
         Args:
@@ -1996,7 +1980,7 @@ class PackIndexer(DeltaChainIterator[PackIndexEntry]):
 
 
     _compute_crc32 = True
     _compute_crc32 = True
 
 
-    def _result(self, unpacked):
+    def _result(self, unpacked: UnpackedObject) -> tuple:
         """Convert unpacked object to pack index entry.
         """Convert unpacked object to pack index entry.
 
 
         Args:
         Args:
@@ -2011,7 +1995,7 @@ class PackIndexer(DeltaChainIterator[PackIndexEntry]):
 class PackInflater(DeltaChainIterator[ShaFile]):
 class PackInflater(DeltaChainIterator[ShaFile]):
     """Delta chain iterator that yields ShaFile objects."""
     """Delta chain iterator that yields ShaFile objects."""
 
 
-    def _result(self, unpacked):
+    def _result(self, unpacked: UnpackedObject) -> Any:
         """Convert unpacked object to ShaFile.
         """Convert unpacked object to ShaFile.
 
 
         Args:
         Args:
@@ -2032,6 +2016,7 @@ class SHA1Reader(BinaryIO):
         Args:
         Args:
             f: File-like object to wrap
             f: File-like object to wrap
         """
         """
+    def __init__(self, f: BinaryIO) -> None:
         self.f = f
         self.f = f
         self.sha1 = sha1(b"")
         self.sha1 = sha1(b"")
 
 
@@ -2065,7 +2050,7 @@ class SHA1Reader(BinaryIO):
         ):
         ):
             raise ChecksumMismatch(self.sha1.hexdigest(), sha_to_hex(stored))
             raise ChecksumMismatch(self.sha1.hexdigest(), sha_to_hex(stored))
 
 
-    def close(self):
+    def close(self) -> None:
         """Close the underlying file."""
         """Close the underlying file."""
         return self.f.close()
         return self.f.close()
 
 
@@ -2141,16 +2126,19 @@ class SHA1Reader(BinaryIO):
         """
         """
         raise UnsupportedOperation("write")
         raise UnsupportedOperation("write")
 
 
-    def __enter__(self):
-        """Enter context manager."""
+    def writelines(self, lines: Any) -> None:
+        raise UnsupportedOperation("writelines")
+
+    def write(self, data: bytes) -> int:
+        raise UnsupportedOperation("write")
+
+    def __enter__(self) -> 'SHA1Reader':
         return self
         return self
 
 
-    def __exit__(self, type, value, traceback):
-        """Exit context manager and close file."""
+    def __exit__(self, type: Optional[type], value: Optional[BaseException], traceback: Optional[Any]) -> None:
         self.close()
         self.close()
 
 
-    def __iter__(self):
-        """Return iterator over lines."""
+    def __iter__(self) -> 'SHA1Reader':
         return self
         return self
 
 
     def __next__(self) -> bytes:
     def __next__(self) -> bytes:
@@ -2193,6 +2181,7 @@ class SHA1Writer(BinaryIO):
         Args:
         Args:
             f: File-like object to wrap
             f: File-like object to wrap
         """
         """
+    def __init__(self, f: BinaryIO) -> None:
         self.f = f
         self.f = f
         self.length = 0
         self.length = 0
         self.sha1 = sha1(b"")
         self.sha1 = sha1(b"")
@@ -2206,12 +2195,13 @@ class SHA1Writer(BinaryIO):
         Returns:
         Returns:
             Number of bytes written
             Number of bytes written
         """
         """
+    def write(self, data: bytes) -> int:
         self.sha1.update(data)
         self.sha1.update(data)
         self.f.write(data)
         self.f.write(data)
         self.length += len(data)
         self.length += len(data)
         return len(data)
         return len(data)
 
 
-    def write_sha(self):
+    def write_sha(self) -> bytes:
         """Write the SHA1 digest to the file.
         """Write the SHA1 digest to the file.
 
 
         Returns:
         Returns:
@@ -2223,7 +2213,7 @@ class SHA1Writer(BinaryIO):
         self.length += len(sha)
         self.length += len(sha)
         return sha
         return sha
 
 
-    def close(self):
+    def close(self) -> bytes:
         """Close the file after writing SHA1.
         """Close the file after writing SHA1.
 
 
         Returns:
         Returns:
@@ -2233,7 +2223,7 @@ class SHA1Writer(BinaryIO):
         self.f.close()
         self.f.close()
         return sha
         return sha
 
 
-    def offset(self):
+    def offset(self) -> int:
         """Get the total number of bytes written.
         """Get the total number of bytes written.
 
 
         Returns:
         Returns:
@@ -2297,6 +2287,7 @@ class SHA1Writer(BinaryIO):
         Args:
         Args:
             lines: Iterable of lines to write
             lines: Iterable of lines to write
         """
         """
+    def writelines(self, lines: Any) -> None:
         for line in lines:
         for line in lines:
             self.write(line)
             self.write(line)
 
 
@@ -2308,15 +2299,15 @@ class SHA1Writer(BinaryIO):
         """
         """
         raise UnsupportedOperation("read")
         raise UnsupportedOperation("read")
 
 
-    def __enter__(self):
+    def __enter__(self) -> 'SHA1Writer':
         """Enter context manager."""
         """Enter context manager."""
         return self
         return self
 
 
-    def __exit__(self, type, value, traceback):
+    def __exit__(self, type: Optional[type], value: Optional[BaseException], traceback: Optional[Any]) -> None:
         """Exit context manager and close file."""
         """Exit context manager and close file."""
         self.close()
         self.close()
 
 
-    def __iter__(self):
+    def __iter__(self) -> 'SHA1Writer':
         """Return iterator."""
         """Return iterator."""
         return self
         return self
 
 
@@ -2345,7 +2336,7 @@ class SHA1Writer(BinaryIO):
         raise UnsupportedOperation("truncate")
         raise UnsupportedOperation("truncate")
 
 
 
 
-def pack_object_header(type_num, delta_base, size):
+def pack_object_header(type_num: int, delta_base: Optional[Any], size: int) -> bytearray:
     """Create a pack object header for the given object info.
     """Create a pack object header for the given object info.
 
 
     Args:
     Args:
@@ -2376,7 +2367,7 @@ def pack_object_header(type_num, delta_base, size):
     return bytearray(header)
     return bytearray(header)
 
 
 
 
-def pack_object_chunks(type, object, compression_level=-1):
+def pack_object_chunks(type: int, object: ShaFile, compression_level: int = -1) -> Iterator[bytes]:
     """Generate chunks for a pack object.
     """Generate chunks for a pack object.
 
 
     Args:
     Args:
@@ -2398,7 +2389,7 @@ def pack_object_chunks(type, object, compression_level=-1):
     yield compressor.flush()
     yield compressor.flush()
 
 
 
 
-def write_pack_object(write, type, object, sha=None, compression_level=-1):
+def write_pack_object(write: Callable[[bytes], int], type: int, object: ShaFile, sha: Optional[bytes] = None, compression_level: int = -1) -> bytes:
     """Write pack object to a file.
     """Write pack object to a file.
 
 
     Args:
     Args:
@@ -2449,7 +2440,7 @@ def write_pack(
         return data_sum, write_pack_index(f, entries, data_sum)
         return data_sum, write_pack_index(f, entries, data_sum)
 
 
 
 
-def pack_header_chunks(num_objects):
+def pack_header_chunks(num_objects: int) -> Iterator[bytes]:
     """Yield chunks for a pack header."""
     """Yield chunks for a pack header."""
     yield b"PACK"  # Pack header
     yield b"PACK"  # Pack header
     yield struct.pack(b">L", 2)  # Pack version
     yield struct.pack(b">L", 2)  # Pack version
@@ -2522,7 +2513,7 @@ def deltify_pack_objects(
         delta_base is None for full text entries
         delta_base is None for full text entries
     """
     """
 
 
-    def objects_with_hints():
+    def objects_with_hints() -> Iterator[tuple[ShaFile, tuple[int, None]]]:
         for e in objects:
         for e in objects:
             if isinstance(e, ShaFile):
             if isinstance(e, ShaFile):
                 yield (e, (e.type_num, None))
                 yield (e, (e.type_num, None))
@@ -2651,7 +2642,7 @@ def pack_objects_to_data(
         )
         )
     else:
     else:
 
 
-        def iter_without_path():
+        def iter_without_path() -> Iterator[tuple[ShaFile, tuple[int, None]]]:
             for o in objects:
             for o in objects:
                 if isinstance(o, tuple):
                 if isinstance(o, tuple):
                     yield full_unpacked_object(o[0])
                     yield full_unpacked_object(o[0])
@@ -2819,11 +2810,11 @@ class PackChunkGenerator:
             reuse_compressed=reuse_compressed,
             reuse_compressed=reuse_compressed,
         )
         )
 
 
-    def sha1digest(self):
+    def sha1digest(self) -> bytes:
         """Return the SHA1 digest of the pack data."""
         """Return the SHA1 digest of the pack data."""
         return self.cs.digest()
         return self.cs.digest()
 
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[bytes]:
         """Iterate over pack data chunks."""
         """Iterate over pack data chunks."""
         return self._it
         return self._it
 
 
@@ -2925,7 +2916,7 @@ def write_pack_data(
     return chunk_generator.entries, chunk_generator.sha1digest()
     return chunk_generator.entries, chunk_generator.sha1digest()
 
 
 
 
-def write_pack_index_v1(f, entries, pack_checksum):
+def write_pack_index_v1(f: BinaryIO, entries: list[tuple[bytes, int, Optional[int]]], pack_checksum: bytes) -> None:
     """Write a new pack index file.
     """Write a new pack index file.
 
 
     Args:
     Args:
@@ -2970,7 +2961,7 @@ def _delta_encode_size(size) -> bytes:
 _MAX_COPY_LEN = 0xFFFF
 _MAX_COPY_LEN = 0xFFFF
 
 
 
 
-def _encode_copy_operation(start, length):
+def _encode_copy_operation(start: int, length: int) -> bytes:
     scratch = bytearray([0x80])
     scratch = bytearray([0x80])
     for i in range(4):
     for i in range(4):
         if start & 0xFF << i * 8:
         if start & 0xFF << i * 8:
@@ -2983,7 +2974,7 @@ def _encode_copy_operation(start, length):
     return bytes(scratch)
     return bytes(scratch)
 
 
 
 
-def create_delta(base_buf, target_buf):
+def create_delta(base_buf: bytes, target_buf: bytes) -> bytes:
     """Use python difflib to work out how to transform base_buf to target_buf.
     """Use python difflib to work out how to transform base_buf to target_buf.
 
 
     Args:
     Args:
@@ -3029,7 +3020,7 @@ def create_delta(base_buf, target_buf):
             yield memoryview(target_buf)[o : o + s]
             yield memoryview(target_buf)[o : o + s]
 
 
 
 
-def apply_delta(src_buf, delta):
+def apply_delta(src_buf: bytes, delta: bytes) -> bytes:
     """Based on the similar function in git's patch-delta.c.
     """Based on the similar function in git's patch-delta.c.
 
 
     Args:
     Args:
@@ -3044,7 +3035,7 @@ def apply_delta(src_buf, delta):
     index = 0
     index = 0
     delta_length = len(delta)
     delta_length = len(delta)
 
 
-    def get_delta_header_size(delta, index):
+    def get_delta_header_size(delta: bytes, index: list[int]) -> tuple[int, int]:
         size = 0
         size = 0
         i = 0
         i = 0
         while delta:
         while delta:
@@ -3305,7 +3296,7 @@ class Pack:
         self.resolve_ext_ref = resolve_ext_ref
         self.resolve_ext_ref = resolve_ext_ref
 
 
     @classmethod
     @classmethod
-    def from_lazy_objects(cls, data_fn, idx_fn):
+    def from_lazy_objects(cls, data_fn: Callable, idx_fn: Callable) -> 'Pack':
         """Create a new pack object from callables to load pack data and index objects."""
         """Create a new pack object from callables to load pack data and index objects."""
         ret = cls("")
         ret = cls("")
         ret._data_load = data_fn
         ret._data_load = data_fn
@@ -3313,7 +3304,7 @@ class Pack:
         return ret
         return ret
 
 
     @classmethod
     @classmethod
-    def from_objects(cls, data, idx):
+    def from_objects(cls, data: PackData, idx: PackIndex) -> 'Pack':
         """Create a new pack object from pack data and index objects."""
         """Create a new pack object from pack data and index objects."""
         ret = cls("")
         ret = cls("")
         ret._data = data
         ret._data = data
@@ -3323,7 +3314,7 @@ class Pack:
         ret.check_length_and_checksum()
         ret.check_length_and_checksum()
         return ret
         return ret
 
 
-    def name(self):
+    def name(self) -> bytes:
         """The SHA over the SHAs of the objects in this pack."""
         """The SHA over the SHAs of the objects in this pack."""
         return self.index.objects_sha1()
         return self.index.objects_sha1()
 
 
@@ -3354,15 +3345,15 @@ class Pack:
         if self._idx is not None:
         if self._idx is not None:
             self._idx.close()
             self._idx.close()
 
 
-    def __enter__(self):
+    def __enter__(self) -> 'Pack':
         """Enter context manager."""
         """Enter context manager."""
         return self
         return self
 
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional[Any]) -> None:
         """Exit context manager."""
         """Exit context manager."""
         self.close()
         self.close()
 
 
-    def __eq__(self, other):
+    def __eq__(self, other: Any) -> bool:
         """Check equality with another pack."""
         """Check equality with another pack."""
         return isinstance(self, type(other)) and self.index == other.index
         return isinstance(self, type(other)) and self.index == other.index
 
 
@@ -3374,7 +3365,7 @@ class Pack:
         """Return string representation of this pack."""
         """Return string representation of this pack."""
         return f"{self.__class__.__name__}({self._basename!r})"
         return f"{self.__class__.__name__}({self._basename!r})"
 
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[bytes]:
         """Iterate over all the sha1s of the objects in this pack."""
         """Iterate over all the sha1s of the objects in this pack."""
         return iter(self.index)
         return iter(self.index)
 
 
@@ -3410,7 +3401,7 @@ class Pack:
         """Return the stored checksum of the pack data."""
         """Return the stored checksum of the pack data."""
         return self.data.get_stored_checksum()
         return self.data.get_stored_checksum()
 
 
-    def pack_tuples(self):
+    def pack_tuples(self) -> list[tuple[ShaFile, None]]:
         """Return pack tuples for all objects in pack."""
         """Return pack tuples for all objects in pack."""
         return [(o, None) for o in self.iterobjects()]
         return [(o, None) for o in self.iterobjects()]
 
 
@@ -3492,7 +3483,7 @@ class Pack:
         if not allow_missing and todo:
         if not allow_missing and todo:
             raise UnresolvedDeltas(list(todo))
             raise UnresolvedDeltas(list(todo))
 
 
-    def iter_unpacked(self, include_comp=False):
+    def iter_unpacked(self, include_comp: bool = False) -> Iterator[UnpackedObject]:
         """Iterate over all unpacked objects in this pack."""
         """Iterate over all unpacked objects in this pack."""
         ofs_to_entries = {
         ofs_to_entries = {
             ofs: (sha, crc32) for (sha, ofs, crc32) in self.index.iterentries()
             ofs: (sha, crc32) for (sha, ofs, crc32) in self.index.iterentries()

+ 92 - 174
dulwich/refs.py

@@ -339,7 +339,7 @@ class RefsContainer:
         """
         """
         raise NotImplementedError(self.read_loose_ref)
         raise NotImplementedError(self.read_loose_ref)
 
 
-    def follow(self, name) -> tuple[list[bytes], bytes]:
+    def follow(self, name: bytes) -> tuple[list[bytes], bytes]:
         """Follow a reference name.
         """Follow a reference name.
 
 
         Returns: a tuple of (refnames, sha), wheres refnames are the names of
         Returns: a tuple of (refnames, sha), wheres refnames are the names of
@@ -359,20 +359,12 @@ class RefsContainer:
                 raise SymrefLoop(name, depth)
                 raise SymrefLoop(name, depth)
         return refnames, contents
         return refnames, contents
 
 
-    def __contains__(self, refname) -> bool:
-        """Check if a ref exists.
-
-        Args:
-          refname: Name of the ref to check
-
-        Returns:
-          True if the ref exists
-        """
+    def __contains__(self, refname: bytes) -> bool:
         if self.read_ref(refname):
         if self.read_ref(refname):
             return True
             return True
         return False
         return False
 
 
-    def __getitem__(self, name) -> ObjectID:
+    def __getitem__(self, name: bytes) -> ObjectID:
         """Get the SHA1 for a reference name.
         """Get the SHA1 for a reference name.
 
 
         This method follows all symbolic references.
         This method follows all symbolic references.
@@ -384,13 +376,13 @@ class RefsContainer:
 
 
     def set_if_equals(
     def set_if_equals(
         self,
         self,
-        name,
-        old_ref,
-        new_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        new_ref: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Set a refname to new_ref only if it currently equals old_ref.
         """Set a refname to new_ref only if it currently equals old_ref.
 
 
@@ -412,7 +404,7 @@ class RefsContainer:
         raise NotImplementedError(self.set_if_equals)
         raise NotImplementedError(self.set_if_equals)
 
 
     def add_if_new(
     def add_if_new(
-        self, name, ref, committer=None, timestamp=None, timezone=None, message=None
+        self, name: bytes, ref: bytes, committer: Optional[bytes] = None, timestamp: Optional[int] = None, timezone: Optional[int] = None, message: Optional[bytes] = None
     ) -> bool:
     ) -> bool:
         """Add a new reference only if it does not already exist.
         """Add a new reference only if it does not already exist.
 
 
@@ -426,7 +418,7 @@ class RefsContainer:
         """
         """
         raise NotImplementedError(self.add_if_new)
         raise NotImplementedError(self.add_if_new)
 
 
-    def __setitem__(self, name, ref) -> None:
+    def __setitem__(self, name: bytes, ref: bytes) -> None:
         """Set a reference name to point to the given SHA1.
         """Set a reference name to point to the given SHA1.
 
 
         This method follows all symbolic references if applicable for the
         This method follows all symbolic references if applicable for the
@@ -446,12 +438,12 @@ class RefsContainer:
 
 
     def remove_if_equals(
     def remove_if_equals(
         self,
         self,
-        name,
-        old_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Remove a refname only if it currently equals old_ref.
         """Remove a refname only if it currently equals old_ref.
 
 
@@ -471,7 +463,7 @@ class RefsContainer:
         """
         """
         raise NotImplementedError(self.remove_if_equals)
         raise NotImplementedError(self.remove_if_equals)
 
 
-    def __delitem__(self, name) -> None:
+    def __delitem__(self, name: bytes) -> None:
         """Remove a refname.
         """Remove a refname.
 
 
         This method does not follow symbolic references, even if applicable for
         This method does not follow symbolic references, even if applicable for
@@ -486,7 +478,7 @@ class RefsContainer:
         """
         """
         self.remove_if_equals(name, None)
         self.remove_if_equals(name, None)
 
 
-    def get_symrefs(self):
+    def get_symrefs(self) -> dict[bytes, bytes]:
         """Get a dict with all symrefs in this container.
         """Get a dict with all symrefs in this container.
 
 
         Returns: Dictionary mapping source ref to target ref
         Returns: Dictionary mapping source ref to target ref
@@ -517,41 +509,22 @@ class DictRefsContainer(RefsContainer):
     threadsafe.
     threadsafe.
     """
     """
 
 
-    def __init__(self, refs, logger=None) -> None:
-        """Initialize DictRefsContainer."""
+    def __init__(self, refs: dict[bytes, bytes], logger: Optional[Callable[[bytes, Optional[bytes], Optional[bytes], Optional[bytes], Optional[int], Optional[int], Optional[bytes]], None]] = None) -> None:
         super().__init__(logger=logger)
         super().__init__(logger=logger)
         self._refs = refs
         self._refs = refs
         self._peeled: dict[bytes, ObjectID] = {}
         self._peeled: dict[bytes, ObjectID] = {}
         self._watchers: set[Any] = set()
         self._watchers: set[Any] = set()
 
 
-    def allkeys(self):
-        """Get all ref names.
-
-        Returns:
-          All ref names in the container
-        """
+    def allkeys(self) -> Iterator[bytes]:
         return self._refs.keys()
         return self._refs.keys()
 
 
-    def read_loose_ref(self, name):
-        """Read a reference from the refs dictionary.
-
-        Args:
-          name: The ref name to read
-
-        Returns:
-          The ref value or None if not found
-        """
+    def read_loose_ref(self, name: bytes) -> Optional[bytes]:
         return self._refs.get(name, None)
         return self._refs.get(name, None)
 
 
-    def get_packed_refs(self):
-        """Get packed refs (always empty for DictRefsContainer).
-
-        Returns:
-          Empty dictionary
-        """
+    def get_packed_refs(self) -> dict[bytes, bytes]:
         return {}
         return {}
 
 
-    def _notify(self, ref, newsha) -> None:
+    def _notify(self, ref: bytes, newsha: Optional[bytes]) -> None:
         for watcher in self._watchers:
         for watcher in self._watchers:
             watcher._notify((ref, newsha))
             watcher._notify((ref, newsha))
 
 
@@ -559,10 +532,10 @@ class DictRefsContainer(RefsContainer):
         self,
         self,
         name: Ref,
         name: Ref,
         other: Ref,
         other: Ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> None:
     ) -> None:
         """Make a ref point at another ref.
         """Make a ref point at another ref.
 
 
@@ -590,13 +563,13 @@ class DictRefsContainer(RefsContainer):
 
 
     def set_if_equals(
     def set_if_equals(
         self,
         self,
-        name,
-        old_ref,
-        new_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        new_ref: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Set a refname to new_ref only if it currently equals old_ref.
         """Set a refname to new_ref only if it currently equals old_ref.
 
 
@@ -638,9 +611,9 @@ class DictRefsContainer(RefsContainer):
         self,
         self,
         name: Ref,
         name: Ref,
         ref: ObjectID,
         ref: ObjectID,
-        committer=None,
-        timestamp=None,
-        timezone=None,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
         message: Optional[bytes] = None,
         message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Add a new reference only if it does not already exist.
         """Add a new reference only if it does not already exist.
@@ -673,12 +646,12 @@ class DictRefsContainer(RefsContainer):
 
 
     def remove_if_equals(
     def remove_if_equals(
         self,
         self,
-        name,
-        old_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Remove a refname only if it currently equals old_ref.
         """Remove a refname only if it currently equals old_ref.
 
 
@@ -716,25 +689,17 @@ class DictRefsContainer(RefsContainer):
             )
             )
         return True
         return True
 
 
-    def get_peeled(self, name):
-        """Get the peeled value of a ref.
-
-        Args:
-          name: Ref name to get peeled value for
-
-        Returns:
-          The peeled SHA or None if not available
-        """
+    def get_peeled(self, name: bytes) -> Optional[bytes]:
         return self._peeled.get(name)
         return self._peeled.get(name)
 
 
-    def _update(self, refs) -> None:
+    def _update(self, refs: dict[bytes, bytes]) -> None:
         """Update multiple refs; intended only for testing."""
         """Update multiple refs; intended only for testing."""
         # TODO(dborowitz): replace this with a public function that uses
         # TODO(dborowitz): replace this with a public function that uses
         # set_if_equal.
         # set_if_equal.
         for ref, sha in refs.items():
         for ref, sha in refs.items():
             self.set_if_equals(ref, None, sha)
             self.set_if_equals(ref, None, sha)
 
 
-    def _update_peeled(self, peeled) -> None:
+    def _update_peeled(self, peeled: dict[bytes, bytes]) -> None:
         """Update cached peeled refs; intended only for testing."""
         """Update cached peeled refs; intended only for testing."""
         self._peeled.update(peeled)
         self._peeled.update(peeled)
 
 
@@ -742,56 +707,22 @@ class DictRefsContainer(RefsContainer):
 class InfoRefsContainer(RefsContainer):
 class InfoRefsContainer(RefsContainer):
     """Refs container that reads refs from a info/refs file."""
     """Refs container that reads refs from a info/refs file."""
 
 
-    def __init__(self, f) -> None:
-        """Initialize an InfoRefsContainer.
-
-        Args:
-          f: File-like object containing info/refs data
-        """
+    def __init__(self, f: Any) -> None:
         self._refs = {}
         self._refs = {}
         self._peeled = {}
         self._peeled = {}
         refs = read_info_refs(f)
         refs = read_info_refs(f)
         (self._refs, self._peeled) = split_peeled_refs(refs)
         (self._refs, self._peeled) = split_peeled_refs(refs)
 
 
-    def allkeys(self):
-        """Get all ref names.
-
-        Returns:
-          All ref names in the info/refs file
-        """
+    def allkeys(self) -> Iterator[bytes]:
         return self._refs.keys()
         return self._refs.keys()
 
 
-    def read_loose_ref(self, name):
-        """Read a reference from the parsed info/refs.
-
-        Args:
-          name: The ref name to read
-
-        Returns:
-          The ref value or None if not found
-        """
+    def read_loose_ref(self, name: bytes) -> Optional[bytes]:
         return self._refs.get(name, None)
         return self._refs.get(name, None)
 
 
-    def get_packed_refs(self):
-        """Get packed refs (always empty for InfoRefsContainer).
-
-        Returns:
-          Empty dictionary
-        """
+    def get_packed_refs(self) -> dict[bytes, bytes]:
         return {}
         return {}
 
 
-    def get_peeled(self, name):
-        """Get the peeled value of a ref.
-
-        Args:
-          name: Ref name to get peeled value for
-
-        Returns:
-          The peeled SHA if available, otherwise the ref value itself
-
-        Raises:
-          KeyError: If the ref doesn't exist
-        """
+    def get_peeled(self, name: bytes) -> Optional[bytes]:
         try:
         try:
             return self._peeled[name]
             return self._peeled[name]
         except KeyError:
         except KeyError:
@@ -805,7 +736,7 @@ class DiskRefsContainer(RefsContainer):
         self,
         self,
         path: Union[str, bytes, os.PathLike],
         path: Union[str, bytes, os.PathLike],
         worktree_path: Optional[Union[str, bytes, os.PathLike]] = None,
         worktree_path: Optional[Union[str, bytes, os.PathLike]] = None,
-        logger=None,
+        logger: Optional[Callable[[bytes, Optional[bytes], Optional[bytes], Optional[bytes], Optional[int], Optional[int], Optional[bytes]], None]] = None,
     ) -> None:
     ) -> None:
         """Initialize DiskRefsContainer."""
         """Initialize DiskRefsContainer."""
         super().__init__(logger=logger)
         super().__init__(logger=logger)
@@ -822,15 +753,7 @@ class DiskRefsContainer(RefsContainer):
         """Return string representation of DiskRefsContainer."""
         """Return string representation of DiskRefsContainer."""
         return f"{self.__class__.__name__}({self.path!r})"
         return f"{self.__class__.__name__}({self.path!r})"
 
 
-    def subkeys(self, base):
-        """Get all ref names under a base ref.
-
-        Args:
-          base: Base ref path to search under
-
-        Returns:
-          Set of ref names under the base (without base prefix)
-        """
+    def subkeys(self, base: bytes) -> set[bytes]:
         subkeys = set()
         subkeys = set()
         path = self.refpath(base)
         path = self.refpath(base)
         for root, unused_dirs, files in os.walk(path):
         for root, unused_dirs, files in os.walk(path):
@@ -849,12 +772,7 @@ class DiskRefsContainer(RefsContainer):
                 subkeys.add(key[len(base) :].strip(b"/"))
                 subkeys.add(key[len(base) :].strip(b"/"))
         return subkeys
         return subkeys
 
 
-    def allkeys(self):
-        """Get all ref names from disk.
-
-        Returns:
-          Set of all ref names (both loose and packed)
-        """
+    def allkeys(self) -> Iterator[bytes]:
         allkeys = set()
         allkeys = set()
         if os.path.exists(self.refpath(HEADREF)):
         if os.path.exists(self.refpath(HEADREF)):
             allkeys.add(HEADREF)
             allkeys.add(HEADREF)
@@ -871,7 +789,7 @@ class DiskRefsContainer(RefsContainer):
         allkeys.update(self.get_packed_refs())
         allkeys.update(self.get_packed_refs())
         return allkeys
         return allkeys
 
 
-    def refpath(self, name):
+    def refpath(self, name: bytes) -> bytes:
         """Return the disk path of a ref."""
         """Return the disk path of a ref."""
         if os.path.sep != "/":
         if os.path.sep != "/":
             name = name.replace(b"/", os.fsencode(os.path.sep))
             name = name.replace(b"/", os.fsencode(os.path.sep))
@@ -882,7 +800,7 @@ class DiskRefsContainer(RefsContainer):
         else:
         else:
             return os.path.join(self.path, name)
             return os.path.join(self.path, name)
 
 
-    def get_packed_refs(self):
+    def get_packed_refs(self) -> dict[bytes, bytes]:
         """Get contents of the packed-refs file.
         """Get contents of the packed-refs file.
 
 
         Returns: Dictionary mapping ref names to SHA1s
         Returns: Dictionary mapping ref names to SHA1s
@@ -950,7 +868,7 @@ class DiskRefsContainer(RefsContainer):
 
 
             self._packed_refs = packed_refs
             self._packed_refs = packed_refs
 
 
-    def get_peeled(self, name):
+    def get_peeled(self, name: bytes) -> Optional[bytes]:
         """Return the cached peeled value of a ref, if available.
         """Return the cached peeled value of a ref, if available.
 
 
         Args:
         Args:
@@ -969,7 +887,7 @@ class DiskRefsContainer(RefsContainer):
             # Known not peelable
             # Known not peelable
             return self[name]
             return self[name]
 
 
-    def read_loose_ref(self, name):
+    def read_loose_ref(self, name: bytes) -> Optional[bytes]:
         """Read a reference file and return its contents.
         """Read a reference file and return its contents.
 
 
         If the reference file a symbolic reference, only read the first line of
         If the reference file a symbolic reference, only read the first line of
@@ -999,7 +917,7 @@ class DiskRefsContainer(RefsContainer):
             # errors depending on the specific operating system
             # errors depending on the specific operating system
             return None
             return None
 
 
-    def _remove_packed_ref(self, name) -> None:
+    def _remove_packed_ref(self, name: bytes) -> None:
         if self._packed_refs is None:
         if self._packed_refs is None:
             return
             return
         filename = os.path.join(self.path, b"packed-refs")
         filename = os.path.join(self.path, b"packed-refs")
@@ -1024,12 +942,12 @@ class DiskRefsContainer(RefsContainer):
 
 
     def set_symbolic_ref(
     def set_symbolic_ref(
         self,
         self,
-        name,
-        other,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        other: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> None:
     ) -> None:
         """Make a ref point at another ref.
         """Make a ref point at another ref.
 
 
@@ -1065,13 +983,13 @@ class DiskRefsContainer(RefsContainer):
 
 
     def set_if_equals(
     def set_if_equals(
         self,
         self,
-        name,
-        old_ref,
-        new_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        new_ref: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Set a refname to new_ref only if it currently equals old_ref.
         """Set a refname to new_ref only if it currently equals old_ref.
 
 
@@ -1151,9 +1069,9 @@ class DiskRefsContainer(RefsContainer):
         self,
         self,
         name: bytes,
         name: bytes,
         ref: bytes,
         ref: bytes,
-        committer=None,
-        timestamp=None,
-        timezone=None,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
         message: Optional[bytes] = None,
         message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Add a new reference only if it does not already exist.
         """Add a new reference only if it does not already exist.
@@ -1203,12 +1121,12 @@ class DiskRefsContainer(RefsContainer):
 
 
     def remove_if_equals(
     def remove_if_equals(
         self,
         self,
-        name,
-        old_ref,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        old_ref: Optional[bytes],
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> bool:
     ) -> bool:
         """Remove a refname only if it currently equals old_ref.
         """Remove a refname only if it currently equals old_ref.
 
 
@@ -1309,7 +1227,7 @@ class DiskRefsContainer(RefsContainer):
             self.add_packed_refs(refs_to_pack)
             self.add_packed_refs(refs_to_pack)
 
 
 
 
-def _split_ref_line(line):
+def _split_ref_line(line: bytes) -> tuple[bytes, bytes]:
     """Split a single ref line into a tuple of SHA1 and name."""
     """Split a single ref line into a tuple of SHA1 and name."""
     fields = line.rstrip(b"\n\r").split(b" ")
     fields = line.rstrip(b"\n\r").split(b" ")
     if len(fields) != 2:
     if len(fields) != 2:
@@ -1322,7 +1240,7 @@ def _split_ref_line(line):
     return (sha, name)
     return (sha, name)
 
 
 
 
-def read_packed_refs(f):
+def read_packed_refs(f: Any) -> Iterator[tuple[bytes, bytes]]:
     """Read a packed refs file.
     """Read a packed refs file.
 
 
     Args:
     Args:
@@ -1338,7 +1256,7 @@ def read_packed_refs(f):
         yield _split_ref_line(line)
         yield _split_ref_line(line)
 
 
 
 
-def read_packed_refs_with_peeled(f):
+def read_packed_refs_with_peeled(f: Any) -> Iterator[tuple[bytes, bytes, Optional[bytes]]]:
     """Read a packed refs file including peeled refs.
     """Read a packed refs file including peeled refs.
 
 
     Assumes the "# pack-refs with: peeled" line was already read. Yields tuples
     Assumes the "# pack-refs with: peeled" line was already read. Yields tuples
@@ -1370,7 +1288,7 @@ def read_packed_refs_with_peeled(f):
         yield (sha, name, None)
         yield (sha, name, None)
 
 
 
 
-def write_packed_refs(f, packed_refs, peeled_refs=None) -> None:
+def write_packed_refs(f: Any, packed_refs: dict[bytes, bytes], peeled_refs: Optional[dict[bytes, bytes]] = None) -> None:
     """Write a packed refs file.
     """Write a packed refs file.
 
 
     Args:
     Args:
@@ -1388,7 +1306,7 @@ def write_packed_refs(f, packed_refs, peeled_refs=None) -> None:
             f.write(b"^" + peeled_refs[refname] + b"\n")
             f.write(b"^" + peeled_refs[refname] + b"\n")
 
 
 
 
-def read_info_refs(f):
+def read_info_refs(f: Any) -> dict[bytes, bytes]:
     """Read info/refs file.
     """Read info/refs file.
 
 
     Args:
     Args:
@@ -1404,7 +1322,7 @@ def read_info_refs(f):
     return ret
     return ret
 
 
 
 
-def write_info_refs(refs, store: ObjectContainer):
+def write_info_refs(refs: dict[bytes, bytes], store: ObjectContainer) -> Iterator[bytes]:
     """Generate info refs."""
     """Generate info refs."""
     # TODO: Avoid recursive import :(
     # TODO: Avoid recursive import :(
     from .object_store import peel_sha
     from .object_store import peel_sha
@@ -1447,7 +1365,7 @@ def split_peeled_refs(refs: dict[bytes, bytes]) -> tuple[dict[bytes, bytes], dic
     return regular, peeled
     return regular, peeled
 
 
 
 
-def _set_origin_head(refs, origin, origin_head) -> None:
+def _set_origin_head(refs: RefsContainer, origin: bytes, origin_head: Optional[bytes]) -> None:
     # set refs/remotes/origin/HEAD
     # set refs/remotes/origin/HEAD
     origin_base = b"refs/remotes/" + origin + b"/"
     origin_base = b"refs/remotes/" + origin + b"/"
     if origin_head and origin_head.startswith(LOCAL_BRANCH_PREFIX):
     if origin_head and origin_head.startswith(LOCAL_BRANCH_PREFIX):
@@ -1491,7 +1409,7 @@ def _set_default_branch(
     return head_ref
     return head_ref
 
 
 
 
-def _set_head(refs, head_ref, ref_message):
+def _set_head(refs: RefsContainer, head_ref: bytes, ref_message: Optional[bytes]) -> Optional[bytes]:
     if head_ref.startswith(LOCAL_TAG_PREFIX):
     if head_ref.startswith(LOCAL_TAG_PREFIX):
         # detach HEAD at specified tag
         # detach HEAD at specified tag
         head = refs[head_ref]
         head = refs[head_ref]
@@ -1541,7 +1459,7 @@ def _import_remote_refs(
     )
     )
 
 
 
 
-def serialize_refs(store, refs):
+def serialize_refs(store: ObjectContainer, refs: dict[bytes, bytes]) -> dict[bytes, bytes]:
     """Serialize refs with peeled refs.
     """Serialize refs with peeled refs.
 
 
     Args:
     Args:

+ 26 - 26
dulwich/repo.py

@@ -341,7 +341,7 @@ class ParentsProvider:
         # Get commit graph once at initialization for performance
         # Get commit graph once at initialization for performance
         self.commit_graph = store.get_commit_graph()
         self.commit_graph = store.get_commit_graph()
 
 
-    def get_parents(self, commit_id: bytes, commit=None) -> list[bytes]:
+    def get_parents(self, commit_id: bytes, commit: Optional[Any] = None) -> list[bytes]:
         try:
         try:
             return self.grafts[commit_id]
             return self.grafts[commit_id]
         except KeyError:
         except KeyError:
@@ -498,7 +498,7 @@ class BaseRepo:
     def fetch_pack_data(
     def fetch_pack_data(
         self,
         self,
         determine_wants: Callable,
         determine_wants: Callable,
-        graph_walker,
+        graph_walker: Any,
         progress: Optional[Callable],
         progress: Optional[Callable],
         *,
         *,
         get_tagged: Optional[Callable] = None,
         get_tagged: Optional[Callable] = None,
@@ -532,11 +532,11 @@ class BaseRepo:
 
 
     def find_missing_objects(
     def find_missing_objects(
         self,
         self,
-        determine_wants,
-        graph_walker,
-        progress,
+        determine_wants: Callable,
+        graph_walker: Any,
+        progress: Optional[Callable],
         *,
         *,
-        get_tagged=None,
+        get_tagged: Optional[Callable] = None,
         depth: Optional[int] = None,
         depth: Optional[int] = None,
     ) -> Optional[MissingObjectFinder]:
     ) -> Optional[MissingObjectFinder]:
         """Fetch the missing objects required for a set of revisions.
         """Fetch the missing objects required for a set of revisions.
@@ -596,7 +596,7 @@ class BaseRepo:
                 def __len__(self) -> int:
                 def __len__(self) -> int:
                     return 0
                     return 0
 
 
-                def __iter__(self):
+                def __iter__(self) -> Any:
                     yield from []
                     yield from []
 
 
             return DummyMissingObjectFinder()  # type: ignore
             return DummyMissingObjectFinder()  # type: ignore
@@ -615,7 +615,7 @@ class BaseRepo:
 
 
         parents_provider = ParentsProvider(self.object_store, shallows=current_shallow)
         parents_provider = ParentsProvider(self.object_store, shallows=current_shallow)
 
 
-        def get_parents(commit):
+        def get_parents(commit: Any) -> list[bytes]:
             """Get parents for a commit using the parents provider.
             """Get parents for a commit using the parents provider.
 
 
             Args:
             Args:
@@ -642,7 +642,7 @@ class BaseRepo:
         want: list[ObjectID],
         want: list[ObjectID],
         progress: Optional[Callable[[str], None]] = None,
         progress: Optional[Callable[[str], None]] = None,
         ofs_delta: Optional[bool] = None,
         ofs_delta: Optional[bool] = None,
-    ):
+    ) -> Any:
         """Generate pack data objects for a set of wants/haves.
         """Generate pack data objects for a set of wants/haves.
 
 
         Args:
         Args:
@@ -697,7 +697,7 @@ class BaseRepo:
         # TODO: move this method to WorkTree
         # TODO: move this method to WorkTree
         return self.refs[b"HEAD"]
         return self.refs[b"HEAD"]
 
 
-    def _get_object(self, sha, cls):
+    def _get_object(self, sha: bytes, cls: Any) -> Any:
         assert len(sha) in (20, 40)
         assert len(sha) in (20, 40)
         ret = self.get_object(sha)
         ret = self.get_object(sha)
         if not isinstance(ret, cls):
         if not isinstance(ret, cls):
@@ -768,7 +768,7 @@ class BaseRepo:
         """
         """
         raise NotImplementedError(self.get_description)
         raise NotImplementedError(self.get_description)
 
 
-    def set_description(self, description) -> None:
+    def set_description(self, description: bytes) -> None:
         """Set the description for this repository.
         """Set the description for this repository.
 
 
         Args:
         Args:
@@ -776,14 +776,14 @@ class BaseRepo:
         """
         """
         raise NotImplementedError(self.set_description)
         raise NotImplementedError(self.set_description)
 
 
-    def get_rebase_state_manager(self):
+    def get_rebase_state_manager(self) -> Any:
         """Get the appropriate rebase state manager for this repository.
         """Get the appropriate rebase state manager for this repository.
 
 
         Returns: RebaseStateManager instance
         Returns: RebaseStateManager instance
         """
         """
         raise NotImplementedError(self.get_rebase_state_manager)
         raise NotImplementedError(self.get_rebase_state_manager)
 
 
-    def get_blob_normalizer(self):
+    def get_blob_normalizer(self) -> Any:
         """Return a BlobNormalizer object for checkin/checkout operations.
         """Return a BlobNormalizer object for checkin/checkout operations.
 
 
         Returns: BlobNormalizer instance
         Returns: BlobNormalizer instance
@@ -831,7 +831,7 @@ class BaseRepo:
         with f:
         with f:
             return {line.strip() for line in f}
             return {line.strip() for line in f}
 
 
-    def update_shallow(self, new_shallow, new_unshallow) -> None:
+    def update_shallow(self, new_shallow: Any, new_unshallow: Any) -> None:
         """Update the list of shallow objects.
         """Update the list of shallow objects.
 
 
         Args:
         Args:
@@ -873,7 +873,7 @@ class BaseRepo:
 
 
         return Notes(self.object_store, self.refs)
         return Notes(self.object_store, self.refs)
 
 
-    def get_walker(self, include: Optional[list[bytes]] = None, **kwargs):
+    def get_walker(self, include: Optional[list[bytes]] = None, **kwargs) -> Any:
         """Obtain a walker for this repository.
         """Obtain a walker for this repository.
 
 
         Args:
         Args:
@@ -910,7 +910,7 @@ class BaseRepo:
 
 
         return Walker(self.object_store, include, **kwargs)
         return Walker(self.object_store, include, **kwargs)
 
 
-    def __getitem__(self, name: Union[ObjectID, Ref]):
+    def __getitem__(self, name: Union[ObjectID, Ref]) -> Any:
         """Retrieve a Git object by SHA1 or ref.
         """Retrieve a Git object by SHA1 or ref.
 
 
         Args:
         Args:
@@ -1002,7 +1002,7 @@ class BaseRepo:
         for sha in to_remove:
         for sha in to_remove:
             del self._graftpoints[sha]
             del self._graftpoints[sha]
 
 
-    def _read_heads(self, name):
+    def _read_heads(self, name: str) -> Any:
         f = self.get_named_file(name)
         f = self.get_named_file(name)
         if f is None:
         if f is None:
             return []
             return []
@@ -1028,17 +1028,17 @@ class BaseRepo:
         message: Optional[bytes] = None,
         message: Optional[bytes] = None,
         committer: Optional[bytes] = None,
         committer: Optional[bytes] = None,
         author: Optional[bytes] = None,
         author: Optional[bytes] = None,
-        commit_timestamp=None,
-        commit_timezone=None,
-        author_timestamp=None,
-        author_timezone=None,
+        commit_timestamp: Optional[Any] = None,
+        commit_timezone: Optional[Any] = None,
+        author_timestamp: Optional[Any] = None,
+        author_timezone: Optional[Any] = None,
         tree: Optional[ObjectID] = None,
         tree: Optional[ObjectID] = None,
         encoding: Optional[bytes] = None,
         encoding: Optional[bytes] = None,
         ref: Optional[Ref] = b"HEAD",
         ref: Optional[Ref] = b"HEAD",
         merge_heads: Optional[list[ObjectID]] = None,
         merge_heads: Optional[list[ObjectID]] = None,
         no_verify: bool = False,
         no_verify: bool = False,
         sign: bool = False,
         sign: bool = False,
-    ):
+    ) -> Any:
         """Create a new commit.
         """Create a new commit.
 
 
         If not specified, committer and author default to
         If not specified, committer and author default to
@@ -1356,7 +1356,7 @@ class Repo(BaseRepo):
         """
         """
         return self._commondir
         return self._commondir
 
 
-    def _determine_file_mode(self):
+    def _determine_file_mode(self) -> bool:
         """Probe the file-system to determine whether permissions can be trusted.
         """Probe the file-system to determine whether permissions can be trusted.
 
 
         Returns: True if permissions can be trusted, False otherwise.
         Returns: True if permissions can be trusted, False otherwise.
@@ -1379,7 +1379,7 @@ class Repo(BaseRepo):
 
 
         return mode_differs and st2_has_exec
         return mode_differs and st2_has_exec
 
 
-    def _determine_symlinks(self):
+    def _determine_symlinks(self) -> bool:
         """Probe the filesystem to determine whether symlinks can be created.
         """Probe the filesystem to determine whether symlinks can be created.
 
 
         Returns: True if symlinks can be created, False otherwise.
         Returns: True if symlinks can be created, False otherwise.
@@ -1387,7 +1387,7 @@ class Repo(BaseRepo):
         # TODO(jelmer): Actually probe disk / look at filesystem
         # TODO(jelmer): Actually probe disk / look at filesystem
         return sys.platform != "win32"
         return sys.platform != "win32"
 
 
-    def _put_named_file(self, path, contents) -> None:
+    def _put_named_file(self, path: str, contents: bytes) -> None:
         """Write a file to the control dir with the given name and contents.
         """Write a file to the control dir with the given name and contents.
 
 
         Args:
         Args:
@@ -1398,7 +1398,7 @@ class Repo(BaseRepo):
         with GitFile(os.path.join(self.controldir(), path), "wb") as f:
         with GitFile(os.path.join(self.controldir(), path), "wb") as f:
             f.write(contents)
             f.write(contents)
 
 
-    def _del_named_file(self, path) -> None:
+    def _del_named_file(self, path: str) -> None:
         try:
         try:
             os.unlink(os.path.join(self.controldir(), path))
             os.unlink(os.path.join(self.controldir(), path))
         except FileNotFoundError:
         except FileNotFoundError: