Kaynağa Gözat

cli: Add support for editing commit message in editor

Jelmer Vernooij 1 ay önce
ebeveyn
işleme
5b153cff4f
5 değiştirilmiş dosya ile 298 ekleme ve 27 silme
  1. 102 8
      dulwich/cli.py
  2. 2 1
      dulwich/porcelain.py
  3. 35 14
      dulwich/repo.py
  4. 77 1
      tests/test_cli.py
  5. 82 3
      tests/test_repository.py

+ 102 - 8
dulwich/cli.py

@@ -34,8 +34,9 @@ import shutil
 import signal
 import subprocess
 import sys
+import tempfile
 from pathlib import Path
-from typing import ClassVar, Optional
+from typing import Callable, ClassVar, Optional, Union
 
 from dulwich import porcelain
 
@@ -49,6 +50,10 @@ from .pack import Pack, sha_to_hex
 from .repo import Repo
 
 
+class CommitMessageError(Exception):
+    """Raised when there's an issue with the commit message."""
+
+
 def signal_int(signal, frame) -> None:
     sys.exit(1)
 
@@ -124,6 +129,37 @@ def format_bytes(bytes):
     return f"{bytes:.1f} TB"
 
 
+def launch_editor(template_content=b""):
+    """Launch an editor for the user to enter text.
+
+    Args:
+        template_content: Initial content for the editor
+
+    Returns:
+        The edited content as bytes
+    """
+    # Determine which editor to use
+    editor = os.environ.get("GIT_EDITOR") or os.environ.get("EDITOR") or "vi"
+
+    # Create a temporary file
+    with tempfile.NamedTemporaryFile(mode="wb", delete=False, suffix=".txt") as f:
+        temp_file = f.name
+        f.write(template_content)
+
+    try:
+        # Launch the editor
+        subprocess.run([editor, temp_file], check=True)
+
+        # Read the edited content
+        with open(temp_file, "rb") as f:
+            content = f.read()
+
+        return content
+    finally:
+        # Clean up the temporary file
+        os.unlink(temp_file)
+
+
 class PagerBuffer:
     """Binary buffer wrapper for Pager to mimic sys.stdout.buffer."""
 
@@ -155,6 +191,7 @@ class Pager:
         self.buffer = PagerBuffer(self)
         self._closed = False
         self.pager_cmd = pager_cmd
+        self._pager_died = False
 
     def _get_pager_command(self) -> str:
         """Get the pager command to use."""
@@ -182,28 +219,33 @@ class Pager:
         if self._closed:
             raise ValueError("I/O operation on closed file")
 
+        # If pager died (user quit), stop writing output
+        if self._pager_died:
+            return len(text)
+
         self._ensure_pager_started()
 
         if self.pager_process and self.pager_process.stdin:
             try:
                 return self.pager_process.stdin.write(text)
             except (OSError, subprocess.SubprocessError, BrokenPipeError):
-                # Pager died, fall back to direct output
-                return sys.stdout.write(text)
+                # Pager died (user quit), stop writing output
+                self._pager_died = True
+                return len(text)
         else:
             # No pager available, write directly to stdout
             return sys.stdout.write(text)
 
     def flush(self):
         """Flush the pager."""
-        if self._closed:
+        if self._closed or self._pager_died:
             return
 
         if self.pager_process and self.pager_process.stdin:
             try:
                 self.pager_process.stdin.flush()
             except (OSError, subprocess.SubprocessError, BrokenPipeError):
-                pass
+                self._pager_died = True
         else:
             sys.stdout.flush()
 
@@ -233,6 +275,8 @@ class Pager:
     # Additional file-like methods for compatibility
     def writelines(self, lines):
         """Write a list of lines to the pager."""
+        if self._pager_died:
+            return
         for line in lines:
             self.write(line)
 
@@ -715,10 +759,47 @@ class cmd_clone(Command):
             print(f"{e}")
 
 
+def _get_commit_message(repo, commit):
+    # Prepare a template
+    template = b"\n"
+    template += b"# Please enter the commit message for your changes. Lines starting\n"
+    template += b"# with '#' will be ignored, and an empty message aborts the commit.\n"
+    template += b"#\n"
+    try:
+        ref_names, ref_sha = repo.refs.follow(b"HEAD")
+        ref_path = ref_names[-1]  # Get the final reference
+        if ref_path.startswith(b"refs/heads/"):
+            branch = ref_path[11:]  # Remove 'refs/heads/' prefix
+        else:
+            branch = ref_path
+        template += b"# On branch %s\n" % branch
+    except (KeyError, IndexError):
+        template += b"# On branch (unknown)\n"
+    template += b"#\n"
+    template += b"# Changes to be committed:\n"
+
+    # Launch editor
+    content = launch_editor(template)
+
+    # Check if content was unchanged
+    if content == template:
+        raise CommitMessageError("Aborting commit due to unchanged commit message")
+
+    # Remove comment lines and strip
+    lines = content.split(b"\n")
+    message_lines = [line for line in lines if not line.strip().startswith(b"#")]
+    message = b"\n".join(message_lines).strip()
+
+    if not message:
+        raise CommitMessageError("Aborting commit due to empty commit message")
+
+    return message
+
+
 class cmd_commit(Command):
-    def run(self, args) -> None:
+    def run(self, args) -> Optional[int]:
         parser = argparse.ArgumentParser()
-        parser.add_argument("--message", "-m", required=True, help="Commit message")
+        parser.add_argument("--message", "-m", help="Commit message")
         parser.add_argument(
             "-a",
             "--all",
@@ -726,7 +807,20 @@ class cmd_commit(Command):
             help="Automatically stage all tracked files that have been modified",
         )
         args = parser.parse_args(args)
-        porcelain.commit(".", message=args.message, all=args.all)
+
+        message: Union[bytes, str, Callable]
+
+        if args.message:
+            message = args.message
+        else:
+            message = _get_commit_message
+
+        try:
+            porcelain.commit(".", message=message, all=args.all)
+        except CommitMessageError as e:
+            print(f"error: {e}", file=sys.stderr)
+            return 1
+        return None
 
 
 class cmd_commit_tree(Command):

+ 2 - 1
dulwich/porcelain.py

