Browse Source

Use same rebase state files as C Git

Update rebase implementation to use the exact same file names and
structure as C Git for tracking rebase state in .git/rebase-merge/:

- orig-head: Original HEAD position before rebase
- head-name: Name of the branch being rebased
- onto: SHA of the commit we're rebasing onto
- stopped-sha: Current commit being processed
- msgnum: Current commit number (1-based)
- end: Total number of commits to rebase
Jelmer Vernooij 1 month ago
parent
commit
2c6e0b0ffa
3 changed files with 272 additions and 46 deletions
  1. 237 38
      dulwich/rebase.py
  2. 25 0
      dulwich/repo.py
  3. 10 8
      tests/test_refs.py

+ 237 - 38
dulwich/rebase.py

@@ -21,7 +21,7 @@
 
 """Git rebase implementation."""
 
-from typing import Optional
+from typing import Optional, Protocol
 
 from dulwich.graph import find_merge_base
 from dulwich.merge import three_way_merge
@@ -48,6 +48,214 @@ class RebaseAbort(RebaseError):
     """Raised when rebase is aborted."""
 
 
+class RebaseStateManager(Protocol):
+    """Protocol for managing rebase state."""
+
+    def save(
+        self,
+        original_head: Optional[bytes],
+        rebasing_branch: Optional[bytes],
+        onto: Optional[bytes],
+        todo: list[Commit],
+        done: list[Commit],
+    ) -> None:
+        """Save rebase state."""
+        ...
+
+    def load(
+        self,
+    ) -> tuple[
+        Optional[bytes],  # original_head
+        Optional[bytes],  # rebasing_branch
+        Optional[bytes],  # onto
+        list[Commit],  # todo
+        list[Commit],  # done
+    ]:
+        """Load rebase state."""
+        ...
+
+    def clean(self) -> None:
+        """Clean up rebase state."""
+        ...
+
+    def exists(self) -> bool:
+        """Check if rebase state exists."""
+        ...
+
+
+class DiskRebaseStateManager:
+    """Manages rebase state on disk using same files as C Git."""
+
+    def __init__(self, repo: Repo) -> None:
+        self.repo = repo
+
+    def save(
+        self,
+        original_head: Optional[bytes],
+        rebasing_branch: Optional[bytes],
+        onto: Optional[bytes],
+        todo: list[Commit],
+        done: list[Commit],
+    ) -> None:
+        """Save rebase state to disk."""
+        # For file-based repos, ensure the directory exists
+        if hasattr(self.repo, "controldir"):
+            import os
+
+            rebase_dir = os.path.join(self.repo.controldir(), "rebase-merge")
+            if not os.path.exists(rebase_dir):
+                os.makedirs(rebase_dir)
+
+        # Store the original HEAD ref (e.g. "refs/heads/feature")
+        if original_head:
+            self.repo._put_named_file("rebase-merge/orig-head", original_head)
+
+        # Store the branch name being rebased
+        if rebasing_branch:
+            self.repo._put_named_file("rebase-merge/head-name", rebasing_branch)
+
+        # Store the commit we're rebasing onto
+        if onto:
+            self.repo._put_named_file("rebase-merge/onto", onto)
+
+        # Track progress
+        if todo:
+            # Store the current commit being rebased (same as C Git)
+            current_commit = todo[0]
+            self.repo._put_named_file("rebase-merge/stopped-sha", current_commit.id)
+
+            # Store progress counters
+            msgnum = len(done) + 1  # Current commit number (1-based)
+            end = len(done) + len(todo)  # Total number of commits
+            self.repo._put_named_file("rebase-merge/msgnum", str(msgnum).encode())
+            self.repo._put_named_file("rebase-merge/end", str(end).encode())
+
+        # TODO: Add support for writing git-rebase-todo for interactive rebase
+
+    def load(
+        self,
+    ) -> tuple[
+        Optional[bytes],
+        Optional[bytes],
+        Optional[bytes],
+        list[Commit],
+        list[Commit],
+    ]:
+        """Load rebase state from disk."""
+        original_head = None
+        rebasing_branch = None
+        onto = None
+        todo: list[Commit] = []
+        done: list[Commit] = []
+
+        # Load rebase state files
+        orig_head_file = self.repo.get_named_file("rebase-merge/orig-head")
+        if orig_head_file:
+            original_head = orig_head_file.read().strip()
+            orig_head_file.close()
+
+        head_name_file = self.repo.get_named_file("rebase-merge/head-name")
+        if head_name_file:
+            rebasing_branch = head_name_file.read().strip()
+            head_name_file.close()
+
+        onto_file = self.repo.get_named_file("rebase-merge/onto")
+        if onto_file:
+            onto = onto_file.read().strip()
+            onto_file.close()
+
+        # TODO: Load todo list and done list for resuming rebase
+
+        return original_head, rebasing_branch, onto, todo, done
+
+    def clean(self) -> None:
+        """Clean up rebase state files."""
+        # Clean up rebase state files matching C Git
+        for filename in [
+            "rebase-merge/stopped-sha",
+            "rebase-merge/orig-head",
+            "rebase-merge/onto",
+            "rebase-merge/head-name",
+            "rebase-merge/msgnum",
+            "rebase-merge/end",
+        ]:
+            self.repo._del_named_file(filename)
+
+        # For file-based repos, remove the directory
+        if hasattr(self.repo, "controldir"):
+            import os
+            import shutil
+
+            rebase_dir = os.path.join(self.repo.controldir(), "rebase-merge")
+            try:
+                shutil.rmtree(rebase_dir)
+            except FileNotFoundError:
+                # Directory doesn't exist, that's ok
+                pass
+
+    def exists(self) -> bool:
+        """Check if rebase state exists."""
+        f = self.repo.get_named_file("rebase-merge/orig-head")
+        if f:
+            f.close()
+            return True
+        return False
+
+
+class MemoryRebaseStateManager:
+    """Manages rebase state in memory for MemoryRepo."""
+
+    def __init__(self, repo: Repo) -> None:
+        self.repo = repo
+        self._state: Optional[dict] = None
+
+    def save(
+        self,
+        original_head: Optional[bytes],
+        rebasing_branch: Optional[bytes],
+        onto: Optional[bytes],
+        todo: list[Commit],
+        done: list[Commit],
+    ) -> None:
+        """Save rebase state in memory."""
+        self._state = {
+            "original_head": original_head,
+            "rebasing_branch": rebasing_branch,
+            "onto": onto,
+            "todo": todo[:],  # Copy the lists
+            "done": done[:],
+        }
+
+    def load(
+        self,
+    ) -> tuple[
+        Optional[bytes],
+        Optional[bytes],
+        Optional[bytes],
+        list[Commit],
+        list[Commit],
+    ]:
+        """Load rebase state from memory."""
+        if self._state is None:
+            return None, None, None, [], []
+
+        return (
+            self._state["original_head"],
+            self._state["rebasing_branch"],
+            self._state["onto"],
+            self._state["todo"][:],  # Return copies
+            self._state["done"][:],
+        )
+
+    def clean(self) -> None:
+        """Clean up rebase state."""
+        self._state = None
+
+    def exists(self) -> bool:
+        """Check if rebase state exists."""
+        return self._state is not None
+
+
 class Rebaser:
     """Handles git rebase operations."""
 
@@ -59,12 +267,18 @@ class Rebaser:
         """
         self.repo = repo
         self.object_store = repo.object_store
+        self._state_manager = repo.get_rebase_state_manager()
+
+        # Initialize state
         self._original_head = None
         self._onto = None
         self._todo = []
         self._done = []
         self._rebasing_branch = None
 
+        # Load any existing rebase state
+        self._load_rebase_state()
+
     def _get_commits_to_rebase(
         self, upstream: bytes, branch: Optional[bytes] = None
     ) -> list[Commit]:
@@ -243,9 +457,13 @@ class Rebaser:
             self._save_rebase_state()
             return (commit.id, conflicts)
 
+    def is_in_progress(self) -> bool:
+        """Check if a rebase is currently in progress."""
+        return self._state_manager.exists()
+
     def abort_rebase(self) -> None:
         """Abort an in-progress rebase and restore original state."""
-        if self._original_head is None:
+        if not self.is_in_progress():
             raise RebaseError("No rebase in progress")
 
         # Restore original HEAD
@@ -295,46 +513,27 @@ class Rebaser:
 
     def _save_rebase_state(self) -> None:
         """Save rebase state to allow resuming."""
-        # Store rebase state in named files
-        # Real git uses .git/rebase-merge/ directory
-
-        # For file-based repos, ensure the directory exists
-        if hasattr(self.repo, "controldir"):
-            import os
-
-            rebase_dir = os.path.join(self.repo.controldir(), "rebase-merge")
-            if not os.path.exists(rebase_dir):
-                os.makedirs(rebase_dir)
-
-        if self._todo:
-            # Store the current commit being rebased
-            current_commit = self._todo[0]
-            self.repo._put_named_file("rebase-merge/stopped-sha", current_commit.id)
+        self._state_manager.save(
+            self._original_head,
+            self._rebasing_branch,
+            self._onto,
+            self._todo,
+            self._done,
+        )
 
-        # Store other rebase state
-        if self._original_head:
-            self.repo._put_named_file("rebase-merge/orig-head", self._original_head)
-        if self._onto:
-            self.repo._put_named_file("rebase-merge/onto", self._onto)
+    def _load_rebase_state(self) -> None:
+        """Load existing rebase state if present."""
+        (
+            self._original_head,
+            self._rebasing_branch,
+            self._onto,
+            self._todo,
+            self._done,
+        ) = self._state_manager.load()
 
     def _clean_rebase_state(self) -> None:
         """Clean up rebase state files."""
-        # Clean up rebase state files
-        for filename in [
-            "rebase-merge/stopped-sha",
-            "rebase-merge/orig-head",
-            "rebase-merge/onto",
-        ]:
-            self.repo._del_named_file(filename)
-
-        # For file-based repos, remove the directory
-        if hasattr(self.repo, "controldir"):
-            import os
-            import shutil
-
-            rebase_dir = os.path.join(self.repo.controldir(), "rebase-merge")
-            if os.path.exists(rebase_dir):
-                shutil.rmtree(rebase_dir)
+        self._state_manager.clean()
 
 
 def rebase(

+ 25 - 0
dulwich/repo.py

@@ -736,6 +736,13 @@ class BaseRepo:
         """
         raise NotImplementedError(self.set_description)
 
+    def get_rebase_state_manager(self):
+        """Get the appropriate rebase state manager for this repository.
+
+        Returns: RebaseStateManager instance
+        """
+        raise NotImplementedError(self.get_rebase_state_manager)
+
     def get_config_stack(self) -> "StackedConfig":
         """Return a config stack for this repository.
 
@@ -1713,6 +1720,15 @@ class Repo(BaseRepo):
             ret.path = path
             return ret
 
+    def get_rebase_state_manager(self):
+        """Get the appropriate rebase state manager for this repository.
+
+        Returns: DiskRebaseStateManager instance
+        """
+        from .rebase import DiskRebaseStateManager
+
+        return DiskRebaseStateManager(self)
+
     def get_description(self):
         """Retrieve the description of this repository.
 
@@ -2074,6 +2090,15 @@ class MemoryRepo(BaseRepo):
         """
         return self._config
 
+    def get_rebase_state_manager(self):
+        """Get the appropriate rebase state manager for this repository.
+
+        Returns: MemoryRebaseStateManager instance
+        """
+        from .rebase import MemoryRebaseStateManager
+
+        return MemoryRebaseStateManager(self)
+
     @classmethod
     def init_bare(cls, objects, refs, format: Optional[int] = None):
         """Create a new bare repository in memory.

+ 10 - 8
tests/test_refs.py

@@ -384,33 +384,35 @@ class DictRefsContainerTests(RefsContainerTests, TestCase):
     def test_set_if_equals_with_symbolic_ref(self) -> None:
         # Test that set_if_equals only updates the requested ref,
         # not all refs in a symbolic reference chain
-        
+
         # The bug in the original implementation was that when follow()
         # was called on a ref, it would return all refs in the chain,
         # and set_if_equals would update ALL of them instead of just the
         # requested ref.
-        
+
         # Set up refs
         master_sha = b"1" * 40
-        feature_sha = b"2" * 40  
+        feature_sha = b"2" * 40
         new_sha = b"3" * 40
-        
+
         self._refs[b"refs/heads/master"] = master_sha
         self._refs[b"refs/heads/feature"] = feature_sha
         # Create a second symbolic ref pointing to feature
         self._refs.set_symbolic_ref(b"refs/heads/other", b"refs/heads/feature")
-        
+
         # Update refs/heads/other through set_if_equals
         # With the bug, this would update BOTH refs/heads/other AND refs/heads/feature
         # Without the bug, only refs/heads/other should be updated
         # Note: old_ref needs to be the actual stored value (the symref)
         self.assertTrue(
-            self._refs.set_if_equals(b"refs/heads/other", b"ref: refs/heads/feature", new_sha)
+            self._refs.set_if_equals(
+                b"refs/heads/other", b"ref: refs/heads/feature", new_sha
+            )
         )
-        
+
         # refs/heads/other should now directly point to new_sha
         self.assertEqual(self._refs.read_ref(b"refs/heads/other"), new_sha)
-        
+
         # refs/heads/feature should remain unchanged
         # With the bug, refs/heads/feature would also be incorrectly updated to new_sha
         self.assertEqual(self._refs[b"refs/heads/feature"], feature_sha)