merge.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399
  1. """Git merge implementation."""
  2. from typing import Optional, cast
  3. try:
  4. import merge3
  5. except ImportError:
  6. merge3 = None # type: ignore
  7. from dulwich.object_store import BaseObjectStore
  8. from dulwich.objects import S_ISGITLINK, Blob, Commit, Tree
  9. class MergeConflict(Exception):
  10. """Raised when a merge conflict occurs."""
  11. def __init__(self, path: bytes, message: str):
  12. self.path = path
  13. super().__init__(f"Merge conflict in {path!r}: {message}")
  14. def _can_merge_lines(
  15. base_lines: list[bytes], a_lines: list[bytes], b_lines: list[bytes]
  16. ) -> bool:
  17. """Check if lines can be merged without conflict."""
  18. # If one side is unchanged, we can take the other side
  19. if base_lines == a_lines:
  20. return True
  21. elif base_lines == b_lines:
  22. return True
  23. else:
  24. # For now, treat any difference as a conflict
  25. # A more sophisticated algorithm would check for non-overlapping changes
  26. return False
  27. if merge3 is not None:
  28. def _merge3_to_bytes(m: merge3.Merge3) -> bytes:
  29. """Convert merge3 result to bytes with conflict markers.
  30. Args:
  31. m: Merge3 object
  32. Returns:
  33. Merged content as bytes
  34. """
  35. result = []
  36. for group in m.merge_groups():
  37. if group[0] == "unchanged":
  38. result.extend(group[1])
  39. elif group[0] == "a":
  40. result.extend(group[1])
  41. elif group[0] == "b":
  42. result.extend(group[1])
  43. elif group[0] == "same":
  44. result.extend(group[1])
  45. elif group[0] == "conflict":
  46. # Check if this is a real conflict or just different changes
  47. base_lines, a_lines, b_lines = group[1], group[2], group[3]
  48. # Try to merge line by line
  49. if _can_merge_lines(base_lines, a_lines, b_lines):
  50. merged_lines = _merge_lines(base_lines, a_lines, b_lines)
  51. result.extend(merged_lines)
  52. else:
  53. # Real conflict - add conflict markers
  54. result.append(b"<<<<<<< ours\n")
  55. result.extend(a_lines)
  56. result.append(b"=======\n")
  57. result.extend(b_lines)
  58. result.append(b">>>>>>> theirs\n")
  59. return b"".join(result)
  60. def _merge_lines(
  61. base_lines: list[bytes], a_lines: list[bytes], b_lines: list[bytes]
  62. ) -> list[bytes]:
  63. """Merge lines when possible."""
  64. if base_lines == a_lines:
  65. return b_lines
  66. elif base_lines == b_lines:
  67. return a_lines
  68. else:
  69. # This shouldn't happen if _can_merge_lines returned True
  70. return a_lines
  71. def merge_blobs(
  72. base_blob: Optional[Blob],
  73. ours_blob: Optional[Blob],
  74. theirs_blob: Optional[Blob],
  75. ) -> tuple[bytes, bool]:
  76. """Perform three-way merge on blob contents.
  77. Args:
  78. base_blob: Common ancestor blob (can be None)
  79. ours_blob: Our version of the blob (can be None)
  80. theirs_blob: Their version of the blob (can be None)
  81. Returns:
  82. Tuple of (merged_content, had_conflicts)
  83. """
  84. # Handle deletion cases
  85. if ours_blob is None and theirs_blob is None:
  86. return b"", False
  87. if base_blob is None:
  88. # No common ancestor
  89. if ours_blob is None:
  90. assert theirs_blob is not None
  91. return theirs_blob.data, False
  92. elif theirs_blob is None:
  93. return ours_blob.data, False
  94. elif ours_blob.data == theirs_blob.data:
  95. return ours_blob.data, False
  96. else:
  97. # Both added different content - conflict
  98. m = merge3.Merge3(
  99. [],
  100. ours_blob.data.splitlines(True),
  101. theirs_blob.data.splitlines(True),
  102. )
  103. return _merge3_to_bytes(m), True
  104. # Get content for each version
  105. base_content = base_blob.data if base_blob else b""
  106. ours_content = ours_blob.data if ours_blob else b""
  107. theirs_content = theirs_blob.data if theirs_blob else b""
  108. # Check if either side deleted
  109. if ours_blob is None or theirs_blob is None:
  110. if ours_blob is None and theirs_blob is None:
  111. return b"", False
  112. elif ours_blob is None:
  113. # We deleted, check if they modified
  114. if base_content == theirs_content:
  115. return b"", False # They didn't modify, accept deletion
  116. else:
  117. # Conflict: we deleted, they modified
  118. m = merge3.Merge3(
  119. base_content.splitlines(True),
  120. [],
  121. theirs_content.splitlines(True),
  122. )
  123. return _merge3_to_bytes(m), True
  124. else:
  125. # They deleted, check if we modified
  126. if base_content == ours_content:
  127. return b"", False # We didn't modify, accept deletion
  128. else:
  129. # Conflict: they deleted, we modified
  130. m = merge3.Merge3(
  131. base_content.splitlines(True),
  132. ours_content.splitlines(True),
  133. [],
  134. )
  135. return _merge3_to_bytes(m), True
  136. # Both sides exist, check if merge is needed
  137. if ours_content == theirs_content:
  138. return ours_content, False
  139. elif base_content == ours_content:
  140. return theirs_content, False
  141. elif base_content == theirs_content:
  142. return ours_content, False
  143. # Perform three-way merge
  144. m = merge3.Merge3(
  145. base_content.splitlines(True),
  146. ours_content.splitlines(True),
  147. theirs_content.splitlines(True),
  148. )
  149. # Check for conflicts and generate merged content
  150. merged_content = _merge3_to_bytes(m)
  151. has_conflicts = b"<<<<<<< ours" in merged_content
  152. return merged_content, has_conflicts
  153. class Merger:
  154. """Handles git merge operations."""
  155. def __init__(self, object_store: BaseObjectStore):
  156. """Initialize merger.
  157. Args:
  158. object_store: Object store to read objects from
  159. """
  160. self.object_store = object_store
  161. @staticmethod
  162. def merge_blobs(
  163. base_blob: Optional[Blob],
  164. ours_blob: Optional[Blob],
  165. theirs_blob: Optional[Blob],
  166. ) -> tuple[bytes, bool]:
  167. """Perform three-way merge on blob contents.
  168. Args:
  169. base_blob: Common ancestor blob (can be None)
  170. ours_blob: Our version of the blob (can be None)
  171. theirs_blob: Their version of the blob (can be None)
  172. Returns:
  173. Tuple of (merged_content, had_conflicts)
  174. """
  175. return merge_blobs(base_blob, ours_blob, theirs_blob)
  176. def merge_trees(
  177. self, base_tree: Optional[Tree], ours_tree: Tree, theirs_tree: Tree
  178. ) -> tuple[Tree, list[bytes]]:
  179. """Perform three-way merge on trees.
  180. Args:
  181. base_tree: Common ancestor tree (can be None for no common ancestor)
  182. ours_tree: Our version of the tree
  183. theirs_tree: Their version of the tree
  184. Returns:
  185. tuple of (merged_tree, list_of_conflicted_paths)
  186. """
  187. conflicts = []
  188. merged_entries = {}
  189. # Get all paths from all trees
  190. all_paths = set()
  191. if base_tree:
  192. for entry in base_tree.items():
  193. all_paths.add(entry.path)
  194. for entry in ours_tree.items():
  195. all_paths.add(entry.path)
  196. for entry in theirs_tree.items():
  197. all_paths.add(entry.path)
  198. # Process each path
  199. for path in sorted(all_paths):
  200. base_entry = None
  201. if base_tree:
  202. try:
  203. base_entry = base_tree.lookup_path(
  204. self.object_store.__getitem__, path
  205. )
  206. except KeyError:
  207. pass
  208. try:
  209. ours_entry = ours_tree.lookup_path(self.object_store.__getitem__, path)
  210. except KeyError:
  211. ours_entry = None
  212. try:
  213. theirs_entry = theirs_tree.lookup_path(
  214. self.object_store.__getitem__, path
  215. )
  216. except KeyError:
  217. theirs_entry = None
  218. # Extract mode and sha
  219. base_mode, base_sha = base_entry if base_entry else (None, None)
  220. ours_mode, ours_sha = ours_entry if ours_entry else (None, None)
  221. theirs_mode, theirs_sha = theirs_entry if theirs_entry else (None, None)
  222. # Handle deletions
  223. if ours_sha is None and theirs_sha is None:
  224. continue # Deleted in both
  225. # Handle additions
  226. if base_sha is None:
  227. if ours_sha == theirs_sha and ours_mode == theirs_mode:
  228. # Same addition in both
  229. merged_entries[path] = (ours_mode, ours_sha)
  230. elif ours_sha is None:
  231. # Added only in theirs
  232. merged_entries[path] = (theirs_mode, theirs_sha)
  233. elif theirs_sha is None:
  234. # Added only in ours
  235. merged_entries[path] = (ours_mode, ours_sha)
  236. else:
  237. # Different additions - conflict
  238. conflicts.append(path)
  239. # For now, keep ours
  240. merged_entries[path] = (ours_mode, ours_sha)
  241. continue
  242. # Check for mode conflicts
  243. if (
  244. ours_mode != theirs_mode
  245. and ours_mode is not None
  246. and theirs_mode is not None
  247. ):
  248. conflicts.append(path)
  249. # For now, keep ours
  250. merged_entries[path] = (ours_mode, ours_sha)
  251. continue
  252. # Handle modifications
  253. if ours_sha == theirs_sha:
  254. # Same modification or no change
  255. if ours_sha is not None:
  256. merged_entries[path] = (ours_mode, ours_sha)
  257. elif base_sha == ours_sha and theirs_sha is not None:
  258. # Only theirs modified
  259. merged_entries[path] = (theirs_mode, theirs_sha)
  260. elif base_sha == theirs_sha and ours_sha is not None:
  261. # Only ours modified
  262. merged_entries[path] = (ours_mode, ours_sha)
  263. elif ours_sha is None:
  264. # We deleted
  265. if base_sha == theirs_sha:
  266. # They didn't modify, accept deletion
  267. pass
  268. else:
  269. # They modified, we deleted - conflict
  270. conflicts.append(path)
  271. elif theirs_sha is None:
  272. # They deleted
  273. if base_sha == ours_sha:
  274. # We didn't modify, accept deletion
  275. pass
  276. else:
  277. # We modified, they deleted - conflict
  278. conflicts.append(path)
  279. merged_entries[path] = (ours_mode, ours_sha)
  280. else:
  281. # Both modified differently
  282. # For trees and submodules, this is a conflict
  283. if S_ISGITLINK(ours_mode or 0) or S_ISGITLINK(theirs_mode or 0):
  284. conflicts.append(path)
  285. merged_entries[path] = (ours_mode, ours_sha)
  286. elif (ours_mode or 0) & 0o170000 == 0o040000 or (
  287. theirs_mode or 0
  288. ) & 0o170000 == 0o040000:
  289. # Tree conflict
  290. conflicts.append(path)
  291. merged_entries[path] = (ours_mode, ours_sha)
  292. else:
  293. # Try to merge blobs
  294. base_blob = (
  295. cast(Blob, self.object_store[base_sha]) if base_sha else None
  296. )
  297. ours_blob = (
  298. cast(Blob, self.object_store[ours_sha]) if ours_sha else None
  299. )
  300. theirs_blob = (
  301. cast(Blob, self.object_store[theirs_sha])
  302. if theirs_sha
  303. else None
  304. )
  305. merged_content, had_conflict = self.merge_blobs(
  306. base_blob, ours_blob, theirs_blob
  307. )
  308. if had_conflict:
  309. conflicts.append(path)
  310. # Store merged blob
  311. merged_blob = Blob.from_string(merged_content)
  312. self.object_store.add_object(merged_blob)
  313. merged_entries[path] = (ours_mode or theirs_mode, merged_blob.id)
  314. # Build merged tree
  315. merged_tree = Tree()
  316. for path, (mode, sha) in sorted(merged_entries.items()):
  317. merged_tree.add(path, mode, sha)
  318. return merged_tree, conflicts
  319. def three_way_merge(
  320. object_store: BaseObjectStore,
  321. base_commit: Optional[Commit],
  322. ours_commit: Commit,
  323. theirs_commit: Commit,
  324. ) -> tuple[Tree, list[bytes]]:
  325. """Perform a three-way merge between commits.
  326. Args:
  327. object_store: Object store to read/write objects
  328. base_commit: Common ancestor commit (None if no common ancestor)
  329. ours_commit: Our commit
  330. theirs_commit: Their commit
  331. Returns:
  332. tuple of (merged_tree, list_of_conflicted_paths)
  333. """
  334. merger = Merger(object_store)
  335. base_tree = cast(Tree, object_store[base_commit.tree]) if base_commit else None
  336. ours_tree = cast(Tree, object_store[ours_commit.tree])
  337. theirs_tree = cast(Tree, object_store[theirs_commit.tree])
  338. return merger.merge_trees(base_tree, ours_tree, theirs_tree)