filter_branch.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516
  1. # filter_branch.py - Git filter-branch functionality
  2. # Copyright (C) 2024 Jelmer Vernooij <jelmer@jelmer.uk>
  3. #
  4. # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
  5. # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
  6. # General Public License as published by the Free Software Foundation; version 2.0
  7. # or (at your option) any later version. You can redistribute it and/or
  8. # modify it under the terms of either of these two licenses.
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # You should have received a copy of the licenses; if not, see
  17. # <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
  18. # and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
  19. # License, Version 2.0.
  20. #
  21. """Git filter-branch implementation."""
  22. __all__ = [
  23. "CommitData",
  24. "CommitFilter",
  25. "filter_refs",
  26. ]
  27. import os
  28. import tempfile
  29. import warnings
  30. from collections.abc import Callable, Sequence
  31. from typing import TypedDict
  32. from .index import Index, build_index_from_tree
  33. from .object_store import BaseObjectStore
  34. from .objects import Commit, ObjectID, Tag, Tree
  35. from .refs import Ref, RefsContainer, local_tag_name
  36. class CommitData(TypedDict, total=False):
  37. """TypedDict for commit data fields."""
  38. author: bytes
  39. author_time: int
  40. author_timezone: int
  41. committer: bytes
  42. commit_time: int
  43. commit_timezone: int
  44. message: bytes
  45. encoding: bytes
  46. class CommitFilter:
  47. """Filter for rewriting commits during filter-branch operations."""
  48. def __init__(
  49. self,
  50. object_store: BaseObjectStore,
  51. *,
  52. filter_fn: Callable[[Commit], CommitData | None] | None = None,
  53. filter_author: Callable[[bytes], bytes | None] | None = None,
  54. filter_committer: Callable[[bytes], bytes | None] | None = None,
  55. filter_message: Callable[[bytes], bytes | None] | None = None,
  56. tree_filter: Callable[[ObjectID, str], ObjectID | None] | None = None,
  57. index_filter: Callable[[ObjectID, str], ObjectID | None] | None = None,
  58. parent_filter: Callable[[Sequence[ObjectID]], list[ObjectID]] | None = None,
  59. commit_filter: Callable[[Commit, ObjectID], ObjectID | None] | None = None,
  60. subdirectory_filter: bytes | None = None,
  61. prune_empty: bool = False,
  62. tag_name_filter: Callable[[bytes], bytes | None] | None = None,
  63. ):
  64. """Initialize a commit filter.
  65. Args:
  66. object_store: Object store to read from and write to
  67. filter_fn: Optional callable that takes a Commit object and returns
  68. a dict of updated fields (author, committer, message, etc.)
  69. filter_author: Optional callable that takes author bytes and returns
  70. updated author bytes or None to keep unchanged
  71. filter_committer: Optional callable that takes committer bytes and returns
  72. updated committer bytes or None to keep unchanged
  73. filter_message: Optional callable that takes commit message bytes
  74. and returns updated message bytes
  75. tree_filter: Optional callable that takes (tree_sha, temp_dir) and returns
  76. new tree SHA after modifying working directory
  77. index_filter: Optional callable that takes (tree_sha, temp_index_path) and
  78. returns new tree SHA after modifying index
  79. parent_filter: Optional callable that takes parent list and returns
  80. modified parent list
  81. commit_filter: Optional callable that takes (Commit, tree_sha) and returns
  82. new commit SHA or None to skip commit
  83. subdirectory_filter: Optional subdirectory path to extract as new root
  84. prune_empty: Whether to prune commits that become empty
  85. tag_name_filter: Optional callable to rename tags
  86. """
  87. self.object_store = object_store
  88. self.filter_fn = filter_fn
  89. self.filter_author = filter_author
  90. self.filter_committer = filter_committer
  91. self.filter_message = filter_message
  92. self.tree_filter = tree_filter
  93. self.index_filter = index_filter
  94. self.parent_filter = parent_filter
  95. self.commit_filter = commit_filter
  96. self.subdirectory_filter = subdirectory_filter
  97. self.prune_empty = prune_empty
  98. self.tag_name_filter = tag_name_filter
  99. self._old_to_new: dict[ObjectID, ObjectID] = {}
  100. self._processed: set[ObjectID] = set()
  101. self._tree_cache: dict[ObjectID, ObjectID] = {} # Cache for filtered trees
  102. def _filter_tree_with_subdirectory(
  103. self, tree_sha: ObjectID, subdirectory: bytes
  104. ) -> ObjectID | None:
  105. """Extract a subdirectory from a tree as the new root.
  106. Args:
  107. tree_sha: SHA of the tree to filter
  108. subdirectory: Path to subdirectory to extract
  109. Returns:
  110. SHA of the new tree containing only the subdirectory, or None if not found
  111. """
  112. try:
  113. tree = self.object_store[tree_sha]
  114. if not isinstance(tree, Tree):
  115. return None
  116. except KeyError:
  117. return None
  118. # Split subdirectory path
  119. parts = subdirectory.split(b"/")
  120. current_tree = tree
  121. # Navigate to subdirectory
  122. for part in parts:
  123. if not part:
  124. continue
  125. found = False
  126. for entry in current_tree.items():
  127. if entry.path == part:
  128. try:
  129. assert entry.sha is not None
  130. obj = self.object_store[entry.sha]
  131. if isinstance(obj, Tree):
  132. current_tree = obj
  133. found = True
  134. break
  135. except KeyError:
  136. return None
  137. if not found:
  138. # Subdirectory not found, return empty tree
  139. empty_tree = Tree()
  140. self.object_store.add_object(empty_tree)
  141. return empty_tree.id
  142. # Return the subdirectory tree
  143. return current_tree.id
  144. def _apply_tree_filter(self, tree_sha: ObjectID) -> ObjectID:
  145. """Apply tree filter by checking out tree and running filter.
  146. Args:
  147. tree_sha: SHA of the tree to filter
  148. Returns:
  149. SHA of the filtered tree
  150. """
  151. if tree_sha in self._tree_cache:
  152. return self._tree_cache[tree_sha]
  153. if not self.tree_filter:
  154. self._tree_cache[tree_sha] = tree_sha
  155. return tree_sha
  156. # Create temporary directory
  157. with tempfile.TemporaryDirectory() as tmpdir:
  158. # Check out tree to temp directory
  159. # We need a proper checkout implementation here
  160. # For now, pass tmpdir to filter and let it handle checkout
  161. new_tree_sha = self.tree_filter(tree_sha, tmpdir)
  162. if new_tree_sha is None:
  163. new_tree_sha = tree_sha
  164. self._tree_cache[tree_sha] = new_tree_sha
  165. return new_tree_sha
  166. def _apply_index_filter(self, tree_sha: ObjectID) -> ObjectID:
  167. """Apply index filter by creating temp index and running filter.
  168. Args:
  169. tree_sha: SHA of the tree to filter
  170. Returns:
  171. SHA of the filtered tree
  172. """
  173. if tree_sha in self._tree_cache:
  174. return self._tree_cache[tree_sha]
  175. if not self.index_filter:
  176. self._tree_cache[tree_sha] = tree_sha
  177. return tree_sha
  178. # Create temporary index file
  179. with tempfile.NamedTemporaryFile(delete=False) as tmp_index:
  180. tmp_index_path = tmp_index.name
  181. try:
  182. # Build index from tree
  183. build_index_from_tree(".", tmp_index_path, self.object_store, tree_sha)
  184. # Run index filter
  185. new_tree_sha = self.index_filter(tree_sha, tmp_index_path)
  186. if new_tree_sha is None:
  187. # Read back the modified index and create new tree
  188. index = Index(tmp_index_path)
  189. new_tree_sha = index.commit(self.object_store)
  190. self._tree_cache[tree_sha] = new_tree_sha
  191. return new_tree_sha
  192. finally:
  193. os.unlink(tmp_index_path)
  194. def process_commit(self, commit_sha: ObjectID) -> ObjectID | None:
  195. """Process a single commit, creating a filtered version.
  196. Args:
  197. commit_sha: SHA of the commit to process
  198. Returns:
  199. SHA of the new commit, or None if object not found
  200. """
  201. if commit_sha in self._processed:
  202. return self._old_to_new.get(commit_sha, commit_sha)
  203. self._processed.add(commit_sha)
  204. try:
  205. commit = self.object_store[commit_sha]
  206. except KeyError:
  207. # Object not found
  208. return None
  209. if not isinstance(commit, Commit):
  210. # Not a commit, return as-is
  211. self._old_to_new[commit_sha] = commit_sha
  212. return commit_sha
  213. # Process parents first
  214. new_parents = []
  215. for parent in commit.parents:
  216. new_parent = self.process_commit(parent)
  217. if new_parent: # Skip None parents
  218. new_parents.append(new_parent)
  219. # Apply parent filter
  220. if self.parent_filter:
  221. new_parents = self.parent_filter(new_parents)
  222. # Apply tree filters
  223. new_tree = commit.tree
  224. # Subdirectory filter takes precedence
  225. if self.subdirectory_filter:
  226. filtered_tree = self._filter_tree_with_subdirectory(
  227. commit.tree, self.subdirectory_filter
  228. )
  229. if filtered_tree:
  230. new_tree = filtered_tree
  231. # Then apply tree filter
  232. if self.tree_filter:
  233. new_tree = self._apply_tree_filter(new_tree)
  234. # Or apply index filter
  235. elif self.index_filter:
  236. new_tree = self._apply_index_filter(new_tree)
  237. # Check if we should prune empty commits
  238. if self.prune_empty and len(new_parents) == 1:
  239. # Check if tree is same as parent's tree
  240. parent_commit = self.object_store[new_parents[0]]
  241. if isinstance(parent_commit, Commit) and parent_commit.tree == new_tree:
  242. # This commit doesn't change anything, skip it
  243. self._old_to_new[commit_sha] = new_parents[0]
  244. return new_parents[0]
  245. # Apply filters
  246. new_data: CommitData = {}
  247. # Custom filter function takes precedence
  248. if self.filter_fn:
  249. filtered = self.filter_fn(commit)
  250. if filtered:
  251. new_data.update(filtered)
  252. # Apply specific filters
  253. if self.filter_author and "author" not in new_data:
  254. new_author = self.filter_author(commit.author)
  255. if new_author is not None:
  256. new_data["author"] = new_author
  257. if self.filter_committer and "committer" not in new_data:
  258. new_committer = self.filter_committer(commit.committer)
  259. if new_committer is not None:
  260. new_data["committer"] = new_committer
  261. if self.filter_message and "message" not in new_data:
  262. new_message = self.filter_message(commit.message)
  263. if new_message is not None:
  264. new_data["message"] = new_message
  265. # Create new commit if anything changed
  266. if new_data or new_parents != commit.parents or new_tree != commit.tree:
  267. new_commit = Commit()
  268. new_commit.tree = new_tree
  269. new_commit.parents = new_parents
  270. new_commit.author = new_data.get("author", commit.author)
  271. new_commit.author_time = new_data.get("author_time", commit.author_time)
  272. new_commit.author_timezone = new_data.get(
  273. "author_timezone", commit.author_timezone
  274. )
  275. new_commit.committer = new_data.get("committer", commit.committer)
  276. new_commit.commit_time = new_data.get("commit_time", commit.commit_time)
  277. new_commit.commit_timezone = new_data.get(
  278. "commit_timezone", commit.commit_timezone
  279. )
  280. new_commit.message = new_data.get("message", commit.message)
  281. new_commit.encoding = new_data.get("encoding", commit.encoding)
  282. # Copy extra fields
  283. if hasattr(commit, "_author_timezone_neg_utc"):
  284. new_commit._author_timezone_neg_utc = commit._author_timezone_neg_utc
  285. if hasattr(commit, "_commit_timezone_neg_utc"):
  286. new_commit._commit_timezone_neg_utc = commit._commit_timezone_neg_utc
  287. if hasattr(commit, "_extra"):
  288. new_commit._extra = list(commit._extra)
  289. if hasattr(commit, "_gpgsig"):
  290. new_commit._gpgsig = commit._gpgsig
  291. if hasattr(commit, "_mergetag"):
  292. new_commit._mergetag = list(commit._mergetag)
  293. # Apply commit filter if provided
  294. if self.commit_filter:
  295. # The commit filter can create a completely new commit
  296. new_commit_sha = self.commit_filter(new_commit, new_tree)
  297. if new_commit_sha is None:
  298. # Skip this commit
  299. if len(new_parents) == 1:
  300. self._old_to_new[commit_sha] = new_parents[0]
  301. return new_parents[0]
  302. elif len(new_parents) == 0:
  303. return None
  304. else:
  305. # Multiple parents, can't skip
  306. # Store the new commit anyway
  307. self.object_store.add_object(new_commit)
  308. self._old_to_new[commit_sha] = new_commit.id
  309. return new_commit.id
  310. else:
  311. self._old_to_new[commit_sha] = new_commit_sha
  312. return new_commit_sha
  313. else:
  314. # Store the new commit
  315. self.object_store.add_object(new_commit)
  316. self._old_to_new[commit_sha] = new_commit.id
  317. return new_commit.id
  318. else:
  319. # No changes, keep original
  320. self._old_to_new[commit_sha] = commit_sha
  321. return commit_sha
  322. def get_mapping(self) -> dict[ObjectID, ObjectID]:
  323. """Get the mapping of old commit SHAs to new commit SHAs.
  324. Returns:
  325. Dictionary mapping old SHAs to new SHAs
  326. """
  327. return self._old_to_new.copy()
  328. def filter_refs(
  329. refs: RefsContainer,
  330. object_store: BaseObjectStore,
  331. ref_names: Sequence[bytes],
  332. commit_filter: CommitFilter,
  333. *,
  334. keep_original: bool = True,
  335. force: bool = False,
  336. tag_callback: Callable[[Ref, Ref], None] | None = None,
  337. ) -> dict[ObjectID, ObjectID]:
  338. """Filter commits reachable from the given refs.
  339. Args:
  340. refs: Repository refs container
  341. object_store: Object store containing commits
  342. ref_names: List of ref names to filter
  343. commit_filter: CommitFilter instance to use
  344. keep_original: Keep original refs under refs/original/
  345. force: Force operation even if refs have been filtered before
  346. tag_callback: Optional callback for processing tags
  347. Returns:
  348. Dictionary mapping old commit SHAs to new commit SHAs
  349. Raises:
  350. ValueError: If refs have already been filtered and force is False
  351. """
  352. # Check if already filtered
  353. if keep_original and not force:
  354. for ref in ref_names:
  355. original_ref = Ref(b"refs/original/" + ref)
  356. if original_ref in refs:
  357. raise ValueError(
  358. f"Branch {ref.decode()} appears to have been filtered already. "
  359. "Use force=True to force re-filtering."
  360. )
  361. # Process commits starting from refs
  362. for ref in ref_names:
  363. try:
  364. # Get the commit SHA for this ref
  365. ref_obj = Ref(ref)
  366. if ref_obj in refs:
  367. ref_sha = refs[ref_obj]
  368. if ref_sha:
  369. commit_filter.process_commit(ref_sha)
  370. except KeyError:
  371. # Skip refs that can't be resolved
  372. warnings.warn(f"Could not process ref {ref!r}: ref not found")
  373. continue
  374. # Update refs
  375. mapping = commit_filter.get_mapping()
  376. for ref in ref_names:
  377. try:
  378. ref_obj = Ref(ref)
  379. if ref_obj in refs:
  380. old_sha = refs[ref_obj]
  381. new_sha = mapping.get(old_sha, old_sha)
  382. if old_sha != new_sha:
  383. # Save original ref if requested
  384. if keep_original:
  385. original_ref = Ref(b"refs/original/" + ref)
  386. refs[original_ref] = old_sha
  387. # Update ref to new commit
  388. refs[ref_obj] = new_sha
  389. except KeyError:
  390. # Not a valid ref, skip updating
  391. warnings.warn(f"Could not update ref {ref!r}: ref not found")
  392. continue
  393. # Handle tag filtering
  394. if commit_filter.tag_name_filter and tag_callback:
  395. # Process all tags
  396. for ref in refs.allkeys():
  397. if ref.startswith(b"refs/tags/"):
  398. # Get the tag object or commit it points to
  399. tag_sha = refs[ref]
  400. tag_obj = object_store[tag_sha]
  401. tag_name = ref[10:] # Remove 'refs/tags/'
  402. # Check if it's an annotated tag
  403. if isinstance(tag_obj, Tag):
  404. # Get the commit it points to
  405. target_sha = tag_obj.object[1]
  406. # Process tag if:
  407. # 1. It points to a rewritten commit, OR
  408. # 2. We want to rename the tag regardless
  409. if (
  410. target_sha in mapping
  411. or commit_filter.tag_name_filter is not None
  412. ):
  413. new_tag_name = commit_filter.tag_name_filter(tag_name)
  414. if new_tag_name and new_tag_name != tag_name:
  415. # For annotated tags pointing to rewritten commits,
  416. # we need to create a new tag object
  417. if target_sha in mapping:
  418. new_target = mapping[target_sha]
  419. # Create new tag object pointing to rewritten commit
  420. new_tag = Tag()
  421. new_tag.object = (tag_obj.object[0], new_target)
  422. new_tag.name = new_tag_name
  423. new_tag.message = tag_obj.message
  424. new_tag.tagger = tag_obj.tagger
  425. new_tag.tag_time = tag_obj.tag_time
  426. new_tag.tag_timezone = tag_obj.tag_timezone
  427. object_store.add_object(new_tag)
  428. # Update ref to point to new tag object
  429. refs[local_tag_name(new_tag_name)] = new_tag.id
  430. # Delete old tag
  431. del refs[ref]
  432. else:
  433. # Just rename the tag
  434. new_ref = local_tag_name(new_tag_name)
  435. tag_callback(ref, new_ref)
  436. elif isinstance(tag_obj, Commit):
  437. # Lightweight tag - points directly to a commit
  438. # Process if commit was rewritten or we want to rename
  439. if tag_sha in mapping or commit_filter.tag_name_filter is not None:
  440. new_tag_name = commit_filter.tag_name_filter(tag_name)
  441. if new_tag_name and new_tag_name != tag_name:
  442. new_ref = local_tag_name(new_tag_name)
  443. if tag_sha in mapping:
  444. # Point to rewritten commit
  445. refs[new_ref] = mapping[tag_sha]
  446. del refs[ref]
  447. else:
  448. # Just rename
  449. tag_callback(ref, new_ref)
  450. return mapping