瀏覽代碼

Allow Tree, Commit, and Tag objects to be parsed by parse_tree()

Jelmer Vernooij 1 月之前
父節點
當前提交
31c0076f11
共有 2 個文件被更改,包括 48 次插入4 次删除
  1. 20 4
      dulwich/objectspec.py
  2. 28 0
      tests/test_objectspec.py

+ 20 - 4
dulwich/objectspec.py

@@ -51,17 +51,29 @@ def parse_object(repo: "Repo", objectish: Union[bytes, str]) -> "ShaFile":
     return repo[objectish]
 
 
-def parse_tree(repo: "Repo", treeish: Union[bytes, str]) -> "Tree":
+def parse_tree(repo: "Repo", treeish: Union[bytes, str, Tree, Commit, Tag]) -> "Tree":
     """Parse a string referring to a tree.
 
     Args:
       repo: A `Repo` object
-      treeish: A string referring to a tree
-    Returns: A git object
+      treeish: A string referring to a tree, or a Tree, Commit, or Tag object
+    Returns: A Tree object
     Raises:
       KeyError: If the object can not be found
     """
-    treeish = to_bytes(treeish)
+    # If already a Tree, return it directly
+    if isinstance(treeish, Tree):
+        return treeish
+    
+    # If it's a Commit, return its tree
+    if isinstance(treeish, Commit):
+        return repo[treeish.tree]
+    
+    # For Tag objects or strings, use the existing logic
+    if isinstance(treeish, Tag):
+        treeish = treeish.id
+    else:
+        treeish = to_bytes(treeish)
     try:
         treeish = parse_ref(repo, treeish)
     except KeyError:  # treeish is commit sha
@@ -77,6 +89,10 @@ def parse_tree(repo: "Repo", treeish: Union[bytes, str]) -> "Tree":
             raise KeyError(treeish)
     if o.type_name == b"commit":
         return repo[o.tree]
+    elif o.type_name == b"tag":
+        # Tag handling - dereference and recurse
+        obj_type, obj_sha = o.object
+        return parse_tree(repo, obj_sha)
     return o
 
 

+ 28 - 0
tests/test_objectspec.py

@@ -341,3 +341,31 @@ class ParseTreeTests(TestCase):
         c1, c2, c3 = build_commit_graph(r.object_store, [[1], [2, 1], [3, 1, 2]])
         r.refs[b"refs/heads/foo"] = c1.id
         self.assertEqual(r[c1.tree], parse_tree(r, b"foo"))
+
+    def test_tree_object(self) -> None:
+        r = MemoryRepo()
+        [c1] = build_commit_graph(r.object_store, [[1]])
+        tree = r[c1.tree]
+        # Test that passing a Tree object directly returns the same object
+        self.assertEqual(tree, parse_tree(r, tree))
+
+    def test_commit_object(self) -> None:
+        r = MemoryRepo()
+        [c1] = build_commit_graph(r.object_store, [[1]])
+        # Test that passing a Commit object returns its tree
+        self.assertEqual(r[c1.tree], parse_tree(r, c1))
+
+    def test_tag_object(self) -> None:
+        r = MemoryRepo()
+        [c1] = build_commit_graph(r.object_store, [[1]])
+        # Create an annotated tag pointing to the commit
+        tag = Tag()
+        tag.name = b"v1.0"
+        tag.message = b"Test tag"
+        tag.tag_time = 1234567890
+        tag.tag_timezone = 0
+        tag.object = (Commit, c1.id)
+        tag.tagger = b"Test Tagger <test@example.com>"
+        r.object_store.add_object(tag)
+        # parse_tree should follow the tag to the commit's tree
+        self.assertEqual(r[c1.tree], parse_tree(r, tag))