Преглед изворни кода

Update porcelain function type annotations to accept object instances

Now that parse_commit() and parse_tree() accept Commit/Tree/Tag objects,
update the type annotations of porcelain functions to reflect this.
Jelmer Vernooij пре 1 месец
родитељ
комит
8ac1e1c5f0
2 измењених фајлова са 95 додато и 38 уклоњено
  1. 8 4
      dulwich/objectspec.py
  2. 87 34
      dulwich/porcelain.py

+ 8 - 4
dulwich/objectspec.py

@@ -64,11 +64,11 @@ def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "
     # If already a Tree, return it directly
     if isinstance(treeish, Tree):
         return treeish
-    
+
     # If it's a Commit, return its tree
     if isinstance(treeish, Commit):
         return repo[treeish.tree]
-    
+
     # For Tag objects or strings, use the existing logic
     if isinstance(treeish, Tag):
         treeish = treeish.id
@@ -250,12 +250,12 @@ def scan_for_short_id(object_store, prefix, tp):
     raise AmbiguousShortId(prefix, ret)
 
 
-def parse_commit(repo: "Repo", committish: Union[str, bytes, Commit]) -> "Commit":
+def parse_commit(repo: "Repo", committish: Union[str, bytes, Commit, Tag]) -> "Commit":
     """Parse a string referring to a single commit.
 
     Args:
       repo: A` Repo` object
-      committish: A string referring to a single commit, or a Commit object.
+      committish: A string referring to a single commit, or a Commit or Tag object.
     Returns: A Commit object
     Raises:
       KeyError: When the reference commits can not be found
@@ -279,6 +279,10 @@ def parse_commit(repo: "Repo", committish: Union[str, bytes, Commit]) -> "Commit
     if isinstance(committish, Commit):
         return committish
 
+    # If it's a Tag object, dereference it
+    if isinstance(committish, Tag):
+        return dereference_tag(committish)
+
     committish = to_bytes(committish)
     try:
         obj = repo[committish]

+ 87 - 34
dulwich/porcelain.py

@@ -116,6 +116,7 @@ from .object_store import tree_lookup_path
 from .objects import (
     Commit,
     Tag,
+    Tree,
     format_timezone,
     parse_timezone,
     pretty_format_tree_entry,
@@ -381,7 +382,7 @@ def check_diverged(repo, current_sha, new_sha) -> None:
 
 def archive(
     repo,
-    committish=None,
+    committish: Optional[Union[str, bytes, Commit, Tag]] = None,
     outstream=default_bytes_out_stream,
     errstream=default_bytes_err_stream,
 ) -> None:
@@ -1662,7 +1663,7 @@ def notes_list(repo, ref=b"commits"):
         return r.notes.list_notes(notes_ref, config=config)
 
 
-def reset(repo, mode, treeish="HEAD") -> None:
+def reset(repo, mode, treeish: Union[str, bytes, Commit, Tree, Tag] = "HEAD") -> None:
     """Reset current HEAD to the specified state.
 
     Args:
@@ -1673,10 +1674,16 @@ def reset(repo, mode, treeish="HEAD") -> None:
     with open_repo_closing(repo) as r:
         # Parse the target tree
         tree = parse_tree(r, treeish)
-        target_commit = parse_commit(r, treeish)
+        # Only parse as commit if treeish is not a Tree object
+        if isinstance(treeish, Tree):
+            # For Tree objects, we can't determine the commit, skip updating HEAD
+            target_commit = None
+        else:
+            target_commit = parse_commit(r, treeish)
 
         # Update HEAD to point to the target commit
-        r.refs[b"HEAD"] = target_commit.id
+        if target_commit is not None:
+            r.refs[b"HEAD"] = target_commit.id
 
         if mode == "soft":
             # Soft reset: only update HEAD, leave index and working tree unchanged
