Răsfoiți Sursa

Add filter-branch support to porcelain

This implements git filter-branch functionality in Dulwich, allowing users
to rewrite commit history by modifying author, committer, or commit messages.

Key features:
- New `dulwich.filter_branch` module with reusable CommitFilter class
- `porcelain.filter_branch()` function for high-level usage
- Support for filtering author, committer, and message fields
- Custom filter functions for advanced transformations
- Preserves original refs under refs/original/
- Force flag for re-filtering already filtered branches

Fixes #745
Jelmer Vernooij 1 lună în urmă
părinte
comite
1e39f5ec05
6 a modificat fișierele cu 859 adăugiri și 0 ștergeri
  1. 5 0
      NEWS
  2. 236 0
      dulwich/filter_branch.py
  3. 93 0
      dulwich/porcelain.py
  4. 157 0
      examples/filter_branch.py
  5. 203 0
      tests/test_filter_branch.py
  6. 165 0
      tests/test_porcelain.py

+ 5 - 0
NEWS

@@ -52,6 +52,11 @@
  * Add support for auto garbage collection, and invoke from
    some porcelain commands. (Jelmer Vernooij, #1600)
 
+ * Add ``filter-branch`` support to ``dulwich.porcelain`` and
+   ``dulwich.filter_branch`` module for rewriting commit history.
+   Supports filtering author, committer, and message fields.
+   (#745, Jelmer Vernooij)
+
 0.23.0	2025-06-21
 
  * Add basic ``rebase`` subcommand. (Jelmer Vernooij)

+ 236 - 0
dulwich/filter_branch.py

@@ -0,0 +1,236 @@
+# filter_branch.py - Git filter-branch functionality
+# Copyright (C) 2024 Jelmer Vernooij <jelmer@jelmer.uk>
+#
+# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
+# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
+# General Public License as public by the Free Software Foundation; version 2.0
+# or (at your option) any later version. You can redistribute it and/or
+# modify it under the terms of either of these two licenses.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# You should have received a copy of the licenses; if not, see
+# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
+# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
+# License, Version 2.0.
+#
+
+"""Git filter-branch implementation."""
+
+from typing import Callable, Optional
+
+from .object_store import BaseObjectStore
+from .objects import Commit
+from .refs import RefsContainer
+
+
+class CommitFilter:
+    """Filter for rewriting commits during filter-branch operations."""
+    
+    def __init__(
+        self,
+        object_store: BaseObjectStore,
+        *,
+        filter_fn: Optional[Callable[[Commit], Optional[dict[str, bytes]]]] = None,
+        filter_author: Optional[Callable[[bytes], Optional[bytes]]] = None,
+        filter_committer: Optional[Callable[[bytes], Optional[bytes]]] = None,
+        filter_message: Optional[Callable[[bytes], Optional[bytes]]] = None,
+    ):
+        """Initialize a commit filter.
+        
+        Args:
+          object_store: Object store to read from and write to
+          filter_fn: Optional callable that takes a Commit object and returns
+            a dict of updated fields (author, committer, message, etc.)
+          filter_author: Optional callable that takes author bytes and returns
+            updated author bytes or None to keep unchanged
+          filter_committer: Optional callable that takes committer bytes and returns  
+            updated committer bytes or None to keep unchanged
+          filter_message: Optional callable that takes commit message bytes
+            and returns updated message bytes
+        """
+        self.object_store = object_store
+        self.filter_fn = filter_fn
+        self.filter_author = filter_author
+        self.filter_committer = filter_committer
+        self.filter_message = filter_message
+        self._old_to_new: dict[bytes, bytes] = {}
+        self._processed: set[bytes] = set()
+    
+    def process_commit(self, commit_sha: bytes) -> Optional[bytes]:
+        """Process a single commit, creating a filtered version.
+        
+        Args:
+          commit_sha: SHA of the commit to process
+          
+        Returns:
+          SHA of the new commit, or None if object not found
+        """
+        if commit_sha in self._processed:
+            return self._old_to_new.get(commit_sha, commit_sha)
+        
+        self._processed.add(commit_sha)
+        
+        try:
+            commit = self.object_store[commit_sha]
+        except KeyError:
+            # Object not found
+            return None
+        
+        if not isinstance(commit, Commit):
+            # Not a commit, return as-is
+            self._old_to_new[commit_sha] = commit_sha
+            return commit_sha
+        
+        # Process parents first
+        new_parents = []
+        for parent in commit.parents:
+            new_parent = self.process_commit(parent)
+            if new_parent:  # Skip None parents
+                new_parents.append(new_parent)
+        
+        # Apply filters
+        new_data = {}
+        
+        # Custom filter function takes precedence
+        if self.filter_fn:
+            filtered = self.filter_fn(commit)
+            if filtered:
+                new_data.update(filtered)
+        
+        # Apply specific filters
+        if self.filter_author and "author" not in new_data:
+            new_author = self.filter_author(commit.author)
+            if new_author is not None:
+                new_data["author"] = new_author
+        
+        if self.filter_committer and "committer" not in new_data:
+            new_committer = self.filter_committer(commit.committer)
+            if new_committer is not None:
+                new_data["committer"] = new_committer
+        
+        if self.filter_message and "message" not in new_data:
+            new_message = self.filter_message(commit.message)
+            if new_message is not None:
+                new_data["message"] = new_message
+        
+        # Create new commit if anything changed
+        if new_data or new_parents != commit.parents:
+            new_commit = Commit()
+            new_commit.tree = commit.tree
+            new_commit.parents = new_parents
+            new_commit.author = new_data.get("author", commit.author)
+            new_commit.author_time = new_data.get("author_time", commit.author_time)
+            new_commit.author_timezone = new_data.get("author_timezone", commit.author_timezone)
+            new_commit.committer = new_data.get("committer", commit.committer)
+            new_commit.commit_time = new_data.get("commit_time", commit.commit_time)
+            new_commit.commit_timezone = new_data.get("commit_timezone", commit.commit_timezone)
+            new_commit.message = new_data.get("message", commit.message)
+            new_commit.encoding = new_data.get("encoding", commit.encoding)
+            
+            # Copy extra fields
+            if hasattr(commit, "_author_timezone_neg_utc"):
+                new_commit._author_timezone_neg_utc = commit._author_timezone_neg_utc
+            if hasattr(commit, "_commit_timezone_neg_utc"):
+                new_commit._commit_timezone_neg_utc = commit._commit_timezone_neg_utc
+            if hasattr(commit, "_extra"):
+                new_commit._extra = list(commit._extra)
+            if hasattr(commit, "_gpgsig"):
+                new_commit._gpgsig = commit._gpgsig
+            if hasattr(commit, "_mergetag"):
+                new_commit._mergetag = list(commit._mergetag)
+            
+            # Store the new commit
+            self.object_store.add_object(new_commit)
+            self._old_to_new[commit_sha] = new_commit.id
+            return new_commit.id
+        else:
+            # No changes, keep original
+            self._old_to_new[commit_sha] = commit_sha
+            return commit_sha
+    
+    def get_mapping(self) -> dict[bytes, bytes]:
+        """Get the mapping of old commit SHAs to new commit SHAs.
+        
+        Returns:
+          Dictionary mapping old SHAs to new SHAs
+        """
+        return self._old_to_new.copy()
+
+
+def filter_refs(
+    refs: RefsContainer,
+    object_store: BaseObjectStore,
+    ref_names: list[bytes],
+    commit_filter: CommitFilter,
+    *,
+    keep_original: bool = True,
+    force: bool = False,
+) -> dict[bytes, bytes]:
+    """Filter commits reachable from the given refs.
+    
+    Args:
+      refs: Repository refs container
+      object_store: Object store containing commits
+      ref_names: List of ref names to filter
+      commit_filter: CommitFilter instance to use
+      keep_original: Keep original refs under refs/original/
+      force: Force operation even if refs have been filtered before
+      
+    Returns:
+      Dictionary mapping old commit SHAs to new commit SHAs
+      
+    Raises:
+      ValueError: If refs have already been filtered and force is False
+    """
+    # Check if already filtered
+    if keep_original and not force:
+        for ref in ref_names:
+            original_ref = b"refs/original/" + ref
+            if original_ref in refs:
+                raise ValueError(
+                    f"Branch {ref.decode()} appears to have been filtered already. "
+                    "Use force=True to force re-filtering."
+                )
+    
+    # Process commits starting from refs
+    for ref in ref_names:
+        try:
+            # Get the commit SHA for this ref
+            if ref in refs:
+                ref_sha = refs[ref]
+                if ref_sha:
+                    commit_filter.process_commit(ref_sha)
+        except (KeyError, ValueError) as e:
+            # Skip refs that can't be resolved
+            import warnings
+            warnings.warn(f"Could not process ref {ref!r}: {e}")
+            continue
+    
+    # Update refs
+    mapping = commit_filter.get_mapping()
+    for ref in ref_names:
+        try:
+            if ref in refs:
+                old_sha = refs[ref]
+                new_sha = mapping.get(old_sha, old_sha)
+                
+                if old_sha != new_sha:
+                    # Save original ref if requested
+                    if keep_original:
+                        original_ref = b"refs/original/" + ref
+                        refs[original_ref] = old_sha
+                    
+                    # Update ref to new commit
+                    refs[ref] = new_sha
+        except KeyError as e:
+            # Not a valid ref, skip updating
+            import warnings
+            warnings.warn(f"Could not update ref {ref!r}: {e}")
+            continue
+    
+    return mapping

+ 93 - 0
dulwich/porcelain.py

@@ -36,6 +36,7 @@ Currently implemented:
  * describe
  * diff_tree
  * fetch
+ * filter_branch
  * for_each_ref
  * init
  * ls_files
@@ -3654,3 +3655,95 @@ def annotate(repo, path, committish=None):
 
 
 blame = annotate
+
+
+def filter_branch(
+    repo=".",
+    branch="HEAD",
+    *,
+    filter_fn=None,
+    filter_author=None,
+    filter_committer=None,
+    filter_message=None,
+    force=False,
+    keep_original=True,
+    refs=None,
+):
+    """Rewrite branch history by creating new commits with filtered properties.
+
+    This is similar to git filter-branch, allowing you to rewrite commit
+    history by modifying author, committer, or commit messages.
+
+    Args:
+      repo: Path to repository
+      branch: Branch to rewrite (defaults to HEAD)
+      filter_fn: Optional callable that takes a Commit object and returns
+        a dict of updated fields (author, committer, message, etc.)
+      filter_author: Optional callable that takes author bytes and returns
+        updated author bytes or None to keep unchanged
+      filter_committer: Optional callable that takes committer bytes and returns  
+        updated committer bytes or None to keep unchanged
+      filter_message: Optional callable that takes commit message bytes
+        and returns updated message bytes
+      force: Force operation even if branch has been filtered before
+      keep_original: Keep original refs under refs/original/
+      refs: List of refs to rewrite (defaults to [branch])
+
+    Returns:
+      Dict mapping old commit SHAs to new commit SHAs
+
+    Raises:
+      Error: If branch is already filtered and force is False
+    """
+    from .filter_branch import CommitFilter, filter_refs
+    
+    with open_repo_closing(repo) as r:
+        # Parse branch/committish
+        if isinstance(branch, str):
+            branch = branch.encode()
+        
+        # Determine which refs to process
+        if refs is None:
+            if branch == b"HEAD":
+                # Resolve HEAD to actual branch
+                try:
+                    resolved = r.refs.follow(b"HEAD")
+                    if resolved and resolved[0]:
+                        # resolved is a list of (refname, sha) tuples
+                        resolved_ref = resolved[0][0]
+                        if resolved_ref and resolved_ref != b"HEAD":
+                            refs = [resolved_ref]
+                        else:
+                            # HEAD points directly to a commit
+                            refs = [b"HEAD"]
+                    else:
+                        refs = [b"HEAD"]
+                except Exception:
+                    refs = [b"HEAD"]
+            else:
+                # Convert branch name to full ref if needed
+                if not branch.startswith(b"refs/"):
+                    branch = b"refs/heads/" + branch
+                refs = [branch]
+        
+        # Create commit filter
+        commit_filter = CommitFilter(
+            r.object_store,
+            filter_fn=filter_fn,
+            filter_author=filter_author,
+            filter_committer=filter_committer,
+            filter_message=filter_message,
+        )
+        
+        # Filter refs
+        try:
+            return filter_refs(
+                r.refs,
+                r.object_store,
+                refs,
+                commit_filter,
+                keep_original=keep_original,
+                force=force,
+            )
+        except ValueError as e:
+            raise Error(str(e)) from e

+ 157 - 0
examples/filter_branch.py

@@ -0,0 +1,157 @@
+#!/usr/bin/env python3
+"""Example of using filter-branch to rewrite commit history.
+
+This demonstrates how to use dulwich's filter-branch functionality to:
+- Change author/committer information
+- Modify commit messages
+- Apply custom filters
+
+The example shows both the high-level porcelain interface and the 
+lower-level filter_branch module API.
+"""
+
+import sys
+
+from dulwich import porcelain
+from dulwich.filter_branch import CommitFilter, filter_refs
+from dulwich.repo import Repo
+
+
+def example_change_author(repo_path):
+    """Example: Change all commits to have a new author."""
+    print("Changing author for all commits...")
+    
+    def new_author(old_author):
+        # Change any commit by "Old Author" to "New Author"
+        if b"Old Author" in old_author:
+            return b"New Author <new@example.com>"
+        return old_author
+    
+    result = porcelain.filter_branch(
+        repo_path,
+        "HEAD",
+        filter_author=new_author
+    )
+    
+    print(f"Rewrote {len(result)} commits")
+    return result
+
+
+def example_prefix_messages(repo_path):
+    """Example: Add a prefix to all commit messages."""
+    print("Adding prefix to commit messages...")
+    
+    def add_prefix(message):
+        return b"[PROJECT-123] " + message
+    
+    result = porcelain.filter_branch(
+        repo_path,
+        "HEAD",
+        filter_message=add_prefix
+    )
+    
+    print(f"Rewrote {len(result)} commits")
+    return result
+
+
+def example_custom_filter(repo_path):
+    """Example: Custom filter that changes multiple fields."""
+    print("Applying custom filter...")
+    
+    def custom_filter(commit):
+        # This filter:
+        # - Standardizes author format
+        # - Adds issue number to message if missing
+        # - Updates committer to match author
+        
+        changes = {}
+        
+        # Standardize author format
+        if b"<" not in commit.author:
+            changes["author"] = commit.author + b" <unknown@example.com>"
+        
+        # Add issue number if missing
+        if not commit.message.startswith(b"[") and not commit.message.startswith(b"Merge"):
+            changes["message"] = b"[LEGACY] " + commit.message
+        
+        # Make committer match author
+        if commit.author != commit.committer:
+            changes["committer"] = commit.author
+            
+        return changes if changes else None
+    
+    result = porcelain.filter_branch(
+        repo_path,
+        "HEAD",
+        filter_fn=custom_filter
+    )
+    
+    print(f"Rewrote {len(result)} commits")
+    return result
+
+
+def example_low_level_api(repo_path):
+    """Example: Using the low-level filter_branch module API."""
+    print("Using low-level filter_branch API...")
+    
+    with Repo(repo_path) as repo:
+        # Create a custom filter
+        def transform_message(msg):
+            # Add timestamp and uppercase first line
+            lines = msg.split(b'\n')
+            if lines:
+                lines[0] = lines[0].upper()
+            return b'[TRANSFORMED] ' + b'\n'.join(lines)
+        
+        # Create the commit filter
+        commit_filter = CommitFilter(
+            repo.object_store,
+            filter_message=transform_message,
+            filter_author=lambda a: b"Transformed Author <transformed@example.com>"
+        )
+        
+        # Filter the master branch
+        result = filter_refs(
+            repo.refs,
+            repo.object_store,
+            [b"refs/heads/master"],
+            commit_filter,
+            keep_original=True,
+            force=False,
+        )
+        
+        print(f"Rewrote {len(result)} commits using low-level API")
+        return result
+
+
+def main():
+    if len(sys.argv) < 2:
+        print("Usage: filter_branch.py <repo_path> [example]")
+        print("Examples: change_author, prefix_messages, custom_filter, low_level")
+        sys.exit(1)
+    
+    repo_path = sys.argv[1]
+    example = sys.argv[2] if len(sys.argv) > 2 else "change_author"
+    
+    examples = {
+        "change_author": example_change_author,
+        "prefix_messages": example_prefix_messages,
+        "custom_filter": example_custom_filter,
+        "low_level": example_low_level_api,
+    }
+    
+    if example not in examples:
+        print(f"Unknown example: {example}")
+        print(f"Available examples: {', '.join(examples.keys())}")
+        sys.exit(1)
+    
+    try:
+        examples[example](repo_path)
+        print("Filter-branch completed successfully!")
+    except Exception as e:
+        print(f"Error: {e}")
+        sys.exit(1)
+
+
+if __name__ == "__main__":
+    main()

+ 203 - 0
tests/test_filter_branch.py

@@ -0,0 +1,203 @@
+# test_filter_branch.py -- Tests for filter_branch module
+# Copyright (C) 2024 Jelmer Vernooij <jelmer@jelmer.uk>
+#
+# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
+# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
+# General Public License as public by the Free Software Foundation; version 2.0
+# or (at your option) any later version. You can redistribute it and/or
+# modify it under the terms of either of these two licenses.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# You should have received a copy of the licenses; if not, see
+# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
+# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
+# License, Version 2.0.
+#
+
+"""Tests for dulwich.filter_branch."""
+
+import unittest
+
+from dulwich.filter_branch import CommitFilter, filter_refs
+from dulwich.object_store import MemoryObjectStore
+from dulwich.objects import Commit, Tree
+from dulwich.refs import DictRefsContainer
+
+
+class CommitFilterTests(unittest.TestCase):
+    """Tests for CommitFilter class."""
+    
+    def setUp(self):
+        self.store = MemoryObjectStore()
+        self.refs = DictRefsContainer({})
+        
+        # Create test commits
+        tree = Tree()
+        self.store.add_object(tree)
+        
+        self.c1 = Commit()
+        self.c1.tree = tree.id
+        self.c1.author = self.c1.committer = b"Test User <test@example.com>"
+        self.c1.author_time = self.c1.commit_time = 1000
+        self.c1.author_timezone = self.c1.commit_timezone = 0
+        self.c1.message = b"First commit"
+        self.store.add_object(self.c1)
+        
+        self.c2 = Commit()
+        self.c2.tree = tree.id
+        self.c2.parents = [self.c1.id]
+        self.c2.author = self.c2.committer = b"Test User <test@example.com>"
+        self.c2.author_time = self.c2.commit_time = 2000
+        self.c2.author_timezone = self.c2.commit_timezone = 0
+        self.c2.message = b"Second commit"
+        self.store.add_object(self.c2)
+    
+    def test_filter_author(self):
+        """Test filtering author."""
+        def new_author(old):
+            return b"New Author <new@example.com>"
+        
+        filter = CommitFilter(self.store, filter_author=new_author)
+        new_sha = filter.process_commit(self.c2.id)
+        
+        self.assertNotEqual(new_sha, self.c2.id)
+        new_commit = self.store[new_sha]
+        self.assertEqual(new_commit.author, b"New Author <new@example.com>")
+        self.assertEqual(new_commit.committer, self.c2.committer)
+        
+    def test_filter_message(self):
+        """Test filtering message."""
+        def prefix_message(msg):
+            return b"[PREFIX] " + msg
+        
+        filter = CommitFilter(self.store, filter_message=prefix_message)
+        new_sha = filter.process_commit(self.c2.id)
+        
+        self.assertNotEqual(new_sha, self.c2.id)
+        new_commit = self.store[new_sha]
+        self.assertEqual(new_commit.message, b"[PREFIX] Second commit")
+        
+    def test_filter_fn(self):
+        """Test custom filter function."""
+        def custom_filter(commit):
+            return {
+                "author": b"Custom <custom@example.com>",
+                "message": b"Custom: " + commit.message,
+            }
+        
+        filter = CommitFilter(self.store, filter_fn=custom_filter)
+        new_sha = filter.process_commit(self.c2.id)
+        
+        self.assertNotEqual(new_sha, self.c2.id)
+        new_commit = self.store[new_sha]
+        self.assertEqual(new_commit.author, b"Custom <custom@example.com>")
+        self.assertEqual(new_commit.message, b"Custom: Second commit")
+        
+    def test_no_changes(self):
+        """Test commit with no changes."""
+        filter = CommitFilter(self.store)
+        new_sha = filter.process_commit(self.c2.id)
+        
+        self.assertEqual(new_sha, self.c2.id)
+        
+    def test_parent_rewriting(self):
+        """Test that parent commits are rewritten."""
+        def new_author(old):
+            return b"New Author <new@example.com>"
+        
+        filter = CommitFilter(self.store, filter_author=new_author)
+        new_sha = filter.process_commit(self.c2.id)
+        
+        # Check that parent was also rewritten
+        new_commit = self.store[new_sha]
+        self.assertEqual(len(new_commit.parents), 1)
+        new_parent_sha = new_commit.parents[0]
+        self.assertNotEqual(new_parent_sha, self.c1.id)
+        
+        new_parent = self.store[new_parent_sha]
+        self.assertEqual(new_parent.author, b"New Author <new@example.com>")
+
+
+class FilterRefsTests(unittest.TestCase):
+    """Tests for filter_refs function."""
+    
+    def setUp(self):
+        self.store = MemoryObjectStore()
+        self.refs = DictRefsContainer({})
+        
+        # Create test commits
+        tree = Tree()
+        self.store.add_object(tree)
+        
+        c1 = Commit()
+        c1.tree = tree.id
+        c1.author = c1.committer = b"Test User <test@example.com>"
+        c1.author_time = c1.commit_time = 1000
+        c1.author_timezone = c1.commit_timezone = 0
+        c1.message = b"First commit"
+        self.store.add_object(c1)
+        
+        self.refs[b"refs/heads/master"] = c1.id
+        self.c1_id = c1.id
+    
+    def test_filter_refs_basic(self):
+        """Test basic ref filtering."""
+        def new_author(old):
+            return b"New Author <new@example.com>"
+        
+        filter = CommitFilter(self.store, filter_author=new_author)
+        result = filter_refs(
+            self.refs,
+            self.store,
+            [b"refs/heads/master"],
+            filter,
+        )
+        
+        # Check mapping
+        self.assertEqual(len(result), 1)
+        self.assertIn(self.c1_id, result)
+        self.assertNotEqual(result[self.c1_id], self.c1_id)
+        
+        # Check ref was updated
+        new_sha = self.refs[b"refs/heads/master"]
+        self.assertEqual(new_sha, result[self.c1_id])
+        
+        # Check original was saved
+        original_sha = self.refs[b"refs/original/refs/heads/master"]
+        self.assertEqual(original_sha, self.c1_id)
+        
+    def test_filter_refs_already_filtered(self):
+        """Test error when refs already filtered."""
+        # Set up an "already filtered" state
+        self.refs[b"refs/original/refs/heads/master"] = b"0" * 40
+        
+        filter = CommitFilter(self.store)
+        with self.assertRaises(ValueError) as cm:
+            filter_refs(
+                self.refs,
+                self.store,
+                [b"refs/heads/master"],
+                filter,
+            )
+        self.assertIn("filtered already", str(cm.exception))
+        
+    def test_filter_refs_force(self):
+        """Test force filtering."""
+        # Set up an "already filtered" state
+        self.refs[b"refs/original/refs/heads/master"] = b"0" * 40
+        
+        filter = CommitFilter(self.store)
+        # Should not raise with force=True
+        result = filter_refs(
+            self.refs,
+            self.store,
+            [b"refs/heads/master"],
+            filter,
+            force=True,
+        )
+        self.assertEqual(len(result), 1)

+ 165 - 0
tests/test_porcelain.py

@@ -5575,3 +5575,168 @@ class PruneTests(PorcelainTestCase):
 
         # Verify the file was NOT removed (dry run)
         self.assertTrue(os.path.exists(tmp_pack_path))
+
+
+class FilterBranchTests(PorcelainTestCase):
+    def setUp(self):
+        super().setUp()
+        # Create initial commits with different authors
+        from dulwich.objects import Commit, Tree
+        
+        # Create actual tree and blob objects
+        tree = Tree()
+        self.repo.object_store.add_object(tree)
+        
+        c1 = Commit()
+        c1.tree = tree.id
+        c1.parents = []
+        c1.author = b"Old Author <old@example.com>"
+        c1.author_time = 1000
+        c1.author_timezone = 0
+        c1.committer = b"Old Committer <old@example.com>"
+        c1.commit_time = 1000
+        c1.commit_timezone = 0
+        c1.message = b"Initial commit"
+        self.repo.object_store.add_object(c1)
+        
+        c2 = Commit()
+        c2.tree = tree.id
+        c2.parents = [c1.id]
+        c2.author = b"Another Author <another@example.com>"
+        c2.author_time = 2000
+        c2.author_timezone = 0
+        c2.committer = b"Another Committer <another@example.com>"
+        c2.commit_time = 2000
+        c2.commit_timezone = 0
+        c2.message = b"Second commit\n\nWith body"
+        self.repo.object_store.add_object(c2)
+        
+        c3 = Commit()
+        c3.tree = tree.id
+        c3.parents = [c2.id]
+        c3.author = b"Third Author <third@example.com>"
+        c3.author_time = 3000
+        c3.author_timezone = 0
+        c3.committer = b"Third Committer <third@example.com>"
+        c3.commit_time = 3000
+        c3.commit_timezone = 0
+        c3.message = b"Third commit"
+        self.repo.object_store.add_object(c3)
+        
+        self.repo.refs[b"refs/heads/master"] = c3.id
+        self.repo.refs.set_symbolic_ref(b"HEAD", b"refs/heads/master")
+        
+        # Store IDs for test assertions
+        self.c1_id = c1.id
+        self.c2_id = c2.id
+        self.c3_id = c3.id
+        
+    def test_filter_branch_author(self):
+        """Test filtering branch with author changes."""
+        def filter_author(author):
+            # Change all authors to "New Author"
+            return b"New Author <new@example.com>"
+        
+        result = porcelain.filter_branch(
+            self.repo_path,
+            "master",
+            filter_author=filter_author
+        )
+        
+        # Check that we have mappings for all commits
+        self.assertEqual(len(result), 3)
+        
+        # Verify the branch ref was updated
+        new_head = self.repo.refs[b"refs/heads/master"]
+        self.assertNotEqual(new_head, self.c3_id)
+        
+        # Verify the original ref was saved
+        original_ref = self.repo.refs[b"refs/original/refs/heads/master"]
+        self.assertEqual(original_ref, self.c3_id)
+        
+        # Check that authors were updated
+        new_commit = self.repo[new_head]
+        self.assertEqual(new_commit.author, b"New Author <new@example.com>")
+        
+        # Check parent chain
+        parent = self.repo[new_commit.parents[0]]
+        self.assertEqual(parent.author, b"New Author <new@example.com>")
+        
+    def test_filter_branch_message(self):
+        """Test filtering branch with message changes."""
+        def filter_message(message):
+            # Add prefix to all messages
+            return b"[FILTERED] " + message
+        
+        porcelain.filter_branch(
+            self.repo_path,
+            "master",
+            filter_message=filter_message
+        )
+        
+        # Verify messages were updated
+        new_head = self.repo.refs[b"refs/heads/master"]
+        new_commit = self.repo[new_head]
+        self.assertTrue(new_commit.message.startswith(b"[FILTERED] "))
+        
+    def test_filter_branch_custom_filter(self):
+        """Test filtering branch with custom filter function."""
+        def custom_filter(commit):
+            # Change both author and message
+            return {
+                "author": b"Custom Author <custom@example.com>",
+                "message": b"Custom: " + commit.message
+            }
+        
+        porcelain.filter_branch(
+            self.repo_path,
+            "master", 
+            filter_fn=custom_filter
+        )
+        
+        # Verify custom filter was applied
+        new_head = self.repo.refs[b"refs/heads/master"]
+        new_commit = self.repo[new_head]
+        self.assertEqual(new_commit.author, b"Custom Author <custom@example.com>")
+        self.assertTrue(new_commit.message.startswith(b"Custom: "))
+        
+    def test_filter_branch_no_changes(self):
+        """Test filtering branch with no changes."""
+        result = porcelain.filter_branch(self.repo_path, "master")
+        
+        # All commits should map to themselves
+        for old_sha, new_sha in result.items():
+            self.assertEqual(old_sha, new_sha)
+            
+        # HEAD should be unchanged
+        self.assertEqual(self.repo.refs[b"refs/heads/master"], self.c3_id)
+        
+    def test_filter_branch_force(self):
+        """Test force filtering a previously filtered branch."""
+        # First filter
+        porcelain.filter_branch(
+            self.repo_path,
+            "master",
+            filter_message=lambda m: b"First: " + m
+        )
+        
+        # Try again without force - should fail
+        with self.assertRaises(porcelain.Error):
+            porcelain.filter_branch(
+                self.repo_path,
+                "master",
+                filter_message=lambda m: b"Second: " + m
+            )
+            
+        # Try again with force - should succeed
+        porcelain.filter_branch(
+            self.repo_path,
+            "master",
+            filter_message=lambda m: b"Second: " + m,
+            force=True
+        )
+        
+        # Verify second filter was applied
+        new_head = self.repo.refs[b"refs/heads/master"]
+        new_commit = self.repo[new_head]
+        self.assertTrue(new_commit.message.startswith(b"Second: First: "))