Ver código fonte

Allow ready objects to be passed into parse_commit/parse_tree (#1677)

Jelmer Vernooij 1 mês atrás
pai
commit
fe71597b6e
3 arquivos alterados com 151 adições e 40 exclusões
  1. 30 6
      dulwich/objectspec.py
  2. 87 34
      dulwich/porcelain.py
  3. 34 0
      tests/test_objectspec.py

+ 30 - 6
dulwich/objectspec.py

@@ -51,17 +51,29 @@ def parse_object(repo: "Repo", objectish: Union[bytes, str]) -> "ShaFile":
     return repo[objectish]
 
 
-def parse_tree(repo: "Repo", treeish: Union[bytes, str]) -> "Tree":
+def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "Tree":
     """Parse a string referring to a tree.
 
     Args:
       repo: A `Repo` object
-      treeish: A string referring to a tree
-    Returns: A git object
+      treeish: A string referring to a tree, or a Tree, Commit, or Tag object
+    Returns: A Tree object
     Raises:
       KeyError: If the object can not be found
     """
-    treeish = to_bytes(treeish)
+    # If already a Tree, return it directly
+    if isinstance(treeish, Tree):
+        return treeish
+
+    # If it's a Commit, return its tree
+    if isinstance(treeish, Commit):
+        return repo[treeish.tree]
+
+    # For Tag objects or strings, use the existing logic
+    if isinstance(treeish, Tag):
+        treeish = treeish.id
+    else:
+        treeish = to_bytes(treeish)
     try:
         treeish = parse_ref(repo, treeish)
     except KeyError:  # treeish is commit sha
@@ -77,6 +89,10 @@ def parse_tree(repo: "Repo", treeish: Union[bytes, str]) -> "Tree":
             raise KeyError(treeish)
     if o.type_name == b"commit":
         return repo[o.tree]
+    elif o.type_name == b"tag":
+        # Tag handling - dereference and recurse
+        obj_type, obj_sha = o.object
+        return parse_tree(repo, obj_sha)
     return o
 
 
@@ -234,12 +250,12 @@ def scan_for_short_id(object_store, prefix, tp):
     raise AmbiguousShortId(prefix, ret)
 
 
-def parse_commit(repo: "Repo", committish: Union[str, bytes]) -> "Commit":
+def parse_commit(repo: "Repo", committish: Union[str, bytes, Commit, Tag]) -> "Commit":
     """Parse a string referring to a single commit.
 
     Args:
       repo: A` Repo` object
-      committish: A string referring to a single commit.
+      committish: A string referring to a single commit, or a Commit or Tag object.
     Returns: A Commit object
     Raises:
       KeyError: When the reference commits can not be found
@@ -259,6 +275,14 @@ def parse_commit(repo: "Repo", committish: Union[str, bytes]) -> "Commit":
             raise ValueError(f"Expected commit, got {obj.type_name}")
         return obj
 
+    # If already a Commit object, return it directly
+    if isinstance(committish, Commit):
+        return committish
+
+    # If it's a Tag object, dereference it
+    if isinstance(committish, Tag):
+        return dereference_tag(committish)
+
     committish = to_bytes(committish)
     try:
         obj = repo[committish]

+ 87 - 34
dulwich/porcelain.py

@@ -116,6 +116,7 @@ from .object_store import tree_lookup_path
 from .objects import (
     Commit,
     Tag,
+    Tree,
     format_timezone,
     parse_timezone,
     pretty_format_tree_entry,
@@ -381,7 +382,7 @@ def check_diverged(repo, current_sha, new_sha) -> None:
 
 def archive(
     repo,
-    committish=None,
+    committish: Optional[Union[str, bytes, Commit, Tag]] = None,
     outstream=default_bytes_out_stream,
     errstream=default_bytes_err_stream,
 ) -> None:
@@ -1662,7 +1663,7 @@ def notes_list(repo, ref=b"commits"):
         return r.notes.list_notes(notes_ref, config=config)
 
 
-def reset(repo, mode, treeish="HEAD") -> None:
+def reset(repo, mode, treeish: Union[str, bytes, Commit, Tree, Tag] = "HEAD") -> None:
     """Reset current HEAD to the specified state.
 
     Args:
@@ -1673,10 +1674,16 @@ def reset(repo, mode, treeish="HEAD") -> None:
     with open_repo_closing(repo) as r:
         # Parse the target tree
         tree = parse_tree(r, treeish)
-        target_commit = parse_commit(r, treeish)
+        # Only parse as commit if treeish is not a Tree object
+        if isinstance(treeish, Tree):
+            # For Tree objects, we can't determine the commit, skip updating HEAD
+            target_commit = None
+        else:
+            target_commit = parse_commit(r, treeish)
 
         # Update HEAD to point to the target commit
-        r.refs[b"HEAD"] = target_commit.id
+        if target_commit is not None:
+            r.refs[b"HEAD"] = target_commit.id
 
         if mode == "soft":
             # Soft reset: only update HEAD, leave index and working tree unchanged
@@ -2762,7 +2769,7 @@ def pack_objects(
 
 def ls_tree(
     repo,
-    treeish=b"HEAD",
+    treeish: Union[str, bytes, Commit, Tree, Tag] = b"HEAD",
     outstream=sys.stdout,
     recursive=False,
     name_only=False,
@@ -2949,7 +2956,7 @@ def update_head(repo, target, detached=False, new_branch=None) -> None:
 
 def checkout(
     repo,
-    target: Union[bytes, str],
+    target: Union[str, bytes, Commit, Tag],
     force: bool = False,
     new_branch: Optional[Union[bytes, str]] = None,
 ) -> None:
@@ -2970,13 +2977,21 @@ def checkout(
       KeyError: If the target reference cannot be found
     """
     with open_repo_closing(repo) as r:
+        # Store the original target for later reference checks
+        original_target = target
         if isinstance(target, str):
-            target = target.encode(DEFAULT_ENCODING)
+            target_bytes = target.encode(DEFAULT_ENCODING)
+        elif isinstance(target, bytes):
+            target_bytes = target
+        else:
+            # For Commit/Tag objects, we'll use their SHA
+            target_bytes = target.id
+
         if isinstance(new_branch, str):
             new_branch = new_branch.encode(DEFAULT_ENCODING)
 
         # Parse the target to get the commit
-        target_commit = parse_commit(r, target)
+        target_commit = parse_commit(r, original_target)
         target_tree_id = target_commit.tree
 
         # Get current HEAD tree for comparison
@@ -3064,9 +3079,11 @@ def checkout(
             # Set up tracking if creating from a remote branch
             from .refs import LOCAL_REMOTE_PREFIX, parse_remote_ref
 
-            if target.startswith(LOCAL_REMOTE_PREFIX):
+            if isinstance(original_target, bytes) and target_bytes.startswith(
+                LOCAL_REMOTE_PREFIX
+            ):
                 try:
-                    remote_name, branch_name = parse_remote_ref(target)
+                    remote_name, branch_name = parse_remote_ref(target_bytes)
                     # Set tracking to refs/heads/<branch> on the remote
                     set_branch_tracking(
                         r, new_branch, remote_name, b"refs/heads/" + branch_name
@@ -3077,12 +3094,19 @@ def checkout(
         else:
             # Check if target is a branch name (with or without refs/heads/ prefix)
             branch_ref = None
-            if target in r.refs.keys():
-                if target.startswith(LOCAL_BRANCH_PREFIX):
-                    branch_ref = target
+            if (
+                isinstance(original_target, (str, bytes))
+                and target_bytes in r.refs.keys()
+            ):
+                if target_bytes.startswith(LOCAL_BRANCH_PREFIX):
+                    branch_ref = target_bytes
             else:
                 # Try adding refs/heads/ prefix
-                potential_branch = _make_branch_ref(target)
+                potential_branch = (
+                    _make_branch_ref(target_bytes)
+                    if isinstance(original_target, (str, bytes))
+                    else None
+                )
                 if potential_branch in r.refs.keys():
                     branch_ref = potential_branch
 
@@ -3094,7 +3118,12 @@ def checkout(
                 update_head(r, target_commit.id.decode("ascii"), detached=True)
 
 
-def reset_file(repo, file_path: str, target: bytes = b"HEAD", symlink_fn=None) -> None:
+def reset_file(
+    repo,
+    file_path: str,
+    target: Union[str, bytes, Commit, Tree, Tag] = b"HEAD",
+    symlink_fn=None,
+) -> None:
     """Reset the file to specific commit or branch.
 
     Args:
@@ -3451,7 +3480,9 @@ def describe(repo, abbrev=None):
         return f"g{find_unique_abbrev(r.object_store, latest_commit.id)}"
 
 
-def get_object_by_path(repo, path, committish=None):
+def get_object_by_path(
+    repo, path, committish: Optional[Union[str, bytes, Commit, Tag]] = None
+):
     """Get an object by path.
 
     Args:
@@ -3599,7 +3630,7 @@ def _do_merge(
 
 def merge(
     repo,
-    committish,
+    committish: Union[str, bytes, Commit, Tag],
     no_commit=False,
     no_ff=False,
     message=None,
@@ -3629,7 +3660,9 @@ def merge(
         try:
             merge_commit_id = parse_commit(r, committish).id
         except KeyError:
-            raise Error(f"Cannot find commit '{committish}'")
+            raise Error(
+                f"Cannot find commit '{committish.decode() if isinstance(committish, bytes) else committish}'"
+            )
 
         result = _do_merge(
             r, merge_commit_id, no_commit, no_ff, message, author, committer
@@ -3666,7 +3699,12 @@ def unpack_objects(pack_path, target="."):
             return count
 
 
-def merge_tree(repo, base_tree, our_tree, their_tree):
+def merge_tree(
+    repo,
+    base_tree: Optional[Union[str, bytes, Tree, Commit, Tag]],
+    our_tree: Union[str, bytes, Tree, Commit, Tag],
+    their_tree: Union[str, bytes, Tree, Commit, Tag],
+):
     """Perform a three-way tree merge without touching the working directory.
 
     This is similar to git merge-tree, performing a merge at the tree level
@@ -3706,7 +3744,7 @@ def merge_tree(repo, base_tree, our_tree, their_tree):
 
 def cherry_pick(
     repo,
-    committish,
+    committish: Union[str, bytes, Commit, Tag, None],
     no_commit=False,
     continue_=False,
     abort=False,
@@ -3715,9 +3753,9 @@ def cherry_pick(
 
     Args:
       repo: Repository to cherry-pick into
-      committish: Commit to cherry-pick
+      committish: Commit to cherry-pick (can be None only when ``continue_`` or abort is True)
       no_commit: If True, do not create a commit after applying changes
-      continue\_: Continue an in-progress cherry-pick after resolving conflicts
+      ``continue_``: Continue an in-progress cherry-pick after resolving conflicts
       abort: Abort an in-progress cherry-pick
 
     Returns:
@@ -3728,6 +3766,10 @@ def cherry_pick(
     """
     from .merge import three_way_merge
 
+    # Validate that committish is provided when needed
+    if not (continue_ or abort) and committish is None:
+        raise ValueError("committish is required when not using --continue or --abort")
+
     with open_repo_closing(repo) as r:
         # Handle abort
         if abort:
@@ -3799,10 +3841,14 @@ def cherry_pick(
             raise Error("No HEAD reference found")
 
         # Parse the commit to cherry-pick
+        # committish cannot be None here due to validation above
+        assert committish is not None
         try:
             cherry_pick_commit = parse_commit(r, committish)
         except KeyError:
-            raise Error(f"Cannot find commit '{committish}'")
+            raise Error(
+                f"Cannot find commit '{committish.decode() if isinstance(committish, bytes) else committish}'"
+            )
 
         # Check if commit has parents
         if not cherry_pick_commit.parents:
@@ -3860,7 +3906,7 @@ def cherry_pick(
 
 def revert(
     repo,
-    commits,
+    commits: Union[str, bytes, Commit, Tag, list[Union[str, bytes, Commit, Tag]]],
     no_commit=False,
     message=None,
     author=None,
@@ -3889,7 +3935,7 @@ def revert(
     from .merge import three_way_merge
 
     # Normalize commits to a list
-    if isinstance(commits, (str, bytes)):
+    if isinstance(commits, (str, bytes, Commit, Tag)):
         commits = [commits]
 
     with open_repo_closing(repo) as r:
@@ -3917,13 +3963,13 @@ def revert(
 
             if not commit_to_revert.parents:
                 raise Error(
-                    f"Cannot revert commit {commit_to_revert.id} - it has no parents"
+                    f"Cannot revert commit {commit_to_revert.id.decode() if isinstance(commit_to_revert.id, bytes) else commit_to_revert.id} - it has no parents"
                 )
 
             # For simplicity, we only handle commits with one parent (no merge commits)
             if len(commit_to_revert.parents) > 1:
                 raise Error(
-                    f"Cannot revert merge commit {commit_to_revert.id} - not yet implemented"
+                    f"Cannot revert merge commit {commit_to_revert.id.decode() if isinstance(commit_to_revert.id, bytes) else commit_to_revert.id} - not yet implemented"
                 )
 
             parent_commit = r[commit_to_revert.parents[0]]
@@ -4204,7 +4250,7 @@ def rebase(
             raise Error(str(e))
 
 
-def annotate(repo, path, committish=None):
+def annotate(repo, path, committish: Optional[Union[str, bytes, Commit, Tag]] = None):
     """Annotate the history of a file.
 
     :param repo: Path to the repository
@@ -4358,8 +4404,10 @@ def filter_branch(
 
 def bisect_start(
     repo=".",
-    bad=None,
-    good=None,
+    bad: Optional[Union[str, bytes, Commit, Tag]] = None,
+    good: Optional[
+        Union[str, bytes, Commit, Tag, list[Union[str, bytes, Commit, Tag]]]
+    ] = None,
     paths=None,
     no_checkout=False,
     term_bad="bad",
@@ -4401,7 +4449,7 @@ def bisect_start(
             return next_sha
 
 
-def bisect_bad(repo=".", rev=None):
+def bisect_bad(repo=".", rev: Optional[Union[str, bytes, Commit, Tag]] = None):
     """Mark a commit as bad.
 
     Args:
@@ -4426,7 +4474,7 @@ def bisect_bad(repo=".", rev=None):
         return next_sha
 
 
-def bisect_good(repo=".", rev=None):
+def bisect_good(repo=".", rev: Optional[Union[str, bytes, Commit, Tag]] = None):
     """Mark a commit as good.
 
     Args:
@@ -4451,7 +4499,12 @@ def bisect_good(repo=".", rev=None):
         return next_sha
 
 
-def bisect_skip(repo=".", revs=None):
+def bisect_skip(
+    repo=".",
+    revs: Optional[
+        Union[str, bytes, Commit, Tag, list[Union[str, bytes, Commit, Tag]]]
+    ] = None,
+):
     """Skip one or more commits.
 
     Args:
@@ -4484,7 +4537,7 @@ def bisect_skip(repo=".", revs=None):
         return next_sha
 
 
-def bisect_reset(repo=".", commit=None):
+def bisect_reset(repo=".", commit: Optional[Union[str, bytes, Commit, Tag]] = None):
     """Reset bisect state and return to original branch/commit.
 
     Args:

+ 34 - 0
tests/test_objectspec.py

@@ -160,6 +160,12 @@ class ParseCommitTests(TestCase):
         # Should raise ValueError as it's not a commit
         self.assertRaises(ValueError, parse_commit, r, tag.id)
 
+    def test_commit_object(self) -> None:
+        r = MemoryRepo()
+        [c1] = build_commit_graph(r.object_store, [[1]])
+        # Test that passing a Commit object directly returns the same object
+        self.assertEqual(c1, parse_commit(r, c1))
+
 
 class ParseRefTests(TestCase):
     def test_nonexistent(self) -> None:
@@ -335,3 +341,31 @@ class ParseTreeTests(TestCase):
         c1, c2, c3 = build_commit_graph(r.object_store, [[1], [2, 1], [3, 1, 2]])
         r.refs[b"refs/heads/foo"] = c1.id
         self.assertEqual(r[c1.tree], parse_tree(r, b"foo"))
+
+    def test_tree_object(self) -> None:
+        r = MemoryRepo()
+        [c1] = build_commit_graph(r.object_store, [[1]])
+        tree = r[c1.tree]
+        # Test that passing a Tree object directly returns the same object
+        self.assertEqual(tree, parse_tree(r, tree))
+
+    def test_commit_object(self) -> None:
+        r = MemoryRepo()
+        [c1] = build_commit_graph(r.object_store, [[1]])
+        # Test that passing a Commit object returns its tree
+        self.assertEqual(r[c1.tree], parse_tree(r, c1))
+
+    def test_tag_object(self) -> None:
+        r = MemoryRepo()
+        [c1] = build_commit_graph(r.object_store, [[1]])
+        # Create an annotated tag pointing to the commit
+        tag = Tag()
+        tag.name = b"v1.0"
+        tag.message = b"Test tag"
+        tag.tag_time = 1234567890
+        tag.tag_timezone = 0
+        tag.object = (Commit, c1.id)
+        tag.tagger = b"Test Tagger <test@example.com>"
+        r.object_store.add_object(tag)
+        # parse_tree should follow the tag to the commit's tree
+        self.assertEqual(r[c1.tree], parse_tree(r, tag))