merge.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. """Git merge implementation."""
  2. from typing import Optional
  3. try:
  4. import merge3
  5. except ImportError:
  6. merge3 = None # type: ignore
  7. from dulwich.attrs import GitAttributes
  8. from dulwich.config import Config
  9. from dulwich.merge_drivers import get_merge_driver_registry
  10. from dulwich.object_store import BaseObjectStore
  11. from dulwich.objects import S_ISGITLINK, Blob, Commit, Tree, is_blob, is_tree
  12. class MergeConflict(Exception):
  13. """Raised when a merge conflict occurs."""
  14. def __init__(self, path: bytes, message: str) -> None:
  15. """Initialize MergeConflict.
  16. Args:
  17. path: Path to the conflicted file
  18. message: Conflict description
  19. """
  20. self.path = path
  21. super().__init__(f"Merge conflict in {path!r}: {message}")
  22. def _can_merge_lines(
  23. base_lines: list[bytes], a_lines: list[bytes], b_lines: list[bytes]
  24. ) -> bool:
  25. """Check if lines can be merged without conflict."""
  26. # If one side is unchanged, we can take the other side
  27. if base_lines == a_lines:
  28. return True
  29. elif base_lines == b_lines:
  30. return True
  31. else:
  32. # For now, treat any difference as a conflict
  33. # A more sophisticated algorithm would check for non-overlapping changes
  34. return False
  35. if merge3 is not None:
  36. def _merge3_to_bytes(m: merge3.Merge3) -> bytes:
  37. """Convert merge3 result to bytes with conflict markers.
  38. Args:
  39. m: Merge3 object
  40. Returns:
  41. Merged content as bytes
  42. """
  43. result = []
  44. for group in m.merge_groups():
  45. if group[0] == "unchanged":
  46. result.extend(group[1])
  47. elif group[0] == "a":
  48. result.extend(group[1])
  49. elif group[0] == "b":
  50. result.extend(group[1])
  51. elif group[0] == "same":
  52. result.extend(group[1])
  53. elif group[0] == "conflict":
  54. # Check if this is a real conflict or just different changes
  55. base_lines, a_lines, b_lines = group[1], group[2], group[3]
  56. # Try to merge line by line
  57. if _can_merge_lines(base_lines, a_lines, b_lines):
  58. merged_lines = _merge_lines(base_lines, a_lines, b_lines)
  59. result.extend(merged_lines)
  60. else:
  61. # Real conflict - add conflict markers
  62. result.append(b"<<<<<<< ours\n")
  63. result.extend(a_lines)
  64. result.append(b"=======\n")
  65. result.extend(b_lines)
  66. result.append(b">>>>>>> theirs\n")
  67. return b"".join(result)
  68. def _merge_lines(
  69. base_lines: list[bytes], a_lines: list[bytes], b_lines: list[bytes]
  70. ) -> list[bytes]:
  71. """Merge lines when possible."""
  72. if base_lines == a_lines:
  73. return b_lines
  74. elif base_lines == b_lines:
  75. return a_lines
  76. else:
  77. # This shouldn't happen if _can_merge_lines returned True
  78. return a_lines
  79. def merge_blobs(
  80. base_blob: Optional[Blob],
  81. ours_blob: Optional[Blob],
  82. theirs_blob: Optional[Blob],
  83. path: Optional[bytes] = None,
  84. gitattributes: Optional[GitAttributes] = None,
  85. config: Optional[Config] = None,
  86. ) -> tuple[bytes, bool]:
  87. """Perform three-way merge on blob contents.
  88. Args:
  89. base_blob: Common ancestor blob (can be None)
  90. ours_blob: Our version of the blob (can be None)
  91. theirs_blob: Their version of the blob (can be None)
  92. path: Optional path of the file being merged
  93. gitattributes: Optional GitAttributes object for checking merge drivers
  94. config: Optional Config object for loading merge driver configuration
  95. Returns:
  96. Tuple of (merged_content, had_conflicts)
  97. """
  98. # Check for merge driver
  99. merge_driver_name = None
  100. if path and gitattributes:
  101. attrs = gitattributes.match_path(path)
  102. merge_value = attrs.get(b"merge")
  103. if merge_value and isinstance(merge_value, bytes) and merge_value != b"text":
  104. merge_driver_name = merge_value.decode("utf-8", errors="replace")
  105. # Use merge driver if found
  106. if merge_driver_name:
  107. registry = get_merge_driver_registry(config)
  108. driver = registry.get_driver(merge_driver_name)
  109. if driver:
  110. # Get content from blobs
  111. base_content = base_blob.data if base_blob else b""
  112. ours_content = ours_blob.data if ours_blob else b""
  113. theirs_content = theirs_blob.data if theirs_blob else b""
  114. # Use merge driver
  115. merged_content, success = driver.merge(
  116. ancestor=base_content,
  117. ours=ours_content,
  118. theirs=theirs_content,
  119. path=path.decode("utf-8", errors="replace") if path else None,
  120. marker_size=7,
  121. )
  122. # Convert success (no conflicts) to had_conflicts (conflicts occurred)
  123. had_conflicts = not success
  124. return merged_content, had_conflicts
  125. # Fall back to default merge behavior
  126. # Handle deletion cases
  127. if ours_blob is None and theirs_blob is None:
  128. return b"", False
  129. if base_blob is None:
  130. # No common ancestor
  131. if ours_blob is None:
  132. assert theirs_blob is not None
  133. return theirs_blob.data, False
  134. elif theirs_blob is None:
  135. return ours_blob.data, False
  136. elif ours_blob.data == theirs_blob.data:
  137. return ours_blob.data, False
  138. else:
  139. # Both added different content - conflict
  140. m = merge3.Merge3(
  141. [],
  142. ours_blob.data.splitlines(True),
  143. theirs_blob.data.splitlines(True),
  144. )
  145. return _merge3_to_bytes(m), True
  146. # Get content for each version
  147. base_content = base_blob.data if base_blob else b""
  148. ours_content = ours_blob.data if ours_blob else b""
  149. theirs_content = theirs_blob.data if theirs_blob else b""
  150. # Check if either side deleted
  151. if ours_blob is None or theirs_blob is None:
  152. if ours_blob is None and theirs_blob is None:
  153. return b"", False
  154. elif ours_blob is None:
  155. # We deleted, check if they modified
  156. if base_content == theirs_content:
  157. return b"", False # They didn't modify, accept deletion
  158. else:
  159. # Conflict: we deleted, they modified
  160. m = merge3.Merge3(
  161. base_content.splitlines(True),
  162. [],
  163. theirs_content.splitlines(True),
  164. )
  165. return _merge3_to_bytes(m), True
  166. else:
  167. # They deleted, check if we modified
  168. if base_content == ours_content:
  169. return b"", False # We didn't modify, accept deletion
  170. else:
  171. # Conflict: they deleted, we modified
  172. m = merge3.Merge3(
  173. base_content.splitlines(True),
  174. ours_content.splitlines(True),
  175. [],
  176. )
  177. return _merge3_to_bytes(m), True
  178. # Both sides exist, check if merge is needed
  179. if ours_content == theirs_content:
  180. return ours_content, False
  181. elif base_content == ours_content:
  182. return theirs_content, False
  183. elif base_content == theirs_content:
  184. return ours_content, False
  185. # Perform three-way merge
  186. m = merge3.Merge3(
  187. base_content.splitlines(True),
  188. ours_content.splitlines(True),
  189. theirs_content.splitlines(True),
  190. )
  191. # Check for conflicts and generate merged content
  192. merged_content = _merge3_to_bytes(m)
  193. has_conflicts = b"<<<<<<< ours" in merged_content
  194. return merged_content, has_conflicts
  195. class Merger:
  196. """Handles git merge operations."""
  197. def __init__(
  198. self,
  199. object_store: BaseObjectStore,
  200. gitattributes: Optional[GitAttributes] = None,
  201. config: Optional[Config] = None,
  202. ) -> None:
  203. """Initialize merger.
  204. Args:
  205. object_store: Object store to read objects from
  206. gitattributes: Optional GitAttributes object for checking merge drivers
  207. config: Optional Config object for loading merge driver configuration
  208. """
  209. self.object_store = object_store
  210. self.gitattributes = gitattributes
  211. self.config = config
  212. def merge_blobs(
  213. self,
  214. base_blob: Optional[Blob],
  215. ours_blob: Optional[Blob],
  216. theirs_blob: Optional[Blob],
  217. path: Optional[bytes] = None,
  218. ) -> tuple[bytes, bool]:
  219. """Perform three-way merge on blob contents.
  220. Args:
  221. base_blob: Common ancestor blob (can be None)
  222. ours_blob: Our version of the blob (can be None)
  223. theirs_blob: Their version of the blob (can be None)
  224. path: Optional path of the file being merged
  225. Returns:
  226. Tuple of (merged_content, had_conflicts)
  227. """
  228. return merge_blobs(
  229. base_blob, ours_blob, theirs_blob, path, self.gitattributes, self.config
  230. )
  231. def merge_trees(
  232. self, base_tree: Optional[Tree], ours_tree: Tree, theirs_tree: Tree
  233. ) -> tuple[Tree, list[bytes]]:
  234. """Perform three-way merge on trees.
  235. Args:
  236. base_tree: Common ancestor tree (can be None for no common ancestor)
  237. ours_tree: Our version of the tree
  238. theirs_tree: Their version of the tree
  239. Returns:
  240. tuple of (merged_tree, list_of_conflicted_paths)
  241. """
  242. conflicts = []
  243. merged_entries = {}
  244. # Get all paths from all trees
  245. all_paths = set()
  246. if base_tree:
  247. for entry in base_tree.items():
  248. all_paths.add(entry.path)
  249. for entry in ours_tree.items():
  250. all_paths.add(entry.path)
  251. for entry in theirs_tree.items():
  252. all_paths.add(entry.path)
  253. # Process each path
  254. for path in sorted(all_paths):
  255. base_entry = None
  256. if base_tree:
  257. try:
  258. base_entry = base_tree.lookup_path(
  259. self.object_store.__getitem__, path
  260. )
  261. except KeyError:
  262. pass
  263. try:
  264. ours_entry = ours_tree.lookup_path(self.object_store.__getitem__, path)
  265. except KeyError:
  266. ours_entry = None
  267. try:
  268. theirs_entry = theirs_tree.lookup_path(
  269. self.object_store.__getitem__, path
  270. )
  271. except KeyError:
  272. theirs_entry = None
  273. # Extract mode and sha
  274. base_mode, base_sha = base_entry if base_entry else (None, None)
  275. ours_mode, ours_sha = ours_entry if ours_entry else (None, None)
  276. theirs_mode, theirs_sha = theirs_entry if theirs_entry else (None, None)
  277. # Handle deletions
  278. if ours_sha is None and theirs_sha is None:
  279. continue # Deleted in both
  280. # Handle additions
  281. if base_sha is None:
  282. if ours_sha == theirs_sha and ours_mode == theirs_mode:
  283. # Same addition in both
  284. merged_entries[path] = (ours_mode, ours_sha)
  285. elif ours_sha is None:
  286. # Added only in theirs
  287. merged_entries[path] = (theirs_mode, theirs_sha)
  288. elif theirs_sha is None:
  289. # Added only in ours
  290. merged_entries[path] = (ours_mode, ours_sha)
  291. else:
  292. # Different additions - conflict
  293. conflicts.append(path)
  294. # For now, keep ours
  295. merged_entries[path] = (ours_mode, ours_sha)
  296. continue
  297. # Check for mode conflicts
  298. if (
  299. ours_mode != theirs_mode
  300. and ours_mode is not None
  301. and theirs_mode is not None
  302. ):
  303. conflicts.append(path)
  304. # For now, keep ours
  305. merged_entries[path] = (ours_mode, ours_sha)
  306. continue
  307. # Handle modifications
  308. if ours_sha == theirs_sha:
  309. # Same modification or no change
  310. if ours_sha is not None:
  311. merged_entries[path] = (ours_mode, ours_sha)
  312. elif base_sha == ours_sha and theirs_sha is not None:
  313. # Only theirs modified
  314. merged_entries[path] = (theirs_mode, theirs_sha)
  315. elif base_sha == theirs_sha and ours_sha is not None:
  316. # Only ours modified
  317. merged_entries[path] = (ours_mode, ours_sha)
  318. elif ours_sha is None:
  319. # We deleted
  320. if base_sha == theirs_sha:
  321. # They didn't modify, accept deletion
  322. pass
  323. else:
  324. # They modified, we deleted - conflict
  325. conflicts.append(path)
  326. elif theirs_sha is None:
  327. # They deleted
  328. if base_sha == ours_sha:
  329. # We didn't modify, accept deletion
  330. pass
  331. else:
  332. # We modified, they deleted - conflict
  333. conflicts.append(path)
  334. merged_entries[path] = (ours_mode, ours_sha)
  335. else:
  336. # Both modified differently
  337. # For trees and submodules, this is a conflict
  338. if S_ISGITLINK(ours_mode or 0) or S_ISGITLINK(theirs_mode or 0):
  339. conflicts.append(path)
  340. merged_entries[path] = (ours_mode, ours_sha)
  341. elif (ours_mode or 0) & 0o170000 == 0o040000 or (
  342. theirs_mode or 0
  343. ) & 0o170000 == 0o040000:
  344. # Tree conflict
  345. conflicts.append(path)
  346. merged_entries[path] = (ours_mode, ours_sha)
  347. else:
  348. # Try to merge blobs
  349. base_blob = None
  350. if base_sha:
  351. base_obj = self.object_store[base_sha]
  352. if is_blob(base_obj):
  353. base_blob = base_obj
  354. else:
  355. raise TypeError(
  356. f"Expected blob for {path!r}, got {base_obj.type_name.decode()}"
  357. )
  358. ours_blob = None
  359. if ours_sha:
  360. ours_obj = self.object_store[ours_sha]
  361. if is_blob(ours_obj):
  362. ours_blob = ours_obj
  363. else:
  364. raise TypeError(
  365. f"Expected blob for {path!r}, got {ours_obj.type_name.decode()}"
  366. )
  367. theirs_blob = None
  368. if theirs_sha:
  369. theirs_obj = self.object_store[theirs_sha]
  370. if is_blob(theirs_obj):
  371. theirs_blob = theirs_obj
  372. else:
  373. raise TypeError(
  374. f"Expected blob for {path!r}, got {theirs_obj.type_name.decode()}"
  375. )
  376. assert isinstance(base_blob, Blob)
  377. assert isinstance(ours_blob, Blob)
  378. assert isinstance(theirs_blob, Blob)
  379. merged_content, had_conflict = self.merge_blobs(
  380. base_blob, ours_blob, theirs_blob, path
  381. )
  382. if had_conflict:
  383. conflicts.append(path)
  384. # Store merged blob
  385. merged_blob = Blob.from_string(merged_content)
  386. self.object_store.add_object(merged_blob)
  387. merged_entries[path] = (ours_mode or theirs_mode, merged_blob.id)
  388. # Build merged tree
  389. merged_tree = Tree()
  390. for path, (mode, sha) in sorted(merged_entries.items()):
  391. if mode is not None and sha is not None:
  392. merged_tree.add(path, mode, sha)
  393. return merged_tree, conflicts
  394. def three_way_merge(
  395. object_store: BaseObjectStore,
  396. base_commit: Optional[Commit],
  397. ours_commit: Commit,
  398. theirs_commit: Commit,
  399. gitattributes: Optional[GitAttributes] = None,
  400. config: Optional[Config] = None,
  401. ) -> tuple[Tree, list[bytes]]:
  402. """Perform a three-way merge between commits.
  403. Args:
  404. object_store: Object store to read/write objects
  405. base_commit: Common ancestor commit (None if no common ancestor)
  406. ours_commit: Our commit
  407. theirs_commit: Their commit
  408. gitattributes: Optional GitAttributes object for checking merge drivers
  409. config: Optional Config object for loading merge driver configuration
  410. Returns:
  411. tuple of (merged_tree, list_of_conflicted_paths)
  412. """
  413. merger = Merger(object_store, gitattributes, config)
  414. base_tree = None
  415. if base_commit:
  416. base_obj = object_store[base_commit.tree]
  417. if is_tree(base_obj):
  418. base_tree = base_obj
  419. else:
  420. raise TypeError(f"Expected tree, got {base_obj.type_name.decode()}")
  421. ours_obj = object_store[ours_commit.tree]
  422. if is_tree(ours_obj):
  423. ours_tree = ours_obj
  424. else:
  425. raise TypeError(f"Expected tree, got {ours_obj.type_name.decode()}")
  426. theirs_obj = object_store[theirs_commit.tree]
  427. if is_tree(theirs_obj):
  428. theirs_tree = theirs_obj
  429. else:
  430. raise TypeError(f"Expected tree, got {theirs_obj.type_name.decode()}")
  431. assert isinstance(base_tree, Tree)
  432. assert isinstance(ours_tree, Tree)
  433. assert isinstance(theirs_tree, Tree)
  434. return merger.merge_trees(base_tree, ours_tree, theirs_tree)