Bläddra i källkod

add --column to branch command (#1876)

Implements the --column flag for the dulwich branch command, 
which display branch list in columns.
https://github.com/jelmer/dulwich/issues/1847
Jelmer Vernooij 4 månader sedan
förälder
incheckning
be9830c51c
3 ändrade filer med 299 tillägg och 22 borttagningar
  1. 19 0
      CONTRIBUTING.rst
  2. 106 21
      dulwich/cli.py
  3. 174 1
      tests/test_cli.py

+ 19 - 0
CONTRIBUTING.rst

@@ -97,6 +97,25 @@ dulwich package. This will ensure that the deprecation is handled correctly:
 * Users can use `dissolve migrate` to automatically replace deprecated
   functionality in their code
 
+Tests
+~~~~~
+
+Dulwich has two kinds of tests:
+
+* Unit tests, which test individual functions and classes
+* Compatibility tests, which test that Dulwich behaves in a way that is
+    compatible with C Git
+
+The former should never invoke C Git, while the latter may do so. This is
+to ensure that it is possible to run the unit tests in an environment
+where C Git is not available.
+
+Tests should not depend on the internet, or any other external services.
+
+Avoid using mocks if at all possible; rather, design your code to be easily
+testable without them. If you do need to use mocks, please use the
+``unittest.mock`` module.
+
 Running the tests
 -----------------
 To run the testsuite, you should be able to run ``dulwich.tests.test_suite``.

+ 106 - 21
dulwich/cli.py

@@ -36,8 +36,9 @@ import signal
 import subprocess
 import sys
 import tempfile
+from collections.abc import Iterator
 from pathlib import Path
-from typing import BinaryIO, Callable, ClassVar, Optional, Union
+from typing import BinaryIO, Callable, ClassVar, Optional, TextIO, Union
 
 from dulwich import porcelain
 
@@ -190,6 +191,90 @@ def launch_editor(template_content: bytes = b"") -> bytes:
         os.unlink(temp_file)
 
 
+def detect_terminal_width() -> int:
+    """Detect the width of the terminal.
+
+    Returns:
+        Width of the terminal in characters, or 80 if it cannot be determined
+    """
+    try:
+        return os.get_terminal_size().columns
+    except OSError:
+        return 80
+
+
+def write_columns(
+    items: Union[Iterator[bytes], list[bytes]], out: TextIO, width: Optional[int] = None
+) -> None:
+    """Display items in formatted columns based on terminal width.
+
+    Args:
+        items: List or iterator of bytes objects to display in columns
+        out: Output stream to write to
+        width: Optional width of the terminal (if None, auto-detect)
+
+    The function calculates the optimal number of columns to fit the terminal
+    width and displays the items in a formatted column layout with proper
+    padding and alignment.
+    """
+    if width is None:
+        ter_width = detect_terminal_width()
+    else:
+        ter_width = width
+
+    item_names = [item.decode() for item in items]
+
+    def columns(names, width, num_cols):
+        if num_cols <= 0:
+            return False, []
+
+        num_rows = (len(names) + num_cols - 1) // num_cols
+        col_widths = []
+
+        for col in range(num_cols):
+            max_width = 0
+            for row in range(num_rows):
+                idx = row + col * num_rows
+                if idx < len(names):
+                    max_width = max(max_width, len(names[idx]))
+            col_widths.append(max_width + 2)  # add padding
+
+        total_width = sum(col_widths)
+        if total_width <= width:
+            return True, col_widths
+        return False, []
+
+    best_cols = 1
+    best_widths = []
+
+    for num_cols in range(min(8, len(item_names)), 0, -1):
+        fits, widths = columns(item_names, ter_width, num_cols)
+        if fits:
+            best_cols = num_cols
+            best_widths = widths
+            break
+
+    if not best_widths:
+        best_cols = 1
+        best_widths = [max(len(name) for name in item_names) + 2]
+
+    num_rows = (len(item_names) + best_cols - 1) // best_cols
+
+    for row in range(num_rows):
+        lines = []
+        for col in range(best_cols):
+            idx = row + col * num_rows
+            if idx < len(item_names):
+                branch_name = item_names[idx]
+                if col < len(best_widths):
+                    lines.append(branch_name.ljust(best_widths[col]))
+                else:
+                    lines.append(branch_name)
+
+        if lines:
+            out.write("".join(lines).rstrip() + "\n")
+
+
 class PagerBuffer:
     """Binary buffer wrapper for Pager to mimic sys.stdout.buffer."""
 
@@ -2076,17 +2161,26 @@ class cmd_branch(Command):
             const="HEAD",
             help="List branches that contain a specific commit",
         )
