Преглед изворни кода

Add --amend support to dulwich.cli.commit and dulwich.porcelain.commit

Jelmer Vernooij пре 1 месец
родитељ
комит
e24822349f
4 измењених фајлова са 153 додато и 30 уклоњено
  1. 3 0
      NEWS
  2. 48 20
      dulwich/cli.py
  3. 42 10
      dulwich/porcelain.py
  4. 60 0
      tests/test_porcelain.py

+ 3 - 0
NEWS

@@ -3,6 +3,9 @@
  * Add support for ``-a`` argument to
    ``dulwich.cli.commit``. (Jelmer Vernooij)
 
+ * Add support for ``--amend`` argument to
+   ``dulwich.cli.commit`` and ``dulwich.porcelain.commit``. (Jelmer Vernooij)
+
  * Add support for merge drivers.
    (Jelmer Vernooij)
 

+ 48 - 20
dulwich/cli.py

@@ -759,32 +759,37 @@ class cmd_clone(Command):
             print(f"{e}")
 
 
-def _get_commit_message(repo, commit):
-    # Prepare a template
-    template = b"\n"
+def _get_commit_message_with_template(initial_message, repo=None, commit=None):
+    """Get commit message with an initial message template."""
+    # Start with the initial message
+    template = initial_message
+    if template and not template.endswith(b"\n"):
+        template += b"\n"
+
+    template += b"\n"
     template += b"# Please enter the commit message for your changes. Lines starting\n"
     template += b"# with '#' will be ignored, and an empty message aborts the commit.\n"
     template += b"#\n"
-    try:
-        ref_names, ref_sha = repo.refs.follow(b"HEAD")
-        ref_path = ref_names[-1]  # Get the final reference
-        if ref_path.startswith(b"refs/heads/"):
-            branch = ref_path[11:]  # Remove 'refs/heads/' prefix
-        else:
-            branch = ref_path
-        template += b"# On branch %s\n" % branch
-    except (KeyError, IndexError):
-        template += b"# On branch (unknown)\n"
-    template += b"#\n"
+
+    # Add branch info if repo is provided
+    if repo:
+        try:
+            ref_names, ref_sha = repo.refs.follow(b"HEAD")
+            ref_path = ref_names[-1]  # Get the final reference
+            if ref_path.startswith(b"refs/heads/"):
+                branch = ref_path[11:]  # Remove 'refs/heads/' prefix
+            else:
+                branch = ref_path
+            template += b"# On branch %s\n" % branch
+        except (KeyError, IndexError):
+            template += b"# On branch (unknown)\n"
+        template += b"#\n"
+
     template += b"# Changes to be committed:\n"
 
     # Launch editor
     content = launch_editor(template)
 
-    # Check if content was unchanged
-    if content == template:
-        raise CommitMessageError("Aborting commit due to unchanged commit message")
-
     # Remove comment lines and strip
     lines = content.split(b"\n")
     message_lines = [line for line in lines if not line.strip().startswith(b"#")]
@@ -806,17 +811,40 @@ class cmd_commit(Command):
             action="store_true",
             help="Automatically stage all tracked files that have been modified",
         )
+        parser.add_argument(
+            "--amend",
+            action="store_true",
+            help="Replace the tip of the current branch by creating a new commit",
+        )
         args = parser.parse_args(args)
 
         message: Union[bytes, str, Callable]
 
         if args.message:
             message = args.message
+        elif args.amend:
+            # For amend, create a callable that opens editor with original message pre-populated
+            def get_amend_message(repo, commit):
+                # Get the original commit message from current HEAD
+                try:
+                    head_commit = repo[repo.head()]
+                    original_message = head_commit.message
+                except KeyError:
+                    original_message = b""
+
+                # Open editor with original message
+                return _get_commit_message_with_template(original_message, repo, commit)
+
+            message = get_amend_message
         else:
-            message = _get_commit_message
+            # For regular commits, use empty template
+            def get_regular_message(repo, commit):
+                return _get_commit_message_with_template(b"", repo, commit)
+
+            message = get_regular_message
 
         try:
-            porcelain.commit(".", message=message, all=args.all)
+            porcelain.commit(".", message=message, all=args.all, amend=args.amend)
         except CommitMessageError as e:
             print(f"error: {e}", file=sys.stderr)
             return 1

+ 42 - 10
dulwich/porcelain.py

@@ -475,6 +475,7 @@ def commit(
     no_verify=False,
     signoff=False,
     all=False,
+    amend=False,
 ):
     """Create a new commit.
 
@@ -491,6 +492,7 @@ def commit(
         pass True to use default GPG key,
         pass a str containing Key ID to use a specific GPG key)
       all: Automatically stage all tracked files that have been modified
+      amend: Replace the tip of the current branch by creating a new commit
     Returns: SHA1 of the new commit
     """
     if getattr(message, "encode", None):