@@ -2762,7 +2769,7 @@ def pack_objects(
 
 def ls_tree(
     repo,
-    treeish=b"HEAD",
+    treeish: Union[str, bytes, Commit, Tree, Tag] = b"HEAD",
     outstream=sys.stdout,
     recursive=False,
     name_only=False,
@@ -2949,7 +2956,7 @@ def update_head(repo, target, detached=False, new_branch=None) -> None:
 
 def checkout(
     repo,
-    target: Union[bytes, str],
+    target: Union[str, bytes, Commit, Tag],
     force: bool = False,
     new_branch: Optional[Union[bytes, str]] = None,
 ) -> None:
@@ -2970,13 +2977,21 @@ def checkout(
       KeyError: If the target reference cannot be found
     """
     with open_repo_closing(repo) as r:
+        # Store the original target for later reference checks
+        original_target = target
         if isinstance(target, str):
-            target = target.encode(DEFAULT_ENCODING)
+            target_bytes = target.encode(DEFAULT_ENCODING)
+        elif isinstance(target, bytes):
+            target_bytes = target
+        else:
+            # For Commit/Tag objects, we'll use their SHA
+            target_bytes = target.id
+
         if isinstance(new_branch, str):
             new_branch = new_branch.encode(DEFAULT_ENCODING)
 
         # Parse the target to get the commit
-        target_commit = parse_commit(r, target)
+        target_commit = parse_commit(r, original_target)
         target_tree_id = target_commit.tree
 
         # Get current HEAD tree for comparison
@@ -3064,9 +3079,11 @@ def checkout(
             # Set up tracking if creating from a remote branch
             from .refs import LOCAL_REMOTE_PREFIX, parse_remote_ref
 
-            if target.startswith(LOCAL_REMOTE_PREFIX):
+            if isinstance(original_target, bytes) and target_bytes.startswith(
+                LOCAL_REMOTE_PREFIX
+            ):
                 try:
-                    remote_name, branch_name = parse_remote_ref(target)
+                    remote_name, branch_name = parse_remote_ref(target_bytes)
                     # Set tracking to refs/heads/<branch> on the remote
                     set_branch_tracking(
                         r, new_branch, remote_name, b"refs/heads/" + branch_name
@@ -3077,12 +3094,19 @@ def checkout(
         else:
             # Check if target is a branch name (with or without refs/heads/ prefix)
             branch_ref = None
-            if target in r.refs.keys():
-                if target.startswith(LOCAL_BRANCH_PREFIX):
-                    branch_ref = target
+            if (
+                isinstance(original_target, (str, bytes))
+                and target_bytes in r.refs.keys()
+            ):
+                if target_bytes.startswith(LOCAL_BRANCH_PREFIX):
+                    branch_ref = target_bytes
             else:
                 # Try adding refs/heads/ prefix
-                potential_branch = _make_branch_ref(target)
+                potential_branch = (
+                    _make_branch_ref(target_bytes)
+                    if isinstance(original_target, (str, bytes))
+                    else None
+                )
                 if potential_branch in r.refs.keys():
                     branch_ref = potential_branch
 
@@ -3094,7 +3118,12 @@ def checkout(
                 update_head(r, target_commit.id.decode("ascii"), detached=True)
 
 
-def reset_file(repo, file_path: str, target: bytes = b"HEAD", symlink_fn=None) -> None:
+def reset_file(
+    repo,
+    file_path: str,
+    target: Union[str, bytes, Commit, Tree, Tag] = b"HEAD",
+    symlink_fn=None,
+) -> None:
     """Reset the file to specific commit or branch.
 
     Args:
@@ -3451,7 +3480,9 @@ def describe(repo, abbrev=None):
         return f"g{find_unique_abbrev(r.object_store, latest_commit.id)}"
 
 
-def get_object_by_path(repo, path, committish=None):
+def get_object_by_path(
+    repo, path, committish: Optional[Union[str, bytes, Commit, Tag]] = None
+):
     """Get an object by path.
 
     Args:
@@ -3599,7 +3630,7 @@ def _do_merge(
 
 def merge(
     repo,
-    committish,
+    committish: Union[str, bytes, Commit, Tag],
     no_commit=False,
     no_ff=False,
     message=None,
@@ -3629,7 +3660,9 @@ def merge(
         try:
             merge_commit_id = parse_commit(r, committish).id
         except KeyError:
-            raise Error(f"Cannot find commit '{committish}'")
+            raise Error(
+                f"Cannot find commit '{committish.decode() if isinstance(committish, bytes) else committish}'"
+            )
 
         result = _do_merge(
             r, merge_commit_id, no_commit, no_ff, message, author, committer
@@ -3666,7 +3699,12 @@ def unpack_objects(pack_path, target="."):
             return count
 
 
-def merge_tree(repo, base_tree, our_tree, their_tree):
+def merge_tree(
+    repo,
+    base_tree: Optional[Union[str, bytes, Tree, Commit, Tag]],
+    our_tree: Union[str, bytes, Tree, Commit, Tag],
+    their_tree: Union[str, bytes, Tree, Commit, Tag],
+):
     """Perform a three-way tree merge without touching the working directory.
 
     This is similar to git merge-tree, performing a merge at the tree level
@@ -3706,7 +3744,7 @@ def merge_tree(repo, base_tree, our_tree, their_tree):
 
 def cherry_pick(
     repo,
-    committish,
+    committish: Union[str, bytes, Commit, Tag, None],
     no_commit=False,
     continue_=False,
     abort=False,
@@ -3715,9 +3753,9 @@ def cherry_pick(
 
     Args:
       repo: Repository to cherry-pick into
-      committish: Commit to cherry-pick
+      committish: Commit to cherry-pick (can be None only when ``continue_`` or abort is True)
       no_commit: If True, do not create a commit after applying changes
-      continue\_: Continue an in-progress cherry-pick after resolving conflicts
+      ``continue_``: Continue an in-progress cherry-pick after resolving conflicts
       abort: Abort an in-progress cherry-pick
 
     Returns:
@@ -3728,6 +3766,10 @@ def cherry_pick(
     """
     from .merge import three_way_merge
 
+    # Validate that committish is provided when needed
+    if not (continue_ or abort) and committish is None:
+        raise ValueError("committish is required when not using --continue or --abort")
+
     with open_repo_closing(repo) as r:
         # Handle abort
         if abort:
@@ -3799,10 +3841,14 @@ def cherry_pick(
             raise Error("No HEAD reference found")
 
         # Parse the commit to cherry-pick
+        # committish cannot be None here due to validation above
+        assert committish is not None
         try:
             cherry_pick_commit = parse_commit(r, committish)
         except KeyError:
-            raise Error(f"Cannot find commit '{committish}'")
+            raise Error(
+                f"Cannot find commit '{committish.decode() if isinstance(committish, bytes) else committish}'"
+            )
 
         # Check if commit has parents
         if not cherry_pick_commit.parents:
@@ -3860,7 +3906,7 @@ def cherry_pick(
 
 def revert(
     repo,
-    commits,
+    commits: Union[str, bytes, Commit, Tag, list[Union[str, bytes, Commit, Tag]]],
     no_commit=False,
     message=None,
     author=None,
@@ -3889,7 +3935,7 @@ def revert(
     from .merge import three_way_merge
 
     # Normalize commits to a list
-    if isinstance(commits, (str, bytes)):
+    if isinstance(commits, (str, bytes, Commit, Tag)):
         commits = [commits]
 
     with open_repo_closing(repo) as r:
@@ -3917,13 +3963,13 @@ def revert(
 
             if not commit_to_revert.parents:
                 raise Error(
-                    f"Cannot revert commit {commit_to_revert.id} - it has no parents"
+                    f"Cannot revert commit {commit_to_revert.id.decode() if isinstance(commit_to_revert.id, bytes) else commit_to_revert.id} - it has no parents"
                 )
 
             # For simplicity, we only handle commits with one parent (no merge commits)
             if len(commit_to_revert.parents) > 1:
                 raise Error(
-                    f"Cannot revert merge commit {commit_to_revert.id} - not yet implemented"
+                    f"Cannot revert merge commit {commit_to_revert.id.decode() if isinstance(commit_to_revert.id, bytes) else commit_to_revert.id} - not yet implemented"
                 )
 
             parent_commit = r[commit_to_revert.parents[0]]
@@ -4204,7 +4250,7 @@ def rebase(
             raise Error(str(e))
 
 
-def annotate(repo, path, committish=None):
+def annotate(repo, path, committish: Optional[Union[str, bytes, Commit, Tag]] = None):
     """Annotate the history of a file.
 
     :param repo: Path to the repository
@@ -4358,8 +4404,10 @@ def filter_branch(
 
 def bisect_start(
     repo=".",
-    bad=None,
-    good=None,
+    bad: Optional[Union[str, bytes, Commit, Tag]] = None,
+    good: Optional[
+        Union[str, bytes, Commit, Tag, list[Union[str, bytes, Commit, Tag]]]
+    ] = None,
     paths=None,
     no_checkout=False,
     term_bad="bad",
@@ -4401,7 +4449,7 @@ def bisect_start(
             return next_sha
 
 
-def bisect_bad(repo=".", rev=None):
+def bisect_bad(repo=".", rev: Optional[Union[str, bytes, Commit, Tag]] = None):
     """Mark a commit as bad.
 
     Args:
@@ -4426,7 +4474,7 @@ def bisect_bad(repo=".", rev=None):
         return next_sha
 
 
-def bisect_good(repo=".", rev=None):
+def bisect_good(repo=".", rev: Optional[Union[str, bytes, Commit, Tag]] = None):
     """Mark a commit as good.
 
     Args:
@@ -4451,7 +4499,12 @@ def bisect_good(repo=".", rev=None):
         return next_sha
 
 
-def bisect_skip(repo=".", revs=None):
+def bisect_skip(
+    repo=".",
+    revs: Optional[
+        Union[str, bytes, Commit, Tag, list[Union[str, bytes, Commit, Tag]]]
+    ] = None,
+):
     """Skip one or more commits.
 
     Args:
@@ -4484,7 +4537,7 @@ def bisect_skip(repo=".", revs=None):
         return next_sha
 
 
-def bisect_reset(repo=".", commit=None):
+def bisect_reset(repo=".", commit: Optional[Union[str, bytes, Commit, Tag]] = None):
     """Reset bisect state and return to original branch/commit.
 
     Args: