|
@@ -1,6 +1,6 @@
|
|
|
"""Git merge implementation."""
|
|
|
|
|
|
-from typing import Optional, cast
|
|
|
+from typing import Optional
|
|
|
|
|
|
try:
|
|
|
import merge3
|
|
@@ -8,13 +8,13 @@ except ImportError:
|
|
|
merge3 = None # type: ignore
|
|
|
|
|
|
from dulwich.object_store import BaseObjectStore
|
|
|
-from dulwich.objects import S_ISGITLINK, Blob, Commit, Tree
|
|
|
+from dulwich.objects import S_ISGITLINK, Blob, Commit, Tree, is_blob, is_tree
|
|
|
|
|
|
|
|
|
class MergeConflict(Exception):
|
|
|
"""Raised when a merge conflict occurs."""
|
|
|
|
|
|
- def __init__(self, path: bytes, message: str):
|
|
|
+ def __init__(self, path: bytes, message: str) -> None:
|
|
|
self.path = path
|
|
|
super().__init__(f"Merge conflict in {path!r}: {message}")
|
|
|
|
|
@@ -183,7 +183,7 @@ def merge_blobs(
|
|
|
class Merger:
|
|
|
"""Handles git merge operations."""
|
|
|
|
|
|
- def __init__(self, object_store: BaseObjectStore):
|
|
|
+ def __init__(self, object_store: BaseObjectStore) -> None:
|
|
|
"""Initialize merger.
|
|
|
|
|
|
Args:
|
|
@@ -341,17 +341,35 @@ class Merger:
|
|
|
merged_entries[path] = (ours_mode, ours_sha)
|
|
|
else:
|
|
|
# Try to merge blobs
|
|
|
- base_blob = (
|
|
|
- cast(Blob, self.object_store[base_sha]) if base_sha else None
|
|
|
- )
|
|
|
- ours_blob = (
|
|
|
- cast(Blob, self.object_store[ours_sha]) if ours_sha else None
|
|
|
- )
|
|
|
- theirs_blob = (
|
|
|
- cast(Blob, self.object_store[theirs_sha])
|
|
|
- if theirs_sha
|
|
|
- else None
|
|
|
- )
|
|
|
+ base_blob = None
|
|
|
+ if base_sha:
|
|
|
+ base_obj = self.object_store[base_sha]
|
|
|
+ if is_blob(base_obj):
|
|
|
+ base_blob = base_obj
|
|
|
+ else:
|
|
|
+ raise TypeError(
|
|
|
+ f"Expected blob for {path!r}, got {base_obj.type_name.decode()}"
|
|
|
+ )
|
|
|
+
|
|
|
+ ours_blob = None
|
|
|
+ if ours_sha:
|
|
|
+ ours_obj = self.object_store[ours_sha]
|
|
|
+ if is_blob(ours_obj):
|
|
|
+ ours_blob = ours_obj
|
|
|
+ else:
|
|
|
+ raise TypeError(
|
|
|
+ f"Expected blob for {path!r}, got {ours_obj.type_name.decode()}"
|
|
|
+ )
|
|
|
+
|
|
|
+ theirs_blob = None
|
|
|
+ if theirs_sha:
|
|
|
+ theirs_obj = self.object_store[theirs_sha]
|
|
|
+ if is_blob(theirs_obj):
|
|
|
+ theirs_blob = theirs_obj
|
|
|
+ else:
|
|
|
+ raise TypeError(
|
|
|
+ f"Expected blob for {path!r}, got {theirs_obj.type_name.decode()}"
|
|
|
+ )
|
|
|
|
|
|
merged_content, had_conflict = self.merge_blobs(
|
|
|
base_blob, ours_blob, theirs_blob
|
|
@@ -368,7 +386,8 @@ class Merger:
|
|
|
# Build merged tree
|
|
|
merged_tree = Tree()
|
|
|
for path, (mode, sha) in sorted(merged_entries.items()):
|
|
|
- merged_tree.add(path, mode, sha)
|
|
|
+ if mode is not None and sha is not None:
|
|
|
+ merged_tree.add(path, mode, sha)
|
|
|
|
|
|
return merged_tree, conflicts
|
|
|
|
|
@@ -392,8 +411,24 @@ def three_way_merge(
|
|
|
"""
|
|
|
merger = Merger(object_store)
|
|
|
|
|
|
- base_tree = cast(Tree, object_store[base_commit.tree]) if base_commit else None
|
|
|
- ours_tree = cast(Tree, object_store[ours_commit.tree])
|
|
|
- theirs_tree = cast(Tree, object_store[theirs_commit.tree])
|
|
|
+ base_tree = None
|
|
|
+ if base_commit:
|
|
|
+ base_obj = object_store[base_commit.tree]
|
|
|
+ if is_tree(base_obj):
|
|
|
+ base_tree = base_obj
|
|
|
+ else:
|
|
|
+ raise TypeError(f"Expected tree, got {base_obj.type_name.decode()}")
|
|
|
+
|
|
|
+ ours_obj = object_store[ours_commit.tree]
|
|
|
+ if is_tree(ours_obj):
|
|
|
+ ours_tree = ours_obj
|
|
|
+ else:
|
|
|
+ raise TypeError(f"Expected tree, got {ours_obj.type_name.decode()}")
|
|
|
+
|
|
|
+ theirs_obj = object_store[theirs_commit.tree]
|
|
|
+ if is_tree(theirs_obj):
|
|
|
+ theirs_tree = theirs_obj
|
|
|
+ else:
|
|
|
+ raise TypeError(f"Expected tree, got {theirs_obj.type_name.decode()}")
|
|
|
|
|
|
return merger.merge_trees(base_tree, ours_tree, theirs_tree)
|