@@ -506,6 +508,25 @@ def commit(
         commit_timezone = local_timezone[1]
 
     with open_repo_closing(repo) as r:
+        # Handle amend logic
+        merge_heads = None
+        if amend:
+            try:
+                head_commit = r[r.head()]
+            except KeyError:
+                raise ValueError("Cannot amend: no existing commit found")
+
+            # If message not provided, use the message from the current HEAD
+            if message is None:
+                message = head_commit.message
+            # If author not provided, use the author from the current HEAD
+            if author is None:
+                author = head_commit.author
+                if author_timezone is None:
+                    author_timezone = head_commit.author_timezone
+            # Use the parent(s) of the current HEAD as our parent(s)
+            merge_heads = list(head_commit.parents)
+
         # If -a flag is used, stage all modified tracked files
         if all:
             index = r.open_index()
@@ -525,16 +546,27 @@ def commit(
 
                 add(r, paths=modified_files)
 
-        return r.do_commit(
-            message=message,
-            author=author,
-            author_timezone=author_timezone,
-            committer=committer,
-            commit_timezone=commit_timezone,
-            encoding=encoding,
-            no_verify=no_verify,
-            sign=signoff if isinstance(signoff, (str, bool)) else None,
-        )
+        commit_kwargs = {
+            "message": message,
+            "author": author,
+            "author_timezone": author_timezone,
+            "committer": committer,
+            "commit_timezone": commit_timezone,
+            "encoding": encoding,
+            "no_verify": no_verify,
+            "sign": signoff if isinstance(signoff, (str, bool)) else None,
+            "merge_heads": merge_heads,
+        }
+
+        # For amend, create dangling commit to avoid adding current HEAD as parent
+        if amend:
+            commit_kwargs["ref"] = None
+            commit_sha = r.do_commit(**commit_kwargs)
+            # Update HEAD to point to the new commit
+            r.refs[b"HEAD"] = commit_sha
+            return commit_sha
+        else:
+            return r.do_commit(**commit_kwargs)
 
 
 def commit_tree(repo, tree, message=None, author=None, committer=None):

+ 60 - 0
tests/test_porcelain.py

@@ -545,6 +545,66 @@ class CommitTests(PorcelainTestCase):
         self.assertIn(b"file1.txt", tree)
         self.assertIn(b"file2.txt", tree)
 
+    def test_commit_amend_message(self) -> None:
+        # Create initial commit
+        filename = os.path.join(self.repo.path, "test.txt")
+        with open(filename, "wb") as f:
+            f.write(b"initial content")
+        porcelain.add(self.repo.path, paths=["test.txt"])
+        original_sha = porcelain.commit(self.repo.path, message=b"Original commit")
+
+        # Amend with new message
+        amended_sha = porcelain.commit(
+            self.repo.path, message=b"Amended commit", amend=True
+        )
+
+        self.assertIsInstance(amended_sha, bytes)
+        self.assertEqual(len(amended_sha), 40)
+        self.assertNotEqual(amended_sha, original_sha)
+
+        # Check that the amended commit has the new message
+        amended_commit = self.repo.get_object(amended_sha)
+        assert isinstance(amended_commit, Commit)
+        self.assertEqual(amended_commit.message, b"Amended commit")
+
+        # Check that the amended commit uses the original commit's parents
+        original_commit = self.repo.get_object(original_sha)
+        assert isinstance(original_commit, Commit)
+        # Since this was the first commit, it should have no parents,
+        # and the amended commit should also have no parents
+        self.assertEqual(amended_commit.parents, original_commit.parents)
+
+    def test_commit_amend_no_message(self) -> None:
+        # Create initial commit
+        filename = os.path.join(self.repo.path, "test.txt")
+        with open(filename, "wb") as f:
+            f.write(b"initial content")
+        porcelain.add(self.repo.path, paths=["test.txt"])
+        original_sha = porcelain.commit(self.repo.path, message=b"Original commit")
+
+        # Modify file and stage it
+        with open(filename, "wb") as f:
+            f.write(b"modified content")
+        porcelain.add(self.repo.path, paths=["test.txt"])
+
+        # Amend without providing message (should reuse original message)
+        amended_sha = porcelain.commit(self.repo.path, amend=True)
+
+        self.assertIsInstance(amended_sha, bytes)
+        self.assertEqual(len(amended_sha), 40)
+        self.assertNotEqual(amended_sha, original_sha)
+
+        # Check that the amended commit has the original message
+        amended_commit = self.repo.get_object(amended_sha)
+        assert isinstance(amended_commit, Commit)
+        self.assertEqual(amended_commit.message, b"Original commit")
+
+    def test_commit_amend_no_existing_commit(self) -> None:
+        # Try to amend when there's no existing commit
+        with self.assertRaises(ValueError) as cm:
+            porcelain.commit(self.repo.path, message=b"Should fail", amend=True)
+        self.assertIn("Cannot amend: no existing commit found", str(cm.exception))
+
 
 @skipIf(
     platform.python_implementation() == "PyPy" or sys.platform == "win32",