Jelajahi Sumber

Add type annotations to dulwich/porcelain.py

Jelmer Vernooij 5 bulan lalu
induk
melakukan
5d9cf13d8f
1 mengubah file dengan 96 tambahan dan 82 penghapusan
  1. 96 82
      dulwich/porcelain.py

+ 96 - 82
dulwich/porcelain.py

@@ -106,7 +106,12 @@ from typing import (
 from . import replace_me
 from . import replace_me
 from .archive import tar_stream
 from .archive import tar_stream
 from .bisect import BisectState
 from .bisect import BisectState
-from .client import get_transport_and_path
+from .client import (
+    FetchPackResult,
+    LsRemoteResult,
+    SendPackResult,
+    get_transport_and_path,
+)
 from .config import Config, ConfigFile, StackedConfig, read_submodules
 from .config import Config, ConfigFile, StackedConfig, read_submodules
 from .diff_tree import (
 from .diff_tree import (
     CHANGE_ADD,
     CHANGE_ADD,
@@ -123,6 +128,7 @@ from .graph import can_fast_forward
 from .ignore import IgnoreFilterManager
 from .ignore import IgnoreFilterManager
 from .index import (
 from .index import (
     ConflictedIndexEntry,
     ConflictedIndexEntry,
+    Index,
     IndexEntry,
     IndexEntry,
     _fs_to_tree_path,
     _fs_to_tree_path,
     blob_from_path_and_stat,
     blob_from_path_and_stat,
@@ -136,7 +142,7 @@ from .index import (
     validate_path_element_hfs,
     validate_path_element_hfs,
     validate_path_element_ntfs,
     validate_path_element_ntfs,
 )
 )
-from .object_store import tree_lookup_path
+from .object_store import BaseObjectStore, tree_lookup_path
 from .objects import (
 from .objects import (
     Blob,
     Blob,
     Commit,
     Commit,
@@ -2100,14 +2106,14 @@ def get_remote_repo(
 
 
 
 
 def push(
 def push(
-    repo,
-    remote_location=None,
-    refspecs=None,
-    outstream=default_bytes_out_stream,
-    errstream=default_bytes_err_stream,
-    force=False,
-    **kwargs,
-):
+    repo: RepoPath,
+    remote_location: Optional[Union[str, bytes]] = None,
+    refspecs: Optional[Union[Union[str, bytes], list[Union[str, bytes]]]] = None,
+    outstream: BinaryIO = default_bytes_out_stream,
+    errstream: BinaryIO = default_bytes_err_stream,
+    force: bool = False,
+    **kwargs: Any,
+) -> SendPackResult:
     """Remote push with dulwich via dulwich.client.
     """Remote push with dulwich via dulwich.client.
 
 
     Args:
     Args:
@@ -2149,7 +2155,7 @@ def push(
         selected_refs = []
         selected_refs = []
         remote_changed_refs = {}
         remote_changed_refs = {}
 
 
-        def update_refs(refs):
+        def update_refs(refs: dict[bytes, bytes]) -> dict[bytes, bytes]:
             selected_refs.extend(parse_reftuples(r.refs, refs, refspecs, force=force))
             selected_refs.extend(parse_reftuples(r.refs, refs, refspecs, force=force))
             new_refs = {}
             new_refs = {}
 
 
@@ -2215,17 +2221,17 @@ def push(
 
 
 
 
 def pull(
 def pull(
-    repo,
-    remote_location=None,
-    refspecs=None,
-    outstream=default_bytes_out_stream,
-    errstream=default_bytes_err_stream,
-    fast_forward=True,
-    ff_only=False,
-    force=False,
-    filter_spec=None,
-    protocol_version=None,
-    **kwargs,
+    repo: RepoPath,
+    remote_location: Optional[Union[str, bytes]] = None,
+    refspecs: Optional[Union[Union[str, bytes], list[Union[str, bytes]]]] = None,
+    outstream: BinaryIO = default_bytes_out_stream,
+    errstream: BinaryIO = default_bytes_err_stream,
+    fast_forward: bool = True,
+    ff_only: bool = False,
+    force: bool = False,
+    filter_spec: Optional[str] = None,
+    protocol_version: Optional[int] = None,
+    **kwargs: Any,
 ) -> None:
 ) -> None:
     """Pull from remote via dulwich.client.
     """Pull from remote via dulwich.client.
 
 
@@ -2257,7 +2263,7 @@ def pull(
         if refspecs is None:
         if refspecs is None:
             refspecs = [b"HEAD"]
             refspecs = [b"HEAD"]
 
 
-        def determine_wants(remote_refs, *args, **kwargs):
+        def determine_wants(remote_refs: dict[bytes, bytes], *args: Any, **kwargs: Any) -> list[bytes]:
             selected_refs.extend(
             selected_refs.extend(
                 parse_reftuples(remote_refs, r.refs, refspecs, force=force)
                 parse_reftuples(remote_refs, r.refs, refspecs, force=force)
             )
             )
@@ -2338,9 +2344,9 @@ def pull(
 
 
 def status(
 def status(
     repo: Union[str, os.PathLike, Repo] = ".",
     repo: Union[str, os.PathLike, Repo] = ".",
-    ignored=False,
-    untracked_files="normal",
-):
+    ignored: bool = False,
+    untracked_files: str = "normal",
+) -> GitStatus:
     """Returns staged, unstaged, and untracked changes relative to the HEAD.
     """Returns staged, unstaged, and untracked changes relative to the HEAD.
 
 
     Args:
     Args:
@@ -2386,7 +2392,11 @@ def status(
         return GitStatus(tracked_changes, unstaged_changes, untracked_changes)
         return GitStatus(tracked_changes, unstaged_changes, untracked_changes)
 
 
 
 
-def _walk_working_dir_paths(frompath, basepath, prune_dirnames=None):
+def _walk_working_dir_paths(
+    frompath: Union[str, bytes, os.PathLike],
+    basepath: Union[str, bytes, os.PathLike],
+    prune_dirnames: Optional[Callable[[str, list[str]], list[str]]] = None,
+) -> Iterator[tuple[str, bool]]:
     """Get path, is_dir for files in working dir from frompath.
     """Get path, is_dir for files in working dir from frompath.
 
 
     Args:
     Args:
@@ -2419,8 +2429,12 @@ def _walk_working_dir_paths(frompath, basepath, prune_dirnames=None):
 
 
 
 
 def get_untracked_paths(
 def get_untracked_paths(
-    frompath, basepath, index, exclude_ignored=False, untracked_files="all"
-):
+    frompath: Union[str, bytes, os.PathLike],
+    basepath: Union[str, bytes, os.PathLike],
+    index: Index,
+    exclude_ignored: bool = False,
+    untracked_files: str = "all",
+) -> Iterator[str]:
     """Get untracked paths.
     """Get untracked paths.
 
 
     Args:
     Args:
@@ -2450,7 +2464,7 @@ def get_untracked_paths(
     # List to store untracked directories found during traversal
     # List to store untracked directories found during traversal
     untracked_dir_list = []
     untracked_dir_list = []
 
 
-    def directory_has_non_ignored_files(dir_path, base_rel_path):
+    def directory_has_non_ignored_files(dir_path: str, base_rel_path: str) -> bool:
         """Recursively check if directory contains any non-ignored files."""
         """Recursively check if directory contains any non-ignored files."""
         try:
         try:
             for entry in os.listdir(dir_path):
             for entry in os.listdir(dir_path):
@@ -2468,7 +2482,7 @@ def get_untracked_paths(
             # If we can't read the directory, assume it has non-ignored files
             # If we can't read the directory, assume it has non-ignored files
             return True
             return True
 
 
-    def prune_dirnames(dirpath, dirnames):
+    def prune_dirnames(dirpath: str, dirnames: list[str]) -> list[str]:
         for i in range(len(dirnames) - 1, -1, -1):
         for i in range(len(dirnames) - 1, -1, -1):
             path = os.path.join(dirpath, dirnames[i])
             path = os.path.join(dirpath, dirnames[i])
             ip = os.path.join(os.path.relpath(path, basepath), "")
             ip = os.path.join(os.path.relpath(path, basepath), "")
@@ -2554,7 +2568,7 @@ def get_untracked_paths(
     yield from ignored_dirs
     yield from ignored_dirs
 
 
 
 
-def get_tree_changes(repo: RepoPath):
+def get_tree_changes(repo: RepoPath) -> dict[str, list[Union[str, bytes]]]:
     """Return add/delete/modify changes to tree by comparing index to HEAD.
     """Return add/delete/modify changes to tree by comparing index to HEAD.
 
 
     Args:
     Args:
@@ -2592,7 +2606,7 @@ def get_tree_changes(repo: RepoPath):
         return tracked_changes
         return tracked_changes
 
 
 
 
-def daemon(path=".", address=None, port=None) -> None:
+def daemon(path: Union[str, os.PathLike] = ".", address: Optional[str] = None, port: Optional[int] = None) -> None:
     """Run a daemon serving Git requests over TCP/IP.
     """Run a daemon serving Git requests over TCP/IP.
 
 
     Args:
     Args:
@@ -2606,7 +2620,7 @@ def daemon(path=".", address=None, port=None) -> None:
     server.serve_forever()
     server.serve_forever()
 
 
 
 
-def web_daemon(path=".", address=None, port=None) -> None:
+def web_daemon(path: Union[str, os.PathLike] = ".", address: Optional[str] = None, port: Optional[int] = None) -> None:
     """Run a daemon serving Git requests over HTTP.
     """Run a daemon serving Git requests over HTTP.
 
 
     Args:
     Args:
@@ -2633,7 +2647,7 @@ def web_daemon(path=".", address=None, port=None) -> None:
     server.serve_forever()
     server.serve_forever()
 
 
 
 
-def upload_pack(path=".", inf=None, outf=None) -> int:
+def upload_pack(path: Union[str, os.PathLike] = ".", inf: Optional[BinaryIO] = None, outf: Optional[BinaryIO] = None) -> int:
     """Upload a pack file after negotiating its contents using smart protocol.
     """Upload a pack file after negotiating its contents using smart protocol.
 
 
     Args:
     Args:
@@ -2648,7 +2662,7 @@ def upload_pack(path=".", inf=None, outf=None) -> int:
     path = os.path.expanduser(path)
     path = os.path.expanduser(path)
     backend = FileSystemBackend(path)
     backend = FileSystemBackend(path)
 
 
-    def send_fn(data) -> None:
+    def send_fn(data: bytes) -> None:
         outf.write(data)
         outf.write(data)
         outf.flush()
         outf.flush()
 
 
@@ -2659,7 +2673,7 @@ def upload_pack(path=".", inf=None, outf=None) -> int:
     return 0
     return 0
 
 
 
 
-def receive_pack(path=".", inf=None, outf=None) -> int:
+def receive_pack(path: Union[str, os.PathLike] = ".", inf: Optional[BinaryIO] = None, outf: Optional[BinaryIO] = None) -> int:
     """Receive a pack file after negotiating its contents using smart protocol.
     """Receive a pack file after negotiating its contents using smart protocol.
 
 
     Args:
     Args:
@@ -2674,7 +2688,7 @@ def receive_pack(path=".", inf=None, outf=None) -> int:
     path = os.path.expanduser(path)
     path = os.path.expanduser(path)
     backend = FileSystemBackend(path)
     backend = FileSystemBackend(path)
 
 
-    def send_fn(data) -> None:
+    def send_fn(data: bytes) -> None:
         outf.write(data)
         outf.write(data)
         outf.flush()
         outf.flush()
 
 
@@ -2697,7 +2711,7 @@ def _make_tag_ref(name: Union[str, bytes]) -> Ref:
     return LOCAL_TAG_PREFIX + name
     return LOCAL_TAG_PREFIX + name
 
 
 
 
-def branch_delete(repo: RepoPath, name) -> None:
+def branch_delete(repo: RepoPath, name: Union[str, bytes, list[Union[str, bytes]]]) -> None:
     """Delete a branch.
     """Delete a branch.
 
 
     Args:
     Args:
@@ -2714,7 +2728,7 @@ def branch_delete(repo: RepoPath, name) -> None:
 
 
 
 
 def branch_create(
 def branch_create(
-    repo: Union[str, os.PathLike, Repo], name, objectish=None, force=False
+    repo: Union[str, os.PathLike, Repo], name: Union[str, bytes], objectish: Optional[Union[str, bytes]] = None, force: bool = False
 ) -> None:
 ) -> None:
     """Create a branch.
     """Create a branch.
 
 
@@ -2810,7 +2824,7 @@ def branch_create(
                     repo_config.write_to_path()
                     repo_config.write_to_path()
 
 
 
 
-def branch_list(repo: RepoPath):
+def branch_list(repo: RepoPath) -> list[bytes]:
     """List all branches.
     """List all branches.
 
 
     Args:
     Args:
@@ -2841,7 +2855,7 @@ def branch_list(repo: RepoPath):
             branches.sort(reverse=reverse)
             branches.sort(reverse=reverse)
         elif sort_key in ("committerdate", "authordate"):
         elif sort_key in ("committerdate", "authordate"):
             # Sort by date
             # Sort by date
-            def get_commit_date(branch_name):
+            def get_commit_date(branch_name: bytes) -> int:
                 ref = LOCAL_BRANCH_PREFIX + branch_name
                 ref = LOCAL_BRANCH_PREFIX + branch_name
                 sha = r.refs[ref]
                 sha = r.refs[ref]
                 commit = r.object_store[sha]
                 commit = r.object_store[sha]
@@ -2866,7 +2880,7 @@ def branch_list(repo: RepoPath):
         return branches
         return branches
 
 
 
 
-def active_branch(repo: RepoPath):
+def active_branch(repo: RepoPath) -> bytes:
     """Return the active branch in the repository, if any.
     """Return the active branch in the repository, if any.
 
 
     Args:
     Args:
@@ -2884,7 +2898,7 @@ def active_branch(repo: RepoPath):
         return active_ref[len(LOCAL_BRANCH_PREFIX) :]
         return active_ref[len(LOCAL_BRANCH_PREFIX) :]
 
 
 
 
-def get_branch_remote(repo: Union[str, os.PathLike, Repo]):
+def get_branch_remote(repo: Union[str, os.PathLike, Repo]) -> bytes:
     """Return the active branch's remote name, if any.
     """Return the active branch's remote name, if any.
 
 
     Args:
     Args:
@@ -2904,7 +2918,7 @@ def get_branch_remote(repo: Union[str, os.PathLike, Repo]):
     return remote_name
     return remote_name
 
 
 
 
-def get_branch_merge(repo: RepoPath, branch_name=None):
+def get_branch_merge(repo: RepoPath, branch_name: Optional[bytes] = None) -> bytes:
     """Return the branch's merge reference (upstream branch), if any.
     """Return the branch's merge reference (upstream branch), if any.
 
 
     Args:
     Args:
@@ -2925,8 +2939,8 @@ def get_branch_merge(repo: RepoPath, branch_name=None):
 
 
 
 
 def set_branch_tracking(
 def set_branch_tracking(
-    repo: Union[str, os.PathLike, Repo], branch_name, remote_name, remote_ref
-):
+    repo: Union[str, os.PathLike, Repo], branch_name: bytes, remote_name: bytes, remote_ref: bytes
+) -> None:
     """Set up branch tracking configuration.
     """Set up branch tracking configuration.
 
 
     Args:
     Args:
@@ -2943,17 +2957,17 @@ def set_branch_tracking(
 
 
 
 
 def fetch(
 def fetch(
-    repo,
-    remote_location=None,
-    outstream=sys.stdout,
-    errstream=default_bytes_err_stream,
-    message=None,
-    depth=None,
-    prune=False,
-    prune_tags=False,
-    force=False,
-    **kwargs,
-):
+    repo: RepoPath,
+    remote_location: Optional[Union[str, bytes]] = None,
+    outstream: TextIO = sys.stdout,
+    errstream: BinaryIO = default_bytes_err_stream,
+    message: Optional[bytes] = None,
+    depth: Optional[int] = None,
+    prune: bool = False,
+    prune_tags: bool = False,
+    force: bool = False,
+    **kwargs: Any,
+) -> FetchPackResult:
     """Fetch objects from a remote server.
     """Fetch objects from a remote server.
 
 
     Args:
     Args:
@@ -3050,7 +3064,7 @@ def for_each_ref(
     return ret
     return ret
 
 
 
 
-def ls_remote(remote, config: Optional[Config] = None, **kwargs):
+def ls_remote(remote: Union[str, bytes], config: Optional[Config] = None, **kwargs: Any) -> LsRemoteResult:
     """List the refs in a remote.
     """List the refs in a remote.
 
 
     Args:
     Args:
@@ -3079,14 +3093,14 @@ def repack(repo: RepoPath) -> None:
 
 
 
 
 def pack_objects(
 def pack_objects(
-    repo,
-    object_ids,
-    packf,
-    idxf,
-    delta_window_size=None,
-    deltify=None,
-    reuse_deltas=True,
-    pack_index_version=None,
+    repo: RepoPath,
+    object_ids: list[bytes],
+    packf: BinaryIO,
+    idxf: Optional[BinaryIO],
+    delta_window_size: Optional[int] = None,
+    deltify: Optional[bool] = None,
+    reuse_deltas: bool = True,
+    pack_index_version: Optional[int] = None,
 ) -> None:
 ) -> None:
     """Pack objects into a file.
     """Pack objects into a file.
 
 
@@ -3116,11 +3130,11 @@ def pack_objects(
 
 
 
 
 def ls_tree(
 def ls_tree(
-    repo,
+    repo: RepoPath,
     treeish: Union[str, bytes, Commit, Tree, Tag] = b"HEAD",
     treeish: Union[str, bytes, Commit, Tree, Tag] = b"HEAD",
-    outstream=sys.stdout,
-    recursive=False,
-    name_only=False,
+    outstream: TextIO = sys.stdout,
+    recursive: bool = False,
+    name_only: bool = False,
 ) -> None:
 ) -> None:
     """List contents of a tree.
     """List contents of a tree.
 
 
@@ -3132,7 +3146,7 @@ def ls_tree(
       name_only: Only print item name
       name_only: Only print item name
     """
     """
 
 
-    def list_tree(store, treeid, base) -> None:
+    def list_tree(store: BaseObjectStore, treeid: bytes, base: bytes) -> None:
         for name, mode, sha in store[treeid].iteritems():
         for name, mode, sha in store[treeid].iteritems():
             if base:
             if base:
                 name = posixpath.join(base, name)
                 name = posixpath.join(base, name)
@@ -3226,7 +3240,7 @@ def _quote_path(path: str) -> str:
     return quoted
     return quoted
 
 
 
 
-def check_ignore(repo: RepoPath, paths, no_index=False, quote_path=True):
+def check_ignore(repo: RepoPath, paths: list[Union[str, bytes, os.PathLike]], no_index: bool = False, quote_path: bool = True) -> Iterator[str]:
     r"""Debug gitignore files.
     r"""Debug gitignore files.
 
 
     Args:
     Args:
@@ -3278,7 +3292,7 @@ def check_ignore(repo: RepoPath, paths, no_index=False, quote_path=True):
                 yield _quote_path(output_path) if quote_path else output_path
                 yield _quote_path(output_path) if quote_path else output_path
 
 
 
 
-def update_head(repo: RepoPath, target, detached=False, new_branch=None) -> None:
+def update_head(repo: RepoPath, target: Union[str, bytes], detached: bool = False, new_branch: Optional[Union[str, bytes]] = None) -> None:
     """Update HEAD to point at a new branch/commit.
     """Update HEAD to point at a new branch/commit.
 
 
     Note that this does not actually update the working tree.
     Note that this does not actually update the working tree.
@@ -3554,10 +3568,10 @@ def checkout(
 
 
 
 
 def reset_file(
 def reset_file(
-    repo,
+    repo: Repo,
     file_path: str,
     file_path: str,
     target: Union[str, bytes, Commit, Tree, Tag] = b"HEAD",
     target: Union[str, bytes, Commit, Tree, Tag] = b"HEAD",
-    symlink_fn=None,
+    symlink_fn: Optional[Callable[[bytes, bytes], None]] = None,
 ) -> None:
 ) -> None:
     """Reset the file to specific commit or branch.
     """Reset the file to specific commit or branch.
 
 
@@ -3599,10 +3613,10 @@ def checkout_branch(
 
 
 def sparse_checkout(
 def sparse_checkout(
     repo: Union[str, os.PathLike, Repo],
     repo: Union[str, os.PathLike, Repo],
-    patterns=None,
+    patterns: Optional[list[str]] = None,
     force: bool = False,
     force: bool = False,
-    cone: Union[bool, None] = None,
-):
+    cone: Optional[bool] = None,
+) -> None:
     """Perform a sparse checkout in the repository (either 'full' or 'cone mode').
     """Perform a sparse checkout in the repository (either 'full' or 'cone mode').
 
 
     Perform sparse checkout in either 'cone' (directory-based) mode or
     Perform sparse checkout in either 'cone' (directory-based) mode or
@@ -3794,7 +3808,7 @@ def stash_pop(repo: Union[str, os.PathLike, Repo]) -> None:
         stash.pop(0)
         stash.pop(0)
 
 
 
 
-def stash_drop(repo: Union[str, os.PathLike, Repo], index) -> None:
+def stash_drop(repo: Union[str, os.PathLike, Repo], index: int) -> None:
     """Drop a stash from the stack."""
     """Drop a stash from the stack."""
     with open_repo_closing(repo) as r:
     with open_repo_closing(repo) as r:
         from .stash import Stash
         from .stash import Stash
@@ -3803,13 +3817,13 @@ def stash_drop(repo: Union[str, os.PathLike, Repo], index) -> None:
         stash.drop(index)
         stash.drop(index)
 
 
 
 
-def ls_files(repo: RepoPath):
+def ls_files(repo: RepoPath) -> list[bytes]:
     """List all files in an index."""
     """List all files in an index."""
     with open_repo_closing(repo) as r:
     with open_repo_closing(repo) as r:
         return sorted(r.open_index())
         return sorted(r.open_index())
 
 
 
 
-def find_unique_abbrev(object_store, object_id, min_length=7):
+def find_unique_abbrev(object_store: BaseObjectStore, object_id: Union[str, bytes], min_length: int = 7) -> str:
     """Find the shortest unique abbreviation for an object ID.
     """Find the shortest unique abbreviation for an object ID.
 
 
     Args:
     Args: