Răsfoiți Sursa

Add type annotations to functions in dulwich/porcelain.py

Jelmer Vernooij 5 luni în urmă
părinte
comite
941310c9b5
1 a modificat fișierele cu 38 adăugiri și 70 ștergeri
  1. 38 70
      dulwich/porcelain.py

+ 38 - 70
dulwich/porcelain.py

@@ -91,7 +91,7 @@ from contextlib import AbstractContextManager, closing, contextmanager
 from dataclasses import dataclass
 from io import BytesIO, RawIOBase
 from pathlib import Path
-from typing import BinaryIO, Optional, TypeVar, Union, cast, overload
+from typing import BinaryIO, Callable, Optional, TypeVar, Union, cast, overload
 
 from . import replace_me
 from .archive import tar_stream
@@ -224,26 +224,10 @@ class NoneStream(RawIOBase):
         """
         return b""
 
-    def readinto(self, b) -> None:
-        """Read bytes into buffer (no-op for NoneStream).
-
-        Args:
-          b: Buffer to read into
-
-        Returns:
-          None
-        """
+    def readinto(self, b: Union[bytearray, memoryview]) -> None:
         return None
 
-    def write(self, b) -> None:
-        """Write bytes (no-op for NoneStream).
-
-        Args:
-          b: Bytes to write
-
-        Returns:
-          None
-        """
+    def write(self, b: bytes) -> None:
         return None
 
 
@@ -257,12 +241,7 @@ DEFAULT_ENCODING = "utf-8"
 class Error(Exception):
     """Porcelain-based error."""
 
-    def __init__(self, msg) -> None:
-        """Initialize an Error.
-
-        Args:
-          msg: Error message
-        """
+    def __init__(self, msg: str) -> None:
         super().__init__(msg)
 
 
@@ -278,7 +257,7 @@ class CheckoutError(Error):
     """Indicates that a checkout cannot be performed."""
 
 
-def parse_timezone_format(tz_str):
+def parse_timezone_format(tz_str: str) -> int:
     """Parse given string and attempt to return a timezone offset.
 
     Different formats are considered in the following order:
@@ -333,9 +312,9 @@ def parse_timezone_format(tz_str):
     raise TimezoneFormatError(tz_str)
 
 
-def get_user_timezones():
-    """Retrieve local timezone as described in https://raw.githubusercontent.com/git/git/v2.3.0/Documentation/date-formats.txt.
-
+def get_user_timezones() -> tuple[int, int]:
+    """Retrieve local timezone as described in
+    https://raw.githubusercontent.com/git/git/v2.3.0/Documentation/date-formats.txt
     Returns: A tuple containing author timezone, committer timezone.
     """
     local_timezone = time.localtime().tm_gmtoff
@@ -372,7 +351,7 @@ def open_repo(
 
 
 @contextmanager
-def _noop_context_manager(obj):
+def _noop_context_manager(obj: T) -> Iterator[T]:
     """Context manager that has the same api as closing but does nothing."""
     yield obj
 
@@ -401,9 +380,10 @@ def open_repo_closing(
 
 
 def path_to_tree_path(
-    repopath: Union[str, os.PathLike], path, tree_encoding=DEFAULT_ENCODING
-):
-    """Convert a path to a path usable in an index, e.g. bytes and relative to the repository root.
+    repopath: Union[str, os.PathLike], path: Union[str, os.PathLike], tree_encoding: str = DEFAULT_ENCODING
+) -> bytes:
+    """Convert a path to a path usable in an index, e.g. bytes and relative to
+    the repository root.
 
     Args:
       repopath: Repository path, absolute or relative to the cwd
@@ -446,18 +426,12 @@ def path_to_tree_path(
 class DivergedBranches(Error):
     """Branches have diverged and fast-forward is not possible."""
 
-    def __init__(self, current_sha, new_sha) -> None:
-        """Initialize a DivergedBranches error.
-
-        Args:
-          current_sha: SHA of the current branch head
-          new_sha: SHA of the new branch head
-        """
+    def __init__(self, current_sha: bytes, new_sha: bytes) -> None:
         self.current_sha = current_sha
         self.new_sha = new_sha
 
 
-def check_diverged(repo, current_sha, new_sha) -> None:
+def check_diverged(repo: BaseRepo, current_sha: bytes, new_sha: bytes) -> None:
     """Check if updating to a sha can be done with fast forwarding.
 
     Args:
@@ -474,10 +448,10 @@ def check_diverged(repo, current_sha, new_sha) -> None:
 
 
 def archive(
-    repo,
+    repo: Union[str, BaseRepo],
     committish: Optional[Union[str, bytes, Commit, Tag]] = None,
-    outstream=default_bytes_out_stream,
-    errstream=default_bytes_err_stream,
+    outstream: BinaryIO = default_bytes_out_stream,
+    errstream: BinaryIO = default_bytes_err_stream,
 ) -> None:
     """Create an archive.
 
@@ -507,7 +481,7 @@ def update_server_info(repo: RepoPath = ".") -> None:
         server_update_server_info(r)
 
 
-def write_commit_graph(repo: RepoPath = ".", reachable=True) -> None:
+def write_commit_graph(repo: RepoPath = ".", reachable: bool = True) -> None:
     """Write a commit graph file for a repository.
 
     Args:
@@ -522,7 +496,7 @@ def write_commit_graph(repo: RepoPath = ".", reachable=True) -> None:
             r.object_store.write_commit_graph(refs, reachable=reachable)
 
 
-def symbolic_ref(repo: RepoPath, ref_name, force=False) -> None:
+def symbolic_ref(repo: RepoPath, ref_name: Union[str, bytes], force: bool = False) -> None:
     """Set git symbolic ref into HEAD.
 
     Args:
@@ -537,30 +511,24 @@ def symbolic_ref(repo: RepoPath, ref_name, force=False) -> None:
         repo_obj.refs.set_symbolic_ref(b"HEAD", ref_path)
 
 
-def pack_refs(repo: RepoPath, all=False) -> None:
-    """Pack loose refs into a single file.
-
-    Args:
-      repo: Path to the repository
-      all: If True, pack all refs; if False, only pack already-packed refs
-    """
+def pack_refs(repo: RepoPath, all: bool = False) -> None:
     with open_repo_closing(repo) as repo_obj:
         repo_obj.refs.pack_refs(all=all)
 
 
 def commit(
-    repo=".",
-    message=None,
-    author=None,
-    author_timezone=None,
-    committer=None,
-    commit_timezone=None,
-    encoding=None,
-    no_verify=False,
-    signoff=False,
-    all=False,
-    amend=False,
-):
+    repo: RepoPath = ".",
+    message: Optional[Union[str, bytes, Callable]] = None,
+    author: Optional[bytes] = None,
+    author_timezone: Optional[int] = None,
+    committer: Optional[bytes] = None,
+    commit_timezone: Optional[int] = None,
+    encoding: Optional[bytes] = None,
+    no_verify: bool = False,
+    signoff: bool = False,
+    all: bool = False,
+    amend: bool = False,
+) -> bytes:
     """Create a new commit.
 
     Args:
@@ -656,11 +624,11 @@ def commit(
 
 def commit_tree(
     repo: RepoPath,
-    tree,
-    message=None,
-    author=None,
-    committer=None,
-):
+    tree: bytes,
+    message: Optional[Union[str, bytes]] = None,
+    author: Optional[bytes] = None,
+    committer: Optional[bytes] = None,
+) -> bytes:
     """Create a new commit object.
 
     Args: