merge.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. """Git merge implementation."""
  2. from typing import Optional, cast
  3. try:
  4. import merge3
  5. except ImportError:
  6. merge3 = None
  7. from dulwich.object_store import BaseObjectStore
  8. from dulwich.objects import S_ISGITLINK, Blob, Commit, Tree
  9. class MergeConflict(Exception):
  10. """Raised when a merge conflict occurs."""
  11. def __init__(self, path: bytes, message: str):
  12. self.path = path
  13. super().__init__(f"Merge conflict in {path!r}: {message}")
  14. class Merger:
  15. """Handles git merge operations."""
  16. def __init__(self, object_store: BaseObjectStore):
  17. """Initialize merger.
  18. Args:
  19. object_store: Object store to read objects from
  20. """
  21. self.object_store = object_store
  22. def merge_blobs(
  23. self,
  24. base_blob: Optional[Blob],
  25. ours_blob: Optional[Blob],
  26. theirs_blob: Optional[Blob],
  27. ) -> tuple[bytes, bool]:
  28. """Perform three-way merge on blob contents.
  29. Args:
  30. base_blob: Common ancestor blob (can be None)
  31. ours_blob: Our version of the blob (can be None)
  32. theirs_blob: Their version of the blob (can be None)
  33. Returns:
  34. Tuple of (merged_content, had_conflicts)
  35. """
  36. if merge3 is None:
  37. raise ImportError(
  38. "merge3 is required for merging. Install with: pip install dulwich[merge]"
  39. )
  40. # Handle deletion cases
  41. if ours_blob is None and theirs_blob is None:
  42. return b"", False
  43. if base_blob is None:
  44. # No common ancestor
  45. if ours_blob is None:
  46. assert theirs_blob is not None
  47. return theirs_blob.data, False
  48. elif theirs_blob is None:
  49. return ours_blob.data, False
  50. elif ours_blob.data == theirs_blob.data:
  51. return ours_blob.data, False
  52. else:
  53. # Both added different content - conflict
  54. m = merge3.Merge3(
  55. [],
  56. ours_blob.data.splitlines(True),
  57. theirs_blob.data.splitlines(True),
  58. )
  59. return self._merge3_to_bytes(m), True
  60. # Get content for each version
  61. base_content = base_blob.data if base_blob else b""
  62. ours_content = ours_blob.data if ours_blob else b""
  63. theirs_content = theirs_blob.data if theirs_blob else b""
  64. # Check if either side deleted
  65. if ours_blob is None or theirs_blob is None:
  66. if ours_blob is None and theirs_blob is None:
  67. return b"", False
  68. elif ours_blob is None:
  69. # We deleted, check if they modified
  70. if base_content == theirs_content:
  71. return b"", False # They didn't modify, accept deletion
  72. else:
  73. # Conflict: we deleted, they modified
  74. m = merge3.Merge3(
  75. base_content.splitlines(True),
  76. [],
  77. theirs_content.splitlines(True),
  78. )
  79. return self._merge3_to_bytes(m), True
  80. else:
  81. # They deleted, check if we modified
  82. if base_content == ours_content:
  83. return b"", False # We didn't modify, accept deletion
  84. else:
  85. # Conflict: they deleted, we modified
  86. m = merge3.Merge3(
  87. base_content.splitlines(True),
  88. ours_content.splitlines(True),
  89. [],
  90. )
  91. return self._merge3_to_bytes(m), True
  92. # Both sides exist, check if merge is needed
  93. if ours_content == theirs_content:
  94. return ours_content, False
  95. elif base_content == ours_content:
  96. return theirs_content, False
  97. elif base_content == theirs_content:
  98. return ours_content, False
  99. # Perform three-way merge
  100. m = merge3.Merge3(
  101. base_content.splitlines(True),
  102. ours_content.splitlines(True),
  103. theirs_content.splitlines(True),
  104. )
  105. # Check for conflicts and generate merged content
  106. merged_content = self._merge3_to_bytes(m)
  107. has_conflicts = b"<<<<<<< ours" in merged_content
  108. return merged_content, has_conflicts
  109. def _merge3_to_bytes(self, m: merge3.Merge3) -> bytes:
  110. """Convert merge3 result to bytes with conflict markers.
  111. Args:
  112. m: Merge3 object
  113. Returns:
  114. Merged content as bytes
  115. """
  116. result = []
  117. for group in m.merge_groups():
  118. if group[0] == "unchanged":
  119. result.extend(group[1])
  120. elif group[0] == "a":
  121. result.extend(group[1])
  122. elif group[0] == "b":
  123. result.extend(group[1])
  124. elif group[0] == "same":
  125. result.extend(group[1])
  126. elif group[0] == "conflict":
  127. # Check if this is a real conflict or just different changes
  128. base_lines, a_lines, b_lines = group[1], group[2], group[3]
  129. # Try to merge line by line
  130. if self._can_merge_lines(base_lines, a_lines, b_lines):
  131. merged_lines = self._merge_lines(base_lines, a_lines, b_lines)
  132. result.extend(merged_lines)
  133. else:
  134. # Real conflict - add conflict markers
  135. result.append(b"<<<<<<< ours\n")
  136. result.extend(a_lines)
  137. result.append(b"=======\n")
  138. result.extend(b_lines)
  139. result.append(b">>>>>>> theirs\n")
  140. return b"".join(result)
  141. def _can_merge_lines(
  142. self, base_lines: list[bytes], a_lines: list[bytes], b_lines: list[bytes]
  143. ) -> bool:
  144. """Check if lines can be merged without conflict."""
  145. # If one side is unchanged, we can take the other side
  146. if base_lines == a_lines:
  147. return True
  148. elif base_lines == b_lines:
  149. return True
  150. else:
  151. # For now, treat any difference as a conflict
  152. # A more sophisticated algorithm would check for non-overlapping changes
  153. return False
  154. def _merge_lines(
  155. self, base_lines: list[bytes], a_lines: list[bytes], b_lines: list[bytes]
  156. ) -> list[bytes]:
  157. """Merge lines when possible."""
  158. if base_lines == a_lines:
  159. return b_lines
  160. elif base_lines == b_lines:
  161. return a_lines
  162. else:
  163. # This shouldn't happen if _can_merge_lines returned True
  164. return a_lines
  165. def merge_trees(
  166. self, base_tree: Optional[Tree], ours_tree: Tree, theirs_tree: Tree
  167. ) -> tuple[Tree, list[bytes]]:
  168. """Perform three-way merge on trees.
  169. Args:
  170. base_tree: Common ancestor tree (can be None for no common ancestor)
  171. ours_tree: Our version of the tree
  172. theirs_tree: Their version of the tree
  173. Returns:
  174. tuple of (merged_tree, list_of_conflicted_paths)
  175. """
  176. conflicts = []
  177. merged_entries = {}
  178. # Get all paths from all trees
  179. all_paths = set()
  180. if base_tree:
  181. for entry in base_tree.items():
  182. all_paths.add(entry.path)
  183. for entry in ours_tree.items():
  184. all_paths.add(entry.path)
  185. for entry in theirs_tree.items():
  186. all_paths.add(entry.path)
  187. # Process each path
  188. for path in sorted(all_paths):
  189. base_entry = None
  190. if base_tree:
  191. try:
  192. base_entry = base_tree.lookup_path(
  193. self.object_store.__getitem__, path
  194. )
  195. except KeyError:
  196. pass
  197. try:
  198. ours_entry = ours_tree.lookup_path(self.object_store.__getitem__, path)
  199. except KeyError:
  200. ours_entry = None
  201. try:
  202. theirs_entry = theirs_tree.lookup_path(
  203. self.object_store.__getitem__, path
  204. )
  205. except KeyError:
  206. theirs_entry = None
  207. # Extract mode and sha
  208. base_mode, base_sha = base_entry if base_entry else (None, None)
  209. ours_mode, ours_sha = ours_entry if ours_entry else (None, None)
  210. theirs_mode, theirs_sha = theirs_entry if theirs_entry else (None, None)
  211. # Handle deletions
  212. if ours_sha is None and theirs_sha is None:
  213. continue # Deleted in both
  214. # Handle additions
  215. if base_sha is None:
  216. if ours_sha == theirs_sha and ours_mode == theirs_mode:
  217. # Same addition in both
  218. merged_entries[path] = (ours_mode, ours_sha)
  219. elif ours_sha is None:
  220. # Added only in theirs
  221. merged_entries[path] = (theirs_mode, theirs_sha)
  222. elif theirs_sha is None:
  223. # Added only in ours
  224. merged_entries[path] = (ours_mode, ours_sha)
  225. else:
  226. # Different additions - conflict
  227. conflicts.append(path)
  228. # For now, keep ours
  229. merged_entries[path] = (ours_mode, ours_sha)
  230. continue
  231. # Check for mode conflicts
  232. if (
  233. ours_mode != theirs_mode
  234. and ours_mode is not None
  235. and theirs_mode is not None
  236. ):
  237. conflicts.append(path)
  238. # For now, keep ours
  239. merged_entries[path] = (ours_mode, ours_sha)
  240. continue
  241. # Handle modifications
  242. if ours_sha == theirs_sha:
  243. # Same modification or no change
  244. if ours_sha is not None:
  245. merged_entries[path] = (ours_mode, ours_sha)
  246. elif base_sha == ours_sha and theirs_sha is not None:
  247. # Only theirs modified
  248. merged_entries[path] = (theirs_mode, theirs_sha)
  249. elif base_sha == theirs_sha and ours_sha is not None:
  250. # Only ours modified
  251. merged_entries[path] = (ours_mode, ours_sha)
  252. elif ours_sha is None:
  253. # We deleted
  254. if base_sha == theirs_sha:
  255. # They didn't modify, accept deletion
  256. pass
  257. else:
  258. # They modified, we deleted - conflict
  259. conflicts.append(path)
  260. elif theirs_sha is None:
  261. # They deleted
  262. if base_sha == ours_sha:
  263. # We didn't modify, accept deletion
  264. pass
  265. else:
  266. # We modified, they deleted - conflict
  267. conflicts.append(path)
  268. merged_entries[path] = (ours_mode, ours_sha)
  269. else:
  270. # Both modified differently
  271. # For trees and submodules, this is a conflict
  272. if S_ISGITLINK(ours_mode or 0) or S_ISGITLINK(theirs_mode or 0):
  273. conflicts.append(path)
  274. merged_entries[path] = (ours_mode, ours_sha)
  275. elif (ours_mode or 0) & 0o170000 == 0o040000 or (
  276. theirs_mode or 0
  277. ) & 0o170000 == 0o040000:
  278. # Tree conflict
  279. conflicts.append(path)
  280. merged_entries[path] = (ours_mode, ours_sha)
  281. else:
  282. # Try to merge blobs
  283. base_blob = (
  284. cast(Blob, self.object_store[base_sha]) if base_sha else None
  285. )
  286. ours_blob = (
  287. cast(Blob, self.object_store[ours_sha]) if ours_sha else None
  288. )
  289. theirs_blob = (
  290. cast(Blob, self.object_store[theirs_sha])
  291. if theirs_sha
  292. else None
  293. )
  294. merged_content, had_conflict = self.merge_blobs(
  295. base_blob, ours_blob, theirs_blob
  296. )
  297. if had_conflict:
  298. conflicts.append(path)
  299. # Store merged blob
  300. merged_blob = Blob.from_string(merged_content)
  301. self.object_store.add_object(merged_blob)
  302. merged_entries[path] = (ours_mode or theirs_mode, merged_blob.id)
  303. # Build merged tree
  304. merged_tree = Tree()
  305. for path, (mode, sha) in sorted(merged_entries.items()):
  306. merged_tree.add(path, mode, sha)
  307. return merged_tree, conflicts
  308. def three_way_merge(
  309. object_store: BaseObjectStore,
  310. base_commit: Optional[Commit],
  311. ours_commit: Commit,
  312. theirs_commit: Commit,
  313. ) -> tuple[Tree, list[bytes]]:
  314. """Perform a three-way merge between commits.
  315. Args:
  316. object_store: Object store to read/write objects
  317. base_commit: Common ancestor commit (None if no common ancestor)
  318. ours_commit: Our commit
  319. theirs_commit: Their commit
  320. Returns:
  321. tuple of (merged_tree, list_of_conflicted_paths)
  322. """
  323. merger = Merger(object_store)
  324. base_tree = cast(Tree, object_store[base_commit.tree]) if base_commit else None
  325. ours_tree = cast(Tree, object_store[ours_commit.tree])
  326. theirs_tree = cast(Tree, object_store[theirs_commit.tree])
  327. return merger.merge_trees(base_tree, ours_tree, theirs_tree)