filter_branch.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # filter_branch.py - Git filter-branch functionality
  2. # Copyright (C) 2024 Jelmer Vernooij <jelmer@jelmer.uk>
  3. #
  4. # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
  5. # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
  6. # General Public License as public by the Free Software Foundation; version 2.0
  7. # or (at your option) any later version. You can redistribute it and/or
  8. # modify it under the terms of either of these two licenses.
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # You should have received a copy of the licenses; if not, see
  17. # <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
  18. # and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
  19. # License, Version 2.0.
  20. #
  21. """Git filter-branch implementation."""
  22. from typing import Callable, Optional
  23. from .object_store import BaseObjectStore
  24. from .objects import Commit
  25. from .refs import RefsContainer
  26. class CommitFilter:
  27. """Filter for rewriting commits during filter-branch operations."""
  28. def __init__(
  29. self,
  30. object_store: BaseObjectStore,
  31. *,
  32. filter_fn: Optional[Callable[[Commit], Optional[dict[str, bytes]]]] = None,
  33. filter_author: Optional[Callable[[bytes], Optional[bytes]]] = None,
  34. filter_committer: Optional[Callable[[bytes], Optional[bytes]]] = None,
  35. filter_message: Optional[Callable[[bytes], Optional[bytes]]] = None,
  36. ):
  37. """Initialize a commit filter.
  38. Args:
  39. object_store: Object store to read from and write to
  40. filter_fn: Optional callable that takes a Commit object and returns
  41. a dict of updated fields (author, committer, message, etc.)
  42. filter_author: Optional callable that takes author bytes and returns
  43. updated author bytes or None to keep unchanged
  44. filter_committer: Optional callable that takes committer bytes and returns
  45. updated committer bytes or None to keep unchanged
  46. filter_message: Optional callable that takes commit message bytes
  47. and returns updated message bytes
  48. """
  49. self.object_store = object_store
  50. self.filter_fn = filter_fn
  51. self.filter_author = filter_author
  52. self.filter_committer = filter_committer
  53. self.filter_message = filter_message
  54. self._old_to_new: dict[bytes, bytes] = {}
  55. self._processed: set[bytes] = set()
  56. def process_commit(self, commit_sha: bytes) -> Optional[bytes]:
  57. """Process a single commit, creating a filtered version.
  58. Args:
  59. commit_sha: SHA of the commit to process
  60. Returns:
  61. SHA of the new commit, or None if object not found
  62. """
  63. if commit_sha in self._processed:
  64. return self._old_to_new.get(commit_sha, commit_sha)
  65. self._processed.add(commit_sha)
  66. try:
  67. commit = self.object_store[commit_sha]
  68. except KeyError:
  69. # Object not found
  70. return None
  71. if not isinstance(commit, Commit):
  72. # Not a commit, return as-is
  73. self._old_to_new[commit_sha] = commit_sha
  74. return commit_sha
  75. # Process parents first
  76. new_parents = []
  77. for parent in commit.parents:
  78. new_parent = self.process_commit(parent)
  79. if new_parent: # Skip None parents
  80. new_parents.append(new_parent)
  81. # Apply filters
  82. new_data = {}
  83. # Custom filter function takes precedence
  84. if self.filter_fn:
  85. filtered = self.filter_fn(commit)
  86. if filtered:
  87. new_data.update(filtered)
  88. # Apply specific filters
  89. if self.filter_author and "author" not in new_data:
  90. new_author = self.filter_author(commit.author)
  91. if new_author is not None:
  92. new_data["author"] = new_author
  93. if self.filter_committer and "committer" not in new_data:
  94. new_committer = self.filter_committer(commit.committer)
  95. if new_committer is not None:
  96. new_data["committer"] = new_committer
  97. if self.filter_message and "message" not in new_data:
  98. new_message = self.filter_message(commit.message)
  99. if new_message is not None:
  100. new_data["message"] = new_message
  101. # Create new commit if anything changed
  102. if new_data or new_parents != commit.parents:
  103. new_commit = Commit()
  104. new_commit.tree = commit.tree
  105. new_commit.parents = new_parents
  106. new_commit.author = new_data.get("author", commit.author)
  107. new_commit.author_time = new_data.get("author_time", commit.author_time)
  108. new_commit.author_timezone = new_data.get("author_timezone", commit.author_timezone)
  109. new_commit.committer = new_data.get("committer", commit.committer)
  110. new_commit.commit_time = new_data.get("commit_time", commit.commit_time)
  111. new_commit.commit_timezone = new_data.get("commit_timezone", commit.commit_timezone)
  112. new_commit.message = new_data.get("message", commit.message)
  113. new_commit.encoding = new_data.get("encoding", commit.encoding)
  114. # Copy extra fields
  115. if hasattr(commit, "_author_timezone_neg_utc"):
  116. new_commit._author_timezone_neg_utc = commit._author_timezone_neg_utc
  117. if hasattr(commit, "_commit_timezone_neg_utc"):
  118. new_commit._commit_timezone_neg_utc = commit._commit_timezone_neg_utc
  119. if hasattr(commit, "_extra"):
  120. new_commit._extra = list(commit._extra)
  121. if hasattr(commit, "_gpgsig"):
  122. new_commit._gpgsig = commit._gpgsig
  123. if hasattr(commit, "_mergetag"):
  124. new_commit._mergetag = list(commit._mergetag)
  125. # Store the new commit
  126. self.object_store.add_object(new_commit)
  127. self._old_to_new[commit_sha] = new_commit.id
  128. return new_commit.id
  129. else:
  130. # No changes, keep original
  131. self._old_to_new[commit_sha] = commit_sha
  132. return commit_sha
  133. def get_mapping(self) -> dict[bytes, bytes]:
  134. """Get the mapping of old commit SHAs to new commit SHAs.
  135. Returns:
  136. Dictionary mapping old SHAs to new SHAs
  137. """
  138. return self._old_to_new.copy()
  139. def filter_refs(
  140. refs: RefsContainer,
  141. object_store: BaseObjectStore,
  142. ref_names: list[bytes],
  143. commit_filter: CommitFilter,
  144. *,
  145. keep_original: bool = True,
  146. force: bool = False,
  147. ) -> dict[bytes, bytes]:
  148. """Filter commits reachable from the given refs.
  149. Args:
  150. refs: Repository refs container
  151. object_store: Object store containing commits
  152. ref_names: List of ref names to filter
  153. commit_filter: CommitFilter instance to use
  154. keep_original: Keep original refs under refs/original/
  155. force: Force operation even if refs have been filtered before
  156. Returns:
  157. Dictionary mapping old commit SHAs to new commit SHAs
  158. Raises:
  159. ValueError: If refs have already been filtered and force is False
  160. """
  161. # Check if already filtered
  162. if keep_original and not force:
  163. for ref in ref_names:
  164. original_ref = b"refs/original/" + ref
  165. if original_ref in refs:
  166. raise ValueError(
  167. f"Branch {ref.decode()} appears to have been filtered already. "
  168. "Use force=True to force re-filtering."
  169. )
  170. # Process commits starting from refs
  171. for ref in ref_names:
  172. try:
  173. # Get the commit SHA for this ref
  174. if ref in refs:
  175. ref_sha = refs[ref]
  176. if ref_sha:
  177. commit_filter.process_commit(ref_sha)
  178. except (KeyError, ValueError) as e:
  179. # Skip refs that can't be resolved
  180. import warnings
  181. warnings.warn(f"Could not process ref {ref!r}: {e}")
  182. continue
  183. # Update refs
  184. mapping = commit_filter.get_mapping()
  185. for ref in ref_names:
  186. try:
  187. if ref in refs:
  188. old_sha = refs[ref]
  189. new_sha = mapping.get(old_sha, old_sha)
  190. if old_sha != new_sha:
  191. # Save original ref if requested
  192. if keep_original:
  193. original_ref = b"refs/original/" + ref
  194. refs[original_ref] = old_sha
  195. # Update ref to new commit
  196. refs[ref] = new_sha
  197. except KeyError as e:
  198. # Not a valid ref, skip updating
  199. import warnings
  200. warnings.warn(f"Could not update ref {ref!r}: {e}")
  201. continue
  202. return mapping