@@ -480,7 +480,8 @@ def commit(
 
     Args:
       repo: Path to repository
-      message: Optional commit message
+      message: Optional commit message (string/bytes or callable that takes
+        (repo, commit) and returns bytes)
       author: Optional author name and email
       author_timezone: Author timestamp timezone
       committer: Optional committer name and email

+ 35 - 14
dulwich/repo.py

@@ -1010,7 +1010,8 @@ class BaseRepo:
         and get_user_identity(..., 'AUTHOR') respectively.
 
         Args:
-          message: Commit message
+          message: Commit message (bytes or callable that takes (repo, commit)
+            and returns bytes)
           committer: Committer fullname
           author: Author fullname
           commit_timestamp: Commit timestamp (defaults to now)
@@ -1083,6 +1084,39 @@ class BaseRepo:
                 pass  # No dice
         if encoding is not None:
             c.encoding = encoding
+        # Store original message (might be callable)
+        original_message = message
+        message = None  # Will be set later after parents are set
+
+        # Check if we should sign the commit
+        should_sign = sign
+        if sign is None:
+            # Check commit.gpgSign configuration when sign is not explicitly set
+            config = self.get_config_stack()
+            try:
+                should_sign = config.get_boolean((b"commit",), b"gpgSign")
+            except KeyError:
+                should_sign = False  # Default to not signing if no config
+        keyid = sign if isinstance(sign, str) else None
+
+        if ref is None:
+            # Create a dangling commit
+            c.parents = merge_heads
+        else:
+            try:
+                old_head = self.refs[ref]
+                c.parents = [old_head, *merge_heads]
+            except KeyError:
+                c.parents = merge_heads
+
+        # Handle message after parents are set
+        if callable(original_message):
+            message = original_message(self, c)
+            if message is None:
+                raise ValueError("Message callback returned None")
+        else:
+            message = original_message
+
         if message is None:
             # FIXME: Try to read commit message from .git/MERGE_MSG
             raise ValueError("No commit message specified")
@@ -1099,27 +1133,14 @@ class BaseRepo:
         except KeyError:  # no hook defined, message not modified
             c.message = message
 
-        # Check if we should sign the commit
-        should_sign = sign
-        if sign is None:
-            # Check commit.gpgSign configuration when sign is not explicitly set
-            config = self.get_config_stack()
-            try:
-                should_sign = config.get_boolean((b"commit",), b"gpgSign")
-            except KeyError:
-                should_sign = False  # Default to not signing if no config
-        keyid = sign if isinstance(sign, str) else None
-
         if ref is None:
             # Create a dangling commit
-            c.parents = merge_heads
             if should_sign:
                 c.sign(keyid)
             self.object_store.add_object(c)
         else:
             try:
                 old_head = self.refs[ref]
-                c.parents = [old_head, *merge_heads]
                 if should_sign:
                     c.sign(keyid)
                 self.object_store.add_object(c)

+ 77 - 1
tests/test_cli.py

@@ -33,7 +33,7 @@ from unittest import skipIf
 from unittest.mock import MagicMock, patch
 
 from dulwich import cli
-from dulwich.cli import format_bytes, parse_relative_time
+from dulwich.cli import format_bytes, launch_editor, parse_relative_time
 from dulwich.repo import Repo
 from dulwich.tests.utils import (
     build_commit_graph,
@@ -118,6 +118,25 @@ class InitCommandTest(DulwichCliTestCase):
         self.assertFalse(os.path.exists(os.path.join(bare_repo_path, ".git")))
 
 
+class HelperFunctionsTest(TestCase):
+    """Tests for CLI helper functions."""
+
+    def test_format_bytes(self):
+        self.assertEqual("0.0 B", format_bytes(0))
+        self.assertEqual("100.0 B", format_bytes(100))
+        self.assertEqual("1.0 KB", format_bytes(1024))
+        self.assertEqual("1.5 KB", format_bytes(1536))
+        self.assertEqual("1.0 MB", format_bytes(1024 * 1024))
+        self.assertEqual("1.0 GB", format_bytes(1024 * 1024 * 1024))
+        self.assertEqual("1.0 TB", format_bytes(1024 * 1024 * 1024 * 1024))
+
+    def test_launch_editor_with_cat(self):
+        """Test launch_editor by using cat as the editor."""
+        self.overrideEnv("GIT_EDITOR", "cat")
+        result = launch_editor(b"Test template content")
+        self.assertEqual(b"Test template content", result)
+
+
 class AddCommandTest(DulwichCliTestCase):
     """Tests for add command."""
 
@@ -265,6 +284,63 @@ class CommitCommandTest(DulwichCliTestCase):
         # Verify untracked file still exists
         self.assertTrue(os.path.exists(untracked_file))
 
+    @patch("dulwich.cli.launch_editor")
+    def test_commit_editor_success(self, mock_editor):
+        """Test commit with editor when user provides a message."""
+        # Create and add a file
+        test_file = os.path.join(self.repo_path, "test.txt")
+        with open(test_file, "w") as f:
+            f.write("test content")
+        self._run_cli("add", "test.txt")
+
+        # Mock editor to return a commit message
+        mock_editor.return_value = b"My commit message\n\n# This is a comment\n"
+
+        # Commit without --message flag
+        result, stdout, stderr = self._run_cli("commit")
+
+        # Check that HEAD points to a commit
+        commit = self.repo[self.repo.head()]
+        self.assertEqual(commit.message, b"My commit message")
+
+        # Verify editor was called
+        mock_editor.assert_called_once()
+
+    @patch("dulwich.cli.launch_editor")
+    def test_commit_editor_empty_message(self, mock_editor):
+        """Test commit with editor when user provides empty message."""
+        # Create and add a file
+        test_file = os.path.join(self.repo_path, "test.txt")
+        with open(test_file, "w") as f:
+            f.write("test content")
+        self._run_cli("add", "test.txt")
+
+        # Mock editor to return only comments
+        mock_editor.return_value = b"# All lines are comments\n# No actual message\n"
+
+        # Commit without --message flag should fail with exit code 1
+        result, stdout, stderr = self._run_cli("commit")
+        self.assertEqual(result, 1)
+
+    @patch("dulwich.cli.launch_editor")
+    def test_commit_editor_unchanged_template(self, mock_editor):
+        """Test commit with editor when user doesn't change the template."""
+        # Create and add a file
+        test_file = os.path.join(self.repo_path, "test.txt")
+        with open(test_file, "w") as f:
+            f.write("test content")
+        self._run_cli("add", "test.txt")
+
+        # Mock editor to return the exact template that was passed to it
+        def return_unchanged_template(template):
+            return template
+
+        mock_editor.side_effect = return_unchanged_template
+
+        # Commit without --message flag should fail with exit code 1
+        result, stdout, stderr = self._run_cli("commit")
+        self.assertEqual(result, 1)
+
 
 class LogCommandTest(DulwichCliTestCase):
     """Tests for log command."""

+ 82 - 3
tests/test_repository.py

@@ -1447,9 +1447,88 @@ class BuildRepoRootTests(TestCase):
         new_shas = set(r.object_store) - old_shas
         self.assertEqual(1, len(new_shas))
         # Check that the new commit (now garbage) was added.
-        new_commit = r[new_shas.pop()]
-        self.assertEqual(r[self._root_commit].tree, new_commit.tree)
-        self.assertEqual(b"failed commit", new_commit.message)
+
+    def test_commit_message_callback(self) -> None:
+        """Test commit with a callable message."""
+        r = self._repo
+
+        # Define a callback that generates message based on repo and commit
+        def message_callback(repo, commit):
+            # Verify we get the right objects
+            self.assertEqual(repo, r)
+            self.assertIsNotNone(commit.tree)
+            self.assertIsNotNone(commit.author)
+            self.assertIsNotNone(commit.committer)
+
+            # Generate a message
+            return b"Generated commit for tree " + commit.tree[:8]
+
+        commit_sha = r.do_commit(
+            message_callback,  # Pass the callback as message
+            committer=b"Test Committer <test@nodomain.com>",
+            author=b"Test Author <test@nodomain.com>",
+            commit_timestamp=12345,
+            commit_timezone=0,
+            author_timestamp=12345,
+            author_timezone=0,
+        )
+
+        commit = r[commit_sha]
+        self.assertTrue(commit.message.startswith(b"Generated commit for tree "))
+        self.assertIn(commit.tree[:8], commit.message)
+
+    def test_commit_message_callback_returns_none(self) -> None:
+        """Test commit with callback that returns None."""
+        r = self._repo
+
+        def message_callback(repo, commit):
+            return None
+
+        self.assertRaises(
+            ValueError,
+            r.do_commit,
+            message_callback,
+            committer=b"Test Committer <test@nodomain.com>",
+            author=b"Test Author <test@nodomain.com>",
+            commit_timestamp=12345,
+            commit_timezone=0,
+            author_timestamp=12345,
+            author_timezone=0,
+        )
+
+    def test_commit_message_callback_with_merge_heads(self) -> None:
+        """Test commit with callback for merge commits."""
+        r = self._repo
+
+        # Create two parent commits first
+        parent1 = r.do_commit(
+            b"Parent 1",
+            committer=b"Test Committer <test@nodomain.com>",
+            author=b"Test Author <test@nodomain.com>",
+        )
+
+        parent2 = r.do_commit(
+            b"Parent 2",
+            committer=b"Test Committer <test@nodomain.com>",
+            author=b"Test Author <test@nodomain.com>",
+            ref=None,  # Dangling commit
+        )
+
+        def message_callback(repo, commit):
+            # Verify the commit object has parents set
+            self.assertEqual(2, len(commit.parents))
+            return b"Merge commit with %d parents" % len(commit.parents)
+
+        merge_sha = r.do_commit(
+            message_callback,
+            committer=b"Test Committer <test@nodomain.com>",
+            author=b"Test Author <test@nodomain.com>",
+            merge_heads=[parent2],
+        )
+
+        merge_commit = r[merge_sha]
+        self.assertEqual(b"Merge commit with 2 parents", merge_commit.message)
+        self.assertEqual([parent1, parent2], merge_commit.parents)
 
     def test_commit_branch(self) -> None:
         r = self._repo