+        parser.add_argument(
+            "--column", action="store_true", help="Display branch list in columns"
+        )
         args = parser.parse_args(args)
 
+        def print_branches(
+            branches: Union[Iterator[bytes], list[bytes]], use_columns=False
+        ) -> None:
+            if use_columns:
+                write_columns(branches, sys.stdout)
+            else:
+                for branch in branches:
+                    sys.stdout.write(f"{branch.decode()}\n")
+
         if args.all:
             try:
                 branches = porcelain.branch_list(".") + porcelain.branch_remotes_list(
                     "."
                 )
-
-                for branch in branches:
-                    sys.stdout.write(f"{branch.decode()}\n")
-
+                print_branches(branches, args.column)
                 return 0
 
             except porcelain.Error as e:
@@ -2096,11 +2190,9 @@ class cmd_branch(Command):
         if args.merged:
             try:
                 branches_iter = porcelain.merged_branches(".")
-
-                for branch in branches_iter:
-                    sys.stdout.write(f"{branch.decode()}\n")
-
+                print_branches(branches_iter, args.column)
                 return 0
+
             except porcelain.Error as e:
                 sys.stderr.write(f"{e}")
                 return 1
@@ -2108,11 +2200,9 @@ class cmd_branch(Command):
         if args.no_merged:
             try:
                 branches_iter = porcelain.no_merged_branches(".")
-
-                for branch in branches_iter:
-                    sys.stdout.write(f"{branch.decode()}\n")
-
+                print_branches(branches_iter, args.column)
                 return 0
+
             except porcelain.Error as e:
                 sys.stderr.write(f"{e}")
                 return 1
@@ -2120,10 +2210,7 @@ class cmd_branch(Command):
         if args.contains:
             try:
                 branches_iter = porcelain.branches_containing(".", commit=args.contains)
-
-                for branch in branches_iter:
-                    sys.stdout.write(f"{branch.decode()}\n")
-
+                print_branches(branches_iter, args.column)
                 return 0
 
             except KeyError as e:
@@ -2137,11 +2224,9 @@ class cmd_branch(Command):
         if args.remotes:
             try:
                 branches = porcelain.branch_remotes_list(".")
-
-                for branch in branches:
-                    sys.stdout.write(f"{branch.decode()}\n")
-
+                print_branches(branches, args.column)
                 return 0
+
             except porcelain.Error as e:
                 sys.stderr.write(f"{e}")
                 return 1

+ 174 - 1
tests/test_cli.py

@@ -33,7 +33,13 @@ from unittest import skipIf
 from unittest.mock import MagicMock, patch
 
 from dulwich import cli
-from dulwich.cli import format_bytes, launch_editor, parse_relative_time
+from dulwich.cli import (
+    detect_terminal_width,
+    format_bytes,
+    launch_editor,
+    parse_relative_time,
+    write_columns,
+)
 from dulwich.repo import Repo
 from dulwich.tests.utils import (
     build_commit_graph,
@@ -629,6 +635,173 @@ class BranchCommandTest(DulwichCliTestCase):
         self.assertNotEqual(result, 0)
         self.assertIn("error: object name invalid123 not found", stderr)
 
+    def test_branch_list_column(self):
+        """Test branch --column formatting"""
+        # Create initial commit
+        test_file = os.path.join(self.repo_path, "test.txt")
+        with open(test_file, "w") as f:
+            f.write("test")
+        self._run_cli("add", "test.txt")
+        self._run_cli("commit", "--message=Initial")
+
+        self._run_cli("branch", "feature-1")
+        self._run_cli("branch", "feature-2")
+        self._run_cli("branch", "feature-3")
+
+        # Run branch --column
+        result, stdout, stderr = self._run_cli("branch", "--all", "--column")
+        self.assertEqual(result, 0)
+
+        expected = ["feature-1", "feature-2", "feature-3"]
+
+        for branch in expected:
+            self.assertIn(branch, stdout)
+
+        multiple_columns = any(
+            sum(branch in line for branch in expected) > 1
+            for line in stdout.strip().split("\n")
+        )
+        self.assertTrue(multiple_columns)
+
+
+class TestTerminalWidth(TestCase):
+    @patch("os.get_terminal_size")
+    def test_terminal_size(self, mock_get_terminal_size):
+        """Test os.get_terminal_size mocking."""
+        mock_get_terminal_size.return_value.columns = 100
+        width = detect_terminal_width()
+        self.assertEqual(width, 100)
+
+    @patch("os.get_terminal_size")
+    def test_terminal_size_os_error(self, mock_get_terminal_size):
+        """Test os.get_terminal_size raising OSError."""
+        mock_get_terminal_size.side_effect = OSError("No terminal")
+        width = detect_terminal_width()
+        self.assertEqual(width, 80)
+
+
+class TestWriteColumns(TestCase):
+    """Tests for write_columns function"""
+
+    def test_basic_functionality(self):
+        """Test basic functionality with default terminal width."""
+        out = io.StringIO()
+        items = [b"main", b"dev", b"feature/branch-1"]
+        write_columns(items, out, width=80)
+
+        output_text = out.getvalue()
+        self.assertEqual(output_text, "main  dev  feature/branch-1\n")
+
+    def test_narrow_terminal_single_column(self):
+        """Test with narrow terminal forcing single column."""
+        out = io.StringIO()
+
+        items = [b"main", b"dev", b"feature/branch-1"]
+        write_columns(items, out, 20)
+
+        self.assertEqual(out.getvalue(), "main\ndev\nfeature/branch-1\n")
+
+    def test_wide_terminal_multiple_columns(self):
+        """Test with wide terminal allowing multiple columns."""
+        out = io.StringIO()
+        items = [
+            b"main",
+            b"dev",
+            b"feature/branch-1",
+            b"feature/branch-2",
+            b"feature/branch-3",
+        ]
+        write_columns(items, out, 120)
+
+        output_text = out.getvalue()
+        self.assertEqual(
+            output_text,
+            "main  dev  feature/branch-1  feature/branch-2  feature/branch-3\n",
+        )
+
+    def test_single_item(self):
+        """Test with single item."""
+        out = io.StringIO()
+        write_columns([b"single"], out, 80)
+
+        output_text = out.getvalue()
+        self.assertEqual("single\n", output_text)
+        self.assertTrue(output_text.endswith("\n"))
+
+    def test_os_error_fallback(self):
+        """Test fallback behavior when os.get_terminal_size raises OSError."""
+        with patch("os.get_terminal_size", side_effect=OSError("No terminal")):
+            out = io.StringIO()
+            items = [b"main", b"dev"]
+            write_columns(items, out)
+
+            output_text = out.getvalue()
+            # With default width (80), should display in columns
+            self.assertEqual(output_text, "main  dev\n")
+
+    def test_iterator_input(self):
+        """Test with iterator input instead of list."""
+        out = io.StringIO()
+        items = [b"main", b"dev", b"feature/branch-1"]
+        items_iterator = iter(items)
+        write_columns(items_iterator, out, 80)
+
+        output_text = out.getvalue()
+        self.assertEqual(output_text, "main  dev  feature/branch-1\n")
+
+    def test_column_alignment(self):
+        """Test that columns are properly aligned."""
+        out = io.StringIO()
+        items = [b"short", b"medium_length", b"very_long______name"]
+        write_columns(items, out, 50)
+
+        output_text = out.getvalue()
+        self.assertEqual(output_text, "short  medium_length  very_long______name\n")
+
+    def test_columns_formatting(self):
+        """Test that items are formatted in columns within single line."""
+        out = io.StringIO()
+        items = [b"branch-1", b"branch-2", b"branch-3", b"branch-4", b"branch-5"]
+        write_columns(items, out, 80)
+
+        output_text = out.getvalue()
+
+        self.assertEqual(output_text.count("\n"), 1)
+        self.assertTrue(output_text.endswith("\n"))
+
+        line = output_text.strip()
+        for item in items:
+            self.assertIn(item.decode(), line)
+
+    def test_column_alignment_multiple_lines(self):
+        """Test that columns are properly aligned across multiple lines."""
+        items = [
+            b"short",
+            b"medium_length",
+            b"very_long_branch_name",
+            b"another",
+            b"more",
+            b"even_longer_branch_name_here",
+        ]
+
+        out = io.StringIO()
+
+        write_columns(items, out, width=60)
+
+        output_text = out.getvalue()
+        lines = output_text.strip().split("\n")
+
+        self.assertGreater(len(lines), 1)
+
+        line_lengths = [len(line) for line in lines if line.strip()]
+
+        for length in line_lengths:
+            self.assertLessEqual(length, 60)
+
+        all_output = " ".join(lines)
+        for item in items:
+            self.assertIn(item.decode(), all_output)
+
 
 class CheckoutCommandTest(DulwichCliTestCase):
     """Tests for checkout command."""