Ver Fonte

Decouple DiskRebaseStateManager from Repo class

- Changed DiskRebaseStateManager to accept a filesystem path instead
  of a Repo instance
- Added _read_file() and _write_file() helper methods for file I/O
- Simplified clean() and exists() methods to work directly with paths
- Updated Repo.get_rebase_state_manager() to pass the rebase-merge path

This makes DiskRebaseStateManager more focused and testable, as it
now only handles file I/O operations without needing knowledge of
the entire repository interface.
Jelmer Vernooij há 1 mês atrás
pai
commit
0c5203787f
4 ficheiros alterados com 72 adições e 75 exclusões
  1. 8 5
      dulwich/cli.py
  2. 59 66
      dulwich/rebase.py
  3. 4 1
      dulwich/repo.py
  4. 1 3
      tests/test_rebase.py

+ 8 - 5
dulwich/cli.py

@@ -1169,7 +1169,7 @@ class cmd_count_objects(Command):
 
 
 class cmd_rebase(Command):
-    def run(self, args) -> Optional[int]:
+    def run(self, args) -> int:
         parser = argparse.ArgumentParser()
         parser.add_argument(
             "upstream", nargs="?", help="Upstream branch to rebase onto"
@@ -1195,21 +1195,23 @@ class cmd_rebase(Command):
         # Handle abort/continue/skip first
         if args.abort:
             try:
-                porcelain.rebase(".", None, abort=True)
+                porcelain.rebase(".", args.upstream or "HEAD", abort=True)
                 print("Rebase aborted.")
             except porcelain.Error as e:
                 print(f"Error: {e}")
                 return 1
-            return
+            return 0
 
         if args.continue_rebase:
             try:
-                new_shas = porcelain.rebase(".", None, continue_rebase=True)
+                new_shas = porcelain.rebase(
+                    ".", args.upstream or "HEAD", continue_rebase=True
+                )
                 print("Rebase complete.")
             except porcelain.Error as e:
                 print(f"Error: {e}")
                 return 1
-            return
+            return 0
 
         # Normal rebase requires upstream
         if not args.upstream:
@@ -1228,6 +1230,7 @@ class cmd_rebase(Command):
                 print(f"Successfully rebased {len(new_shas)} commits.")
             else:
                 print("Already up to date.")
+            return 0
 
         except porcelain.Error as e:
             print(f"Error: {e}")

+ 59 - 66
dulwich/rebase.py

@@ -86,8 +86,13 @@ class RebaseStateManager(Protocol):
 class DiskRebaseStateManager:
     """Manages rebase state on disk using same files as C Git."""
 
-    def __init__(self, repo: Repo) -> None:
-        self.repo = repo
+    def __init__(self, path: str) -> None:
+        """Initialize disk rebase state manager.
+
+        Args:
+            path: Path to the rebase-merge directory
+        """
+        self.path = path
 
     def save(
         self,
@@ -98,40 +103,44 @@ class DiskRebaseStateManager:
         done: list[Commit],
     ) -> None:
         """Save rebase state to disk."""
-        # For file-based repos, ensure the directory exists
-        if hasattr(self.repo, "controldir"):
-            import os
+        import os
 
-            rebase_dir = os.path.join(self.repo.controldir(), "rebase-merge")
-            if not os.path.exists(rebase_dir):
-                os.makedirs(rebase_dir)
+        # Ensure the directory exists
+        os.makedirs(self.path, exist_ok=True)
 
         # Store the original HEAD ref (e.g. "refs/heads/feature")
         if original_head:
-            self.repo._put_named_file("rebase-merge/orig-head", original_head)
+            self._write_file("orig-head", original_head)
 
         # Store the branch name being rebased
         if rebasing_branch:
-            self.repo._put_named_file("rebase-merge/head-name", rebasing_branch)
+            self._write_file("head-name", rebasing_branch)
 
         # Store the commit we're rebasing onto
         if onto:
-            self.repo._put_named_file("rebase-merge/onto", onto)
+            self._write_file("onto", onto)
 
         # Track progress
         if todo:
             # Store the current commit being rebased (same as C Git)
             current_commit = todo[0]
-            self.repo._put_named_file("rebase-merge/stopped-sha", current_commit.id)
+            self._write_file("stopped-sha", current_commit.id)
 
             # Store progress counters
             msgnum = len(done) + 1  # Current commit number (1-based)
             end = len(done) + len(todo)  # Total number of commits
-            self.repo._put_named_file("rebase-merge/msgnum", str(msgnum).encode())
-            self.repo._put_named_file("rebase-merge/end", str(end).encode())
+            self._write_file("msgnum", str(msgnum).encode())
+            self._write_file("end", str(end).encode())
 
         # TODO: Add support for writing git-rebase-todo for interactive rebase
 
+    def _write_file(self, name: str, content: bytes) -> None:
+        """Write content to a file in the rebase directory."""
+        import os
+
+        with open(os.path.join(self.path, name), "wb") as f:
+            f.write(content)
+
     def load(
         self,
     ) -> tuple[
@@ -149,57 +158,39 @@ class DiskRebaseStateManager:
         done: list[Commit] = []
 
         # Load rebase state files
-        orig_head_file = self.repo.get_named_file("rebase-merge/orig-head")
-        if orig_head_file:
-            original_head = orig_head_file.read().strip()
-            orig_head_file.close()
-
-        head_name_file = self.repo.get_named_file("rebase-merge/head-name")
-        if head_name_file:
-            rebasing_branch = head_name_file.read().strip()
-            head_name_file.close()
-
-        onto_file = self.repo.get_named_file("rebase-merge/onto")
-        if onto_file:
-            onto = onto_file.read().strip()
-            onto_file.close()
+        original_head = self._read_file("orig-head")
+        rebasing_branch = self._read_file("head-name")
+        onto = self._read_file("onto")
 
         # TODO: Load todo list and done list for resuming rebase
 
         return original_head, rebasing_branch, onto, todo, done
 
+    def _read_file(self, name: str) -> Optional[bytes]:
+        """Read content from a file in the rebase directory."""
+        import os
+
+        try:
+            with open(os.path.join(self.path, name), "rb") as f:
+                return f.read().strip()
+        except FileNotFoundError:
+            return None
+
     def clean(self) -> None:
         """Clean up rebase state files."""
-        # Clean up rebase state files matching C Git
-        for filename in [
-            "rebase-merge/stopped-sha",
-            "rebase-merge/orig-head",
-            "rebase-merge/onto",
-            "rebase-merge/head-name",
-            "rebase-merge/msgnum",
-            "rebase-merge/end",
-        ]:
-            self.repo._del_named_file(filename)
-
-        # For file-based repos, remove the directory
-        if hasattr(self.repo, "controldir"):
-            import os
-            import shutil
-
-            rebase_dir = os.path.join(self.repo.controldir(), "rebase-merge")
-            try:
-                shutil.rmtree(rebase_dir)
-            except FileNotFoundError:
-                # Directory doesn't exist, that's ok
-                pass
+        import shutil
+
+        try:
+            shutil.rmtree(self.path)
+        except FileNotFoundError:
+            # Directory doesn't exist, that's ok
+            pass
 
     def exists(self) -> bool:
         """Check if rebase state exists."""
-        f = self.repo.get_named_file("rebase-merge/orig-head")
-        if f:
-            f.close()
-            return True
-        return False
+        import os
+
+        return os.path.exists(os.path.join(self.path, "orig-head"))
 
 
 class MemoryRebaseStateManager:
@@ -270,11 +261,11 @@ class Rebaser:
         self._state_manager = repo.get_rebase_state_manager()
 
         # Initialize state
-        self._original_head = None
+        self._original_head: Optional[bytes] = None
         self._onto = None
-        self._todo = []
-        self._done = []
-        self._rebasing_branch = None
+        self._todo: list[Commit] = []
+        self._done: list[Commit] = []
+        self._rebasing_branch: Optional[bytes] = None
 
         # Load any existing rebase state
         self._load_rebase_state()
@@ -350,7 +341,7 @@ class Rebaser:
         if conflicts:
             # Store merge state for conflict resolution
             self.repo._put_named_file("rebase-merge/stopped-sha", commit.id)
-            return None, conflicts
+            return commit.id, conflicts
 
         # Create new commit
         new_commit = Commit()
@@ -399,7 +390,9 @@ class Rebaser:
                 self._rebasing_branch = b"refs/heads/" + branch
         else:
             # Use current branch
-            if self._original_head.startswith(b"ref: "):
+            if self._original_head is not None and self._original_head.startswith(
+                b"ref: "
+            ):
                 self._rebasing_branch = self._original_head[5:]
             else:
                 self._rebasing_branch = None
@@ -428,7 +421,8 @@ class Rebaser:
             None if rebase is complete, or tuple of (commit_sha, conflicts) for next commit
         """
         if not self._todo:
-            return self._finish_rebase()
+            self._finish_rebase()
+            return None
 
         # Get next commit to rebase
         commit = self._todo.pop(0)
@@ -451,7 +445,8 @@ class Rebaser:
             if self._todo:
                 return self.continue_()
             else:
-                return self._finish_rebase()
+                self._finish_rebase()
+                return None
         else:
             # Conflicts - save state and return
             self._save_rebase_state()
@@ -482,7 +477,7 @@ class Rebaser:
         """Finish rebase by updating HEAD and cleaning up."""
         if not self._done:
             # No commits were rebased
-            return None
+            return
 
         # Update HEAD to point to last rebased commit
         last_commit = self._done[-1]
@@ -509,8 +504,6 @@ class Rebaser:
         self._onto = None
         self._todo = []
 
-        return None
-
     def _save_rebase_state(self) -> None:
         """Save rebase state to allow resuming."""
         self._state_manager.save(

+ 4 - 1
dulwich/repo.py

@@ -1725,9 +1725,12 @@ class Repo(BaseRepo):
 
         Returns: DiskRebaseStateManager instance
         """
+        import os
+
         from .rebase import DiskRebaseStateManager
 
-        return DiskRebaseStateManager(self)
+        path = os.path.join(self.controldir(), "rebase-merge")
+        return DiskRebaseStateManager(path)
 
     def get_description(self):
         """Retrieve the description of this repository.

+ 1 - 3
tests/test_rebase.py

@@ -130,9 +130,7 @@ class RebaserTestCase(TestCase):
 
         # Perform rebase
         rebaser = Rebaser(self.repo)
-        commits = rebaser.start(
-            b"refs/heads/master", branch=b"refs/heads/feature"
-        )
+        commits = rebaser.start(b"refs/heads/master", branch=b"refs/heads/feature")
 
         self.assertEqual(len(commits), 1)
         self.assertEqual(commits[0].id, feature_commit.id)