Răsfoiți Sursa

merge: Add typing

Jelmer Vernooij 3 luni în urmă
părinte
comite
67b77e31eb
1 a modificat fișierele cu 54 adăugiri și 19 ștergeri
  1. 54 19
      dulwich/merge.py

+ 54 - 19
dulwich/merge.py

@@ -1,6 +1,6 @@
 """Git merge implementation."""
 
-from typing import Optional, cast
+from typing import Optional
 
 try:
     import merge3
@@ -8,13 +8,13 @@ except ImportError:
     merge3 = None  # type: ignore
 
 from dulwich.object_store import BaseObjectStore
-from dulwich.objects import S_ISGITLINK, Blob, Commit, Tree
+from dulwich.objects import S_ISGITLINK, Blob, Commit, Tree, is_blob, is_tree
 
 
 class MergeConflict(Exception):
     """Raised when a merge conflict occurs."""
 
-    def __init__(self, path: bytes, message: str):
+    def __init__(self, path: bytes, message: str) -> None:
         self.path = path
         super().__init__(f"Merge conflict in {path!r}: {message}")
 
@@ -183,7 +183,7 @@ def merge_blobs(
 class Merger:
     """Handles git merge operations."""
 
-    def __init__(self, object_store: BaseObjectStore):
+    def __init__(self, object_store: BaseObjectStore) -> None:
         """Initialize merger.
 
         Args:
@@ -341,17 +341,35 @@ class Merger:
                     merged_entries[path] = (ours_mode, ours_sha)
                 else:
                     # Try to merge blobs
-                    base_blob = (
-                        cast(Blob, self.object_store[base_sha]) if base_sha else None
-                    )
-                    ours_blob = (
-                        cast(Blob, self.object_store[ours_sha]) if ours_sha else None
-                    )
-                    theirs_blob = (
-                        cast(Blob, self.object_store[theirs_sha])
-                        if theirs_sha
-                        else None
-                    )
+                    base_blob = None
+                    if base_sha:
+                        base_obj = self.object_store[base_sha]
+                        if is_blob(base_obj):
+                            base_blob = base_obj
+                        else:
+                            raise TypeError(
+                                f"Expected blob for {path!r}, got {base_obj.type_name.decode()}"
+                            )
+
+                    ours_blob = None
+                    if ours_sha:
+                        ours_obj = self.object_store[ours_sha]
+                        if is_blob(ours_obj):
+                            ours_blob = ours_obj
+                        else:
+                            raise TypeError(
+                                f"Expected blob for {path!r}, got {ours_obj.type_name.decode()}"
+                            )
+
+                    theirs_blob = None
+                    if theirs_sha:
+                        theirs_obj = self.object_store[theirs_sha]
+                        if is_blob(theirs_obj):
+                            theirs_blob = theirs_obj
+                        else:
+                            raise TypeError(
+                                f"Expected blob for {path!r}, got {theirs_obj.type_name.decode()}"
+                            )
 
                     merged_content, had_conflict = self.merge_blobs(
                         base_blob, ours_blob, theirs_blob
@@ -368,7 +386,8 @@ class Merger:
         # Build merged tree
         merged_tree = Tree()
         for path, (mode, sha) in sorted(merged_entries.items()):
-            merged_tree.add(path, mode, sha)
+            if mode is not None and sha is not None:
+                merged_tree.add(path, mode, sha)
 
         return merged_tree, conflicts
 
@@ -392,8 +411,24 @@ def three_way_merge(
     """
     merger = Merger(object_store)
 
-    base_tree = cast(Tree, object_store[base_commit.tree]) if base_commit else None
-    ours_tree = cast(Tree, object_store[ours_commit.tree])
-    theirs_tree = cast(Tree, object_store[theirs_commit.tree])
+    base_tree = None
+    if base_commit:
+        base_obj = object_store[base_commit.tree]
+        if is_tree(base_obj):
+            base_tree = base_obj
+        else:
+            raise TypeError(f"Expected tree, got {base_obj.type_name.decode()}")
+
+    ours_obj = object_store[ours_commit.tree]
+    if is_tree(ours_obj):
+        ours_tree = ours_obj
+    else:
+        raise TypeError(f"Expected tree, got {ours_obj.type_name.decode()}")
+
+    theirs_obj = object_store[theirs_commit.tree]
+    if is_tree(theirs_obj):
+        theirs_tree = theirs_obj
+    else:
+        raise TypeError(f"Expected tree, got {theirs_obj.type_name.decode()}")
 
     return merger.merge_trees(base_tree, ours_tree, theirs_tree)