merge_drivers.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. # merge_drivers.py -- Merge driver support for dulwich
  2. # Copyright (C) 2025 Jelmer Vernooij <jelmer@jelmer.uk>
  3. #
  4. # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
  5. # General Public License as published by the Free Software Foundation; version 2.0
  6. # or (at your option) any later version. You can redistribute it and/or
  7. # modify it under the terms of either of these two licenses.
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # You should have received a copy of the licenses; if not, see
  16. # <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
  17. # and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
  18. # License, Version 2.0.
  19. #
  20. """Merge driver support for dulwich."""
  21. import os
  22. import subprocess
  23. import tempfile
  24. from typing import Any, Callable, Optional, Protocol
  25. from .config import Config
  26. class MergeDriver(Protocol):
  27. """Protocol for merge drivers."""
  28. def merge(
  29. self,
  30. ancestor: bytes,
  31. ours: bytes,
  32. theirs: bytes,
  33. path: Optional[str] = None,
  34. marker_size: int = 7,
  35. ) -> tuple[bytes, bool]:
  36. """Perform a three-way merge.
  37. Args:
  38. ancestor: Content of the common ancestor version
  39. ours: Content of our version
  40. theirs: Content of their version
  41. path: Optional path of the file being merged
  42. marker_size: Size of conflict markers (default 7)
  43. Returns:
  44. Tuple of (merged content, success flag)
  45. If success is False, the content may contain conflict markers
  46. """
  47. ...
  48. class ProcessMergeDriver:
  49. """Merge driver that runs an external process."""
  50. def __init__(self, command: str, name: str = "custom"):
  51. """Initialize process merge driver.
  52. Args:
  53. command: Command to run for merging
  54. name: Name of the merge driver
  55. """
  56. self.command = command
  57. self.name = name
  58. def merge(
  59. self,
  60. ancestor: bytes,
  61. ours: bytes,
  62. theirs: bytes,
  63. path: Optional[str] = None,
  64. marker_size: int = 7,
  65. ) -> tuple[bytes, bool]:
  66. """Perform merge using external process.
  67. The command is executed with the following placeholders:
  68. - %O: path to ancestor version (base)
  69. - %A: path to our version
  70. - %B: path to their version
  71. - %L: conflict marker size
  72. - %P: original path of the file
  73. The command should write the merge result to the file at %A.
  74. Exit code 0 means successful merge, non-zero means conflicts.
  75. """
  76. with tempfile.TemporaryDirectory() as tmpdir:
  77. # Write temporary files
  78. ancestor_path = os.path.join(tmpdir, "ancestor")
  79. ours_path = os.path.join(tmpdir, "ours")
  80. theirs_path = os.path.join(tmpdir, "theirs")
  81. with open(ancestor_path, "wb") as f:
  82. f.write(ancestor)
  83. with open(ours_path, "wb") as f:
  84. f.write(ours)
  85. with open(theirs_path, "wb") as f:
  86. f.write(theirs)
  87. # Prepare command with placeholders
  88. cmd = self.command
  89. cmd = cmd.replace("%O", ancestor_path)
  90. cmd = cmd.replace("%A", ours_path)
  91. cmd = cmd.replace("%B", theirs_path)
  92. cmd = cmd.replace("%L", str(marker_size))
  93. if path:
  94. cmd = cmd.replace("%P", path)
  95. # Execute merge command
  96. try:
  97. result = subprocess.run(
  98. cmd,
  99. shell=True,
  100. capture_output=True,
  101. text=False,
  102. )
  103. # Read merged content from ours file
  104. with open(ours_path, "rb") as f:
  105. merged_content = f.read()
  106. # Exit code 0 means clean merge, non-zero means conflicts
  107. success = result.returncode == 0
  108. return merged_content, success
  109. except subprocess.SubprocessError:
  110. # If the command fails completely, return original with conflicts
  111. return ours, False
  112. class MergeDriverRegistry:
  113. """Registry for merge drivers."""
  114. def __init__(self, config: Optional[Config] = None):
  115. """Initialize merge driver registry.
  116. Args:
  117. config: Git configuration object
  118. """
  119. self._drivers: dict[str, MergeDriver] = {}
  120. self._factories: dict[str, Any] = {}
  121. self._config = config
  122. # Register built-in drivers
  123. self._register_builtin_drivers()
  124. def _register_builtin_drivers(self) -> None:
  125. """Register built-in merge drivers."""
  126. # The "text" driver is the default three-way merge
  127. # We don't register it here as it's handled by the default merge code
  128. def register_driver(self, name: str, driver: MergeDriver) -> None:
  129. """Register a merge driver instance.
  130. Args:
  131. name: Name of the merge driver
  132. driver: Driver instance
  133. """
  134. self._drivers[name] = driver
  135. def register_factory(self, name: str, factory: Callable[[], MergeDriver]) -> None:
  136. """Register a factory function for creating merge drivers.
  137. Args:
  138. name: Name of the merge driver
  139. factory: Factory function that returns a MergeDriver
  140. """
  141. self._factories[name] = factory
  142. def get_driver(self, name: str) -> Optional[MergeDriver]:
  143. """Get a merge driver by name.
  144. Args:
  145. name: Name of the merge driver
  146. Returns:
  147. MergeDriver instance or None if not found
  148. """
  149. # First check registered drivers
  150. if name in self._drivers:
  151. return self._drivers[name]
  152. # Then check factories
  153. if name in self._factories:
  154. driver = self._factories[name]()
  155. self._drivers[name] = driver
  156. return driver
  157. # Finally check configuration
  158. if self._config:
  159. driver = self._create_from_config(name)
  160. if driver:
  161. self._drivers[name] = driver
  162. return driver
  163. return None
  164. def _create_from_config(self, name: str) -> Optional[MergeDriver]:
  165. """Create a merge driver from git configuration.
  166. Args:
  167. name: Name of the merge driver
  168. Returns:
  169. MergeDriver instance or None if not configured
  170. """
  171. if not self._config:
  172. return None
  173. # Look for merge.<name>.driver configuration
  174. try:
  175. command = self._config.get(("merge", name), "driver")
  176. if command:
  177. return ProcessMergeDriver(command.decode(), name)
  178. except KeyError:
  179. pass
  180. return None
  181. # Global registry instance
  182. _merge_driver_registry: Optional[MergeDriverRegistry] = None
  183. def get_merge_driver_registry(config: Optional[Config] = None) -> MergeDriverRegistry:
  184. """Get the global merge driver registry.
  185. Args:
  186. config: Git configuration object
  187. Returns:
  188. MergeDriverRegistry instance
  189. """
  190. global _merge_driver_registry
  191. if _merge_driver_registry is None:
  192. _merge_driver_registry = MergeDriverRegistry(config)
  193. elif config is not None:
  194. # Update config if provided
  195. _merge_driver_registry._config = config
  196. return _merge_driver_registry