filter_branch.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508
  1. # filter_branch.py - Git filter-branch functionality
  2. # Copyright (C) 2024 Jelmer Vernooij <jelmer@jelmer.uk>
  3. #
  4. # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
  5. # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
  6. # General Public License as published by the Free Software Foundation; version 2.0
  7. # or (at your option) any later version. You can redistribute it and/or
  8. # modify it under the terms of either of these two licenses.
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # You should have received a copy of the licenses; if not, see
  17. # <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
  18. # and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
  19. # License, Version 2.0.
  20. #
  21. """Git filter-branch implementation."""
  22. import os
  23. import tempfile
  24. import warnings
  25. from collections.abc import Sequence
  26. from typing import Callable, Optional, TypedDict
  27. from .index import Index, build_index_from_tree
  28. from .object_store import BaseObjectStore
  29. from .objects import Commit, Tag, Tree
  30. from .refs import RefsContainer
  31. class CommitData(TypedDict, total=False):
  32. """TypedDict for commit data fields."""
  33. author: bytes
  34. author_time: int
  35. author_timezone: int
  36. committer: bytes
  37. commit_time: int
  38. commit_timezone: int
  39. message: bytes
  40. encoding: bytes
  41. class CommitFilter:
  42. """Filter for rewriting commits during filter-branch operations."""
  43. def __init__(
  44. self,
  45. object_store: BaseObjectStore,
  46. *,
  47. filter_fn: Optional[Callable[[Commit], Optional[CommitData]]] = None,
  48. filter_author: Optional[Callable[[bytes], Optional[bytes]]] = None,
  49. filter_committer: Optional[Callable[[bytes], Optional[bytes]]] = None,
  50. filter_message: Optional[Callable[[bytes], Optional[bytes]]] = None,
  51. tree_filter: Optional[Callable[[bytes, str], Optional[bytes]]] = None,
  52. index_filter: Optional[Callable[[bytes, str], Optional[bytes]]] = None,
  53. parent_filter: Optional[Callable[[Sequence[bytes]], list[bytes]]] = None,
  54. commit_filter: Optional[Callable[[Commit, bytes], Optional[bytes]]] = None,
  55. subdirectory_filter: Optional[bytes] = None,
  56. prune_empty: bool = False,
  57. tag_name_filter: Optional[Callable[[bytes], Optional[bytes]]] = None,
  58. ):
  59. """Initialize a commit filter.
  60. Args:
  61. object_store: Object store to read from and write to
  62. filter_fn: Optional callable that takes a Commit object and returns
  63. a dict of updated fields (author, committer, message, etc.)
  64. filter_author: Optional callable that takes author bytes and returns
  65. updated author bytes or None to keep unchanged
  66. filter_committer: Optional callable that takes committer bytes and returns
  67. updated committer bytes or None to keep unchanged
  68. filter_message: Optional callable that takes commit message bytes
  69. and returns updated message bytes
  70. tree_filter: Optional callable that takes (tree_sha, temp_dir) and returns
  71. new tree SHA after modifying working directory
  72. index_filter: Optional callable that takes (tree_sha, temp_index_path) and
  73. returns new tree SHA after modifying index
  74. parent_filter: Optional callable that takes parent list and returns
  75. modified parent list
  76. commit_filter: Optional callable that takes (Commit, tree_sha) and returns
  77. new commit SHA or None to skip commit
  78. subdirectory_filter: Optional subdirectory path to extract as new root
  79. prune_empty: Whether to prune commits that become empty
  80. tag_name_filter: Optional callable to rename tags
  81. """
  82. self.object_store = object_store
  83. self.filter_fn = filter_fn
  84. self.filter_author = filter_author
  85. self.filter_committer = filter_committer
  86. self.filter_message = filter_message
  87. self.tree_filter = tree_filter
  88. self.index_filter = index_filter
  89. self.parent_filter = parent_filter
  90. self.commit_filter = commit_filter
  91. self.subdirectory_filter = subdirectory_filter
  92. self.prune_empty = prune_empty
  93. self.tag_name_filter = tag_name_filter
  94. self._old_to_new: dict[bytes, bytes] = {}
  95. self._processed: set[bytes] = set()
  96. self._tree_cache: dict[bytes, bytes] = {} # Cache for filtered trees
  97. def _filter_tree_with_subdirectory(
  98. self, tree_sha: bytes, subdirectory: bytes
  99. ) -> Optional[bytes]:
  100. """Extract a subdirectory from a tree as the new root.
  101. Args:
  102. tree_sha: SHA of the tree to filter
  103. subdirectory: Path to subdirectory to extract
  104. Returns:
  105. SHA of the new tree containing only the subdirectory, or None if not found
  106. """
  107. try:
  108. tree = self.object_store[tree_sha]
  109. if not isinstance(tree, Tree):
  110. return None
  111. except KeyError:
  112. return None
  113. # Split subdirectory path
  114. parts = subdirectory.split(b"/")
  115. current_tree = tree
  116. # Navigate to subdirectory
  117. for part in parts:
  118. if not part:
  119. continue
  120. found = False
  121. for entry in current_tree.items():
  122. if entry.path == part:
  123. try:
  124. assert entry.sha is not None
  125. obj = self.object_store[entry.sha]
  126. if isinstance(obj, Tree):
  127. current_tree = obj
  128. found = True
  129. break
  130. except KeyError:
  131. return None
  132. if not found:
  133. # Subdirectory not found, return empty tree
  134. empty_tree = Tree()
  135. self.object_store.add_object(empty_tree)
  136. return empty_tree.id
  137. # Return the subdirectory tree
  138. return current_tree.id
  139. def _apply_tree_filter(self, tree_sha: bytes) -> bytes:
  140. """Apply tree filter by checking out tree and running filter.
  141. Args:
  142. tree_sha: SHA of the tree to filter
  143. Returns:
  144. SHA of the filtered tree
  145. """
  146. if tree_sha in self._tree_cache:
  147. return self._tree_cache[tree_sha]
  148. if not self.tree_filter:
  149. self._tree_cache[tree_sha] = tree_sha
  150. return tree_sha
  151. # Create temporary directory
  152. with tempfile.TemporaryDirectory() as tmpdir:
  153. # Check out tree to temp directory
  154. # We need a proper checkout implementation here
  155. # For now, pass tmpdir to filter and let it handle checkout
  156. new_tree_sha = self.tree_filter(tree_sha, tmpdir)
  157. if new_tree_sha is None:
  158. new_tree_sha = tree_sha
  159. self._tree_cache[tree_sha] = new_tree_sha
  160. return new_tree_sha
  161. def _apply_index_filter(self, tree_sha: bytes) -> bytes:
  162. """Apply index filter by creating temp index and running filter.
  163. Args:
  164. tree_sha: SHA of the tree to filter
  165. Returns:
  166. SHA of the filtered tree
  167. """
  168. if tree_sha in self._tree_cache:
  169. return self._tree_cache[tree_sha]
  170. if not self.index_filter:
  171. self._tree_cache[tree_sha] = tree_sha
  172. return tree_sha
  173. # Create temporary index file
  174. with tempfile.NamedTemporaryFile(delete=False) as tmp_index:
  175. tmp_index_path = tmp_index.name
  176. try:
  177. # Build index from tree
  178. build_index_from_tree(".", tmp_index_path, self.object_store, tree_sha)
  179. # Run index filter
  180. new_tree_sha = self.index_filter(tree_sha, tmp_index_path)
  181. if new_tree_sha is None:
  182. # Read back the modified index and create new tree
  183. index = Index(tmp_index_path)
  184. new_tree_sha = index.commit(self.object_store)
  185. self._tree_cache[tree_sha] = new_tree_sha
  186. return new_tree_sha
  187. finally:
  188. os.unlink(tmp_index_path)
  189. def process_commit(self, commit_sha: bytes) -> Optional[bytes]:
  190. """Process a single commit, creating a filtered version.
  191. Args:
  192. commit_sha: SHA of the commit to process
  193. Returns:
  194. SHA of the new commit, or None if object not found
  195. """
  196. if commit_sha in self._processed:
  197. return self._old_to_new.get(commit_sha, commit_sha)
  198. self._processed.add(commit_sha)
  199. try:
  200. commit = self.object_store[commit_sha]
  201. except KeyError:
  202. # Object not found
  203. return None
  204. if not isinstance(commit, Commit):
  205. # Not a commit, return as-is
  206. self._old_to_new[commit_sha] = commit_sha
  207. return commit_sha
  208. # Process parents first
  209. new_parents = []
  210. for parent in commit.parents:
  211. new_parent = self.process_commit(parent)
  212. if new_parent: # Skip None parents
  213. new_parents.append(new_parent)
  214. # Apply parent filter
  215. if self.parent_filter:
  216. new_parents = self.parent_filter(new_parents)
  217. # Apply tree filters
  218. new_tree = commit.tree
  219. # Subdirectory filter takes precedence
  220. if self.subdirectory_filter:
  221. filtered_tree = self._filter_tree_with_subdirectory(
  222. commit.tree, self.subdirectory_filter
  223. )
  224. if filtered_tree:
  225. new_tree = filtered_tree
  226. # Then apply tree filter
  227. if self.tree_filter:
  228. new_tree = self._apply_tree_filter(new_tree)
  229. # Or apply index filter
  230. elif self.index_filter:
  231. new_tree = self._apply_index_filter(new_tree)
  232. # Check if we should prune empty commits
  233. if self.prune_empty and len(new_parents) == 1:
  234. # Check if tree is same as parent's tree
  235. parent_commit = self.object_store[new_parents[0]]
  236. if isinstance(parent_commit, Commit) and parent_commit.tree == new_tree:
  237. # This commit doesn't change anything, skip it
  238. self._old_to_new[commit_sha] = new_parents[0]
  239. return new_parents[0]
  240. # Apply filters
  241. new_data: CommitData = {}
  242. # Custom filter function takes precedence
  243. if self.filter_fn:
  244. filtered = self.filter_fn(commit)
  245. if filtered:
  246. new_data.update(filtered)
  247. # Apply specific filters
  248. if self.filter_author and "author" not in new_data:
  249. new_author = self.filter_author(commit.author)
  250. if new_author is not None:
  251. new_data["author"] = new_author
  252. if self.filter_committer and "committer" not in new_data:
  253. new_committer = self.filter_committer(commit.committer)
  254. if new_committer is not None:
  255. new_data["committer"] = new_committer
  256. if self.filter_message and "message" not in new_data:
  257. new_message = self.filter_message(commit.message)
  258. if new_message is not None:
  259. new_data["message"] = new_message
  260. # Create new commit if anything changed
  261. if new_data or new_parents != commit.parents or new_tree != commit.tree:
  262. new_commit = Commit()
  263. new_commit.tree = new_tree
  264. new_commit.parents = new_parents
  265. new_commit.author = new_data.get("author", commit.author)
  266. new_commit.author_time = new_data.get("author_time", commit.author_time)
  267. new_commit.author_timezone = new_data.get(
  268. "author_timezone", commit.author_timezone
  269. )
  270. new_commit.committer = new_data.get("committer", commit.committer)
  271. new_commit.commit_time = new_data.get("commit_time", commit.commit_time)
  272. new_commit.commit_timezone = new_data.get(
  273. "commit_timezone", commit.commit_timezone
  274. )
  275. new_commit.message = new_data.get("message", commit.message)
  276. new_commit.encoding = new_data.get("encoding", commit.encoding)
  277. # Copy extra fields
  278. if hasattr(commit, "_author_timezone_neg_utc"):
  279. new_commit._author_timezone_neg_utc = commit._author_timezone_neg_utc
  280. if hasattr(commit, "_commit_timezone_neg_utc"):
  281. new_commit._commit_timezone_neg_utc = commit._commit_timezone_neg_utc
  282. if hasattr(commit, "_extra"):
  283. new_commit._extra = list(commit._extra)
  284. if hasattr(commit, "_gpgsig"):
  285. new_commit._gpgsig = commit._gpgsig
  286. if hasattr(commit, "_mergetag"):
  287. new_commit._mergetag = list(commit._mergetag)
  288. # Apply commit filter if provided
  289. if self.commit_filter:
  290. # The commit filter can create a completely new commit
  291. new_commit_sha = self.commit_filter(new_commit, new_tree)
  292. if new_commit_sha is None:
  293. # Skip this commit
  294. if len(new_parents) == 1:
  295. self._old_to_new[commit_sha] = new_parents[0]
  296. return new_parents[0]
  297. elif len(new_parents) == 0:
  298. return None
  299. else:
  300. # Multiple parents, can't skip
  301. # Store the new commit anyway
  302. self.object_store.add_object(new_commit)
  303. self._old_to_new[commit_sha] = new_commit.id
  304. return new_commit.id
  305. else:
  306. self._old_to_new[commit_sha] = new_commit_sha
  307. return new_commit_sha
  308. else:
  309. # Store the new commit
  310. self.object_store.add_object(new_commit)
  311. self._old_to_new[commit_sha] = new_commit.id
  312. return new_commit.id
  313. else:
  314. # No changes, keep original
  315. self._old_to_new[commit_sha] = commit_sha
  316. return commit_sha
  317. def get_mapping(self) -> dict[bytes, bytes]:
  318. """Get the mapping of old commit SHAs to new commit SHAs.
  319. Returns:
  320. Dictionary mapping old SHAs to new SHAs
  321. """
  322. return self._old_to_new.copy()
  323. def filter_refs(
  324. refs: RefsContainer,
  325. object_store: BaseObjectStore,
  326. ref_names: Sequence[bytes],
  327. commit_filter: CommitFilter,
  328. *,
  329. keep_original: bool = True,
  330. force: bool = False,
  331. tag_callback: Optional[Callable[[bytes, bytes], None]] = None,
  332. ) -> dict[bytes, bytes]:
  333. """Filter commits reachable from the given refs.
  334. Args:
  335. refs: Repository refs container
  336. object_store: Object store containing commits
  337. ref_names: List of ref names to filter
  338. commit_filter: CommitFilter instance to use
  339. keep_original: Keep original refs under refs/original/
  340. force: Force operation even if refs have been filtered before
  341. tag_callback: Optional callback for processing tags
  342. Returns:
  343. Dictionary mapping old commit SHAs to new commit SHAs
  344. Raises:
  345. ValueError: If refs have already been filtered and force is False
  346. """
  347. # Check if already filtered
  348. if keep_original and not force:
  349. for ref in ref_names:
  350. original_ref = b"refs/original/" + ref
  351. if original_ref in refs:
  352. raise ValueError(
  353. f"Branch {ref.decode()} appears to have been filtered already. "
  354. "Use force=True to force re-filtering."
  355. )
  356. # Process commits starting from refs
  357. for ref in ref_names:
  358. try:
  359. # Get the commit SHA for this ref
  360. if ref in refs:
  361. ref_sha = refs[ref]
  362. if ref_sha:
  363. commit_filter.process_commit(ref_sha)
  364. except KeyError:
  365. # Skip refs that can't be resolved
  366. warnings.warn(f"Could not process ref {ref!r}: ref not found")
  367. continue
  368. # Update refs
  369. mapping = commit_filter.get_mapping()
  370. for ref in ref_names:
  371. try:
  372. if ref in refs:
  373. old_sha = refs[ref]
  374. new_sha = mapping.get(old_sha, old_sha)
  375. if old_sha != new_sha:
  376. # Save original ref if requested
  377. if keep_original:
  378. original_ref = b"refs/original/" + ref
  379. refs[original_ref] = old_sha
  380. # Update ref to new commit
  381. refs[ref] = new_sha
  382. except KeyError:
  383. # Not a valid ref, skip updating
  384. warnings.warn(f"Could not update ref {ref!r}: ref not found")
  385. continue
  386. # Handle tag filtering
  387. if commit_filter.tag_name_filter and tag_callback:
  388. # Process all tags
  389. for ref in refs.allkeys():
  390. if ref.startswith(b"refs/tags/"):
  391. # Get the tag object or commit it points to
  392. tag_sha = refs[ref]
  393. tag_obj = object_store[tag_sha]
  394. tag_name = ref[10:] # Remove 'refs/tags/'
  395. # Check if it's an annotated tag
  396. if isinstance(tag_obj, Tag):
  397. # Get the commit it points to
  398. target_sha = tag_obj.object[1]
  399. # Process tag if:
  400. # 1. It points to a rewritten commit, OR
  401. # 2. We want to rename the tag regardless
  402. if (
  403. target_sha in mapping
  404. or commit_filter.tag_name_filter is not None
  405. ):
  406. new_tag_name = commit_filter.tag_name_filter(tag_name)
  407. if new_tag_name and new_tag_name != tag_name:
  408. # For annotated tags pointing to rewritten commits,
  409. # we need to create a new tag object
  410. if target_sha in mapping:
  411. new_target = mapping[target_sha]
  412. # Create new tag object pointing to rewritten commit
  413. new_tag = Tag()
  414. new_tag.object = (tag_obj.object[0], new_target)
  415. new_tag.name = new_tag_name
  416. new_tag.message = tag_obj.message
  417. new_tag.tagger = tag_obj.tagger
  418. new_tag.tag_time = tag_obj.tag_time
  419. new_tag.tag_timezone = tag_obj.tag_timezone
  420. object_store.add_object(new_tag)
  421. # Update ref to point to new tag object
  422. refs[b"refs/tags/" + new_tag_name] = new_tag.id
  423. # Delete old tag
  424. del refs[ref]
  425. else:
  426. # Just rename the tag
  427. new_ref = b"refs/tags/" + new_tag_name
  428. tag_callback(ref, new_ref)
  429. elif isinstance(tag_obj, Commit):
  430. # Lightweight tag - points directly to a commit
  431. # Process if commit was rewritten or we want to rename
  432. if tag_sha in mapping or commit_filter.tag_name_filter is not None:
  433. new_tag_name = commit_filter.tag_name_filter(tag_name)
  434. if new_tag_name and new_tag_name != tag_name:
  435. new_ref = b"refs/tags/" + new_tag_name
  436. if tag_sha in mapping:
  437. # Point to rewritten commit
  438. refs[new_ref] = mapping[tag_sha]
  439. del refs[ref]
  440. else:
  441. # Just rename
  442. tag_callback(ref, new_ref)
  443. return mapping