2
0
Эх сурвалжийг харах

write_columns: Refactor to take arguments for output stream and width

Also, assert the *exact* output matches
Jelmer Vernooij 4 сар өмнө
parent
commit
8f2db3d743
2 өөрчлөгдсөн 129 нэмэгдсэн , 131 устгасан
  1. 24 8
      dulwich/cli.py
  2. 105 123
      tests/test_cli.py

+ 24 - 8
dulwich/cli.py

@@ -38,7 +38,7 @@ import sys
 import tempfile
 from collections.abc import Iterator
 from pathlib import Path
-from typing import BinaryIO, Callable, ClassVar, Optional, Union
+from typing import BinaryIO, Callable, ClassVar, Optional, TextIO, Union
 
 from dulwich import porcelain
 
@@ -191,20 +191,36 @@ def launch_editor(template_content: bytes = b"") -> bytes:
         os.unlink(temp_file)
 
 
-def write_columns(items: Union[Iterator[bytes], list[bytes]]) -> None:
+def detect_terminal_width() -> int:
+    """Detect the width of the terminal.
+
+    Returns:
+        Width of the terminal in characters, or 80 if it cannot be determined
+    """
+    try:
+        return os.get_terminal_size().columns
+    except OSError:
+        return 80
+
+
+def write_columns(
+    items: Union[Iterator[bytes], list[bytes]], out: TextIO, width: Optional[int] = None
+) -> None:
     """Display items in formatted columns based on terminal width.
 
     Args:
         items: List or iterator of bytes objects to display in columns
+        out: Output stream to write to
+        width: Optional width of the terminal (if None, auto-detect)
 
     The function calculates the optimal number of columns to fit the terminal
     width and displays the items in a formatted column layout with proper
     padding and alignment.
     """
-    try:
-        ter_width = os.get_terminal_size().columns
-    except OSError:
-        ter_width = 80
+    if width is None:
+        ter_width = detect_terminal_width()
+    else:
+        ter_width = width
 
     item_names = [item.decode() for item in items]
 
@@ -256,7 +272,7 @@ def write_columns(items: Union[Iterator[bytes], list[bytes]]) -> None:
                     lines.append(branch_name)
 
         if lines:
-            sys.stdout.write("".join(lines).rstrip() + "\n")
+            out.write("".join(lines).rstrip() + "\n")
 
 
 class PagerBuffer:
@@ -2154,7 +2170,7 @@ class cmd_branch(Command):
             branches: Union[Iterator[bytes], list[bytes]], use_columns=False
         ) -> None:
             if use_columns:
-                write_columns(branches)
+                write_columns(branches, sys.stdout)
             else:
                 for branch in branches:
                     sys.stdout.write(f"{branch.decode()}\n")

+ 105 - 123
tests/test_cli.py

@@ -33,7 +33,13 @@ from unittest import skipIf
 from unittest.mock import MagicMock, patch
 
 from dulwich import cli
-from dulwich.cli import format_bytes, launch_editor, parse_relative_time, write_columns
+from dulwich.cli import (
+    detect_terminal_width,
+    format_bytes,
+    launch_editor,
+    parse_relative_time,
+    write_columns,
+)
 from dulwich.repo import Repo
 from dulwich.tests.utils import (
     build_commit_graph,
@@ -658,167 +664,143 @@ class BranchCommandTest(DulwichCliTestCase):
         self.assertTrue(multiple_columns)
 
 
-class TestWriteColumns(TestCase):
-    """Tests for write_columns function"""
-
-    def setUp(self) -> None:
-        super().setUp()
-        self.original_stdout = sys.stdout
-        self.original_get_terminal_size = os.get_terminal_size
-
-    def tearDown(self):
-        super().tearDown()
-        sys.stdout = self.original_stdout
-        os.get_terminal_size = self.original_get_terminal_size
+class TestTerminalWidth(TestCase):
+    @patch("os.get_terminal_size")
+    def test_terminal_size(self, mock_get_terminal_size):
+        """Test os.get_terminal_size mocking."""
+        mock_get_terminal_size.return_value.columns = 100
+        width = detect_terminal_width()
+        self.assertEqual(width, 100)
 
     @patch("os.get_terminal_size")
-    def test_basic_functionality(self, mock_terminal_size):
-        """Test basic functionality with default terminal width."""
-        mock_terminal_size.return_value.columns = 80
+    def test_terminal_size_os_error(self, mock_get_terminal_size):
+        """Test os.get_terminal_size raising OSError."""
+        mock_get_terminal_size.side_effect = OSError("No terminal")
+        width = detect_terminal_width()
+        self.assertEqual(width, 80)
+
 
-        with patch("sys.stdout.write") as mock_write:
-            items = [b"main", b"dev", b"feature/branch-1"]
-            write_columns(items)
+class TestWriteColumns(TestCase):
+    """Tests for write_columns function"""
 
-            self.assertGreater(mock_write.call_count, 0)
+    def test_basic_functionality(self):
+        """Test basic functionality with default terminal width."""
+        out = io.StringIO()
+        items = [b"main", b"dev", b"feature/branch-1"]
+        write_columns(items, out, width=80)
 
-            output_text = "".join(call.args[0] for call in mock_write.call_args_list)
-            self.assertIn("main", output_text)
-            self.assertIn("dev", output_text)
-            self.assertIn("feature/branch-1", output_text)
+        output_text = out.getvalue()
+        self.assertEqual(output_text, "main  dev  feature/branch-1\n")
 
-    @patch("os.get_terminal_size")
-    def test_narrow_terminal_single_column(self, mock_terminal_size):
+    def test_narrow_terminal_single_column(self):
         """Test with narrow terminal forcing single column."""
-        mock_terminal_size.return_value.columns = 20
+        out = io.StringIO()
 
-        with patch("sys.stdout.write") as mock_write:
-            items = [b"main", b"dev", b"feature/branch-1"]
-            write_columns(items)
+        items = [b"main", b"dev", b"feature/branch-1"]
+        write_columns(items, out, 20)
 
-            output_text = "".join(call.args[0] for call in mock_write.call_args_list)
-            for item in items:
-                self.assertIn(item.decode(), output_text)
+        self.assertEqual(out.getvalue(), "main\ndev\nfeature/branch-1\n")
 
-    @patch("os.get_terminal_size")
-    def test_wide_terminal_multiple_columns(self, mock_terminal_size):
+    def test_wide_terminal_multiple_columns(self):
         """Test with wide terminal allowing multiple columns."""
-        mock_terminal_size.return_value.columns = 120
-
-        with patch("sys.stdout.write") as mock_write:
-            items = [
-                b"main",
-                b"dev",
-                b"feature/branch-1",
-                b"feature/branch-2",
-                b"feature/branch-3",
-            ]
-            write_columns(items)
-
-            output_text = "".join(call.args[0] for call in mock_write.call_args_list)
-            for item in items:
-                self.assertIn(item.decode(), output_text)
+        out = io.StringIO()
+        items = [
+            b"main",
+            b"dev",
+            b"feature/branch-1",
+            b"feature/branch-2",
+            b"feature/branch-3",
+        ]
+        write_columns(items, out, 120)
 
-    @patch("os.get_terminal_size")
-    def test_single_item(self, mock_terminal_size):
-        """Test with single item."""
-        mock_terminal_size.return_value.columns = 80
+        output_text = out.getvalue()
+        self.assertEqual(
+            output_text,
+            "main  dev  feature/branch-1  feature/branch-2  feature/branch-3\n",
+        )
 
-        with patch("sys.stdout.write") as mock_write:
-            write_columns([b"single"])
+    def test_single_item(self):
+        """Test with single item."""
+        out = io.StringIO()
+        write_columns([b"single"], out, 80)
 
-            output_text = "".join(call.args[0] for call in mock_write.call_args_list)
-            self.assertIn("single", output_text)
-            self.assertTrue(output_text.endswith("\n"))
+        output_text = out.getvalue()
+        self.assertEqual("single\n", output_text)
+        self.assertTrue(output_text.endswith("\n"))
 
     def test_os_error_fallback(self):
         """Test fallback behavior when os.get_terminal_size raises OSError."""
         with patch("os.get_terminal_size", side_effect=OSError("No terminal")):
-            with patch("sys.stdout.write") as mock_write:
-                items = [b"main", b"dev"]
-                write_columns(items)
+            out = io.StringIO()
+            items = [b"main", b"dev"]
+            write_columns(items, out)
 
-                output_text = "".join(
-                    call.args[0] for call in mock_write.call_args_list
-                )
-                self.assertIn("main", output_text)
-                self.assertIn("dev", output_text)
+            output_text = out.getvalue()
+            # With default width (80), should display in columns
+            self.assertEqual(output_text, "main  dev\n")
 
-    @patch("os.get_terminal_size")
-    def test_iterator_input(self, mock_terminal_size):
+    def test_iterator_input(self):
         """Test with iterator input instead of list."""
-        mock_terminal_size.return_value.columns = 80
+        out = io.StringIO()
+        items = [b"main", b"dev", b"feature/branch-1"]
+        items_iterator = iter(items)
+        write_columns(items_iterator, out, 80)
 
-        with patch("sys.stdout.write") as mock_write:
-            items = [b"main", b"dev", b"feature/branch-1"]
-            items_iterator = iter(items)
-            write_columns(items_iterator)
+        output_text = out.getvalue()
+        self.assertEqual(output_text, "main  dev  feature/branch-1\n")
 
-            output_text = "".join(call.args[0] for call in mock_write.call_args_list)
-            for item in items:
-                self.assertIn(item.decode(), output_text)
-
-    @patch("os.get_terminal_size")
-    def test_column_alignment(self, mock_terminal_size):
+    def test_column_alignment(self):
         """Test that columns are properly aligned."""
-        mock_terminal_size.return_value.columns = 50
+        out = io.StringIO()
+        items = [b"short", b"medium_length", b"very_long______name"]
+        write_columns(items, out, 50)
 
-        with patch("sys.stdout.write") as mock_write:
-            items = [b"short", b"medium_length", b"very_long______name"]
-            write_columns(items)
+        output_text = out.getvalue()
+        self.assertEqual(output_text, "short  medium_length  very_long______name\n")
 
-            output_text = "".join(call.args[0] for call in mock_write.call_args_list)
-            for item in items:
-                self.assertIn(item.decode(), output_text)
-
-    @patch("os.get_terminal_size")
-    def test_columns_formatting(self, mock_terminal_size):
+    def test_columns_formatting(self):
         """Test that items are formatted in columns within single line."""
-        mock_terminal_size.return_value.columns = 80
+        out = io.StringIO()
+        items = [b"branch-1", b"branch-2", b"branch-3", b"branch-4", b"branch-5"]
+        write_columns(items, out, 80)
 
-        with patch("sys.stdout.write") as mock_write:
-            items = [b"branch-1", b"branch-2", b"branch-3", b"branch-4", b"branch-5"]
-            write_columns(items)
+        output_text = out.getvalue()
 
-            output_text = "".join(call.args[0] for call in mock_write.call_args_list)
+        self.assertEqual(output_text.count("\n"), 1)
+        self.assertTrue(output_text.endswith("\n"))
 
-            self.assertEqual(output_text.count("\n"), 1)
-            self.assertTrue(output_text.endswith("\n"))
+        line = output_text.strip()
+        for item in items:
+            self.assertIn(item.decode(), line)
 
-            line = output_text.strip()
-            for item in items:
-                self.assertIn(item.decode(), line)
-
-    @patch("os.get_terminal_size")
-    def test_column_alignment_multiple_lines(self, mock_terminal_size):
+    def test_column_alignment_multiple_lines(self):
         """Test that columns are properly aligned across multiple lines."""
-        mock_terminal_size.return_value.columns = 60
+        items = [
+            b"short",
+            b"medium_length",
+            b"very_long_branch_name",
+            b"another",
+            b"more",
+            b"even_longer_branch_name_here",
+        ]
 
-        with patch("sys.stdout.write") as mock_write:
-            items = [
-                b"short",
-                b"medium_length",
-                b"very_long_branch_name",
-                b"another",
-                b"more",
-                b"even_longer_branch_name_here",
-            ]
+        out = io.StringIO()
 
-            write_columns(items)
+        write_columns(items, out, width=60)
 
-            output_text = "".join(call.args[0] for call in mock_write.call_args_list)
-            lines = output_text.strip().split("\n")
+        output_text = out.getvalue()
+        lines = output_text.strip().split("\n")
 
-            self.assertGreater(len(lines), 1)
+        self.assertGreater(len(lines), 1)
 
-            line_lengths = [len(line) for line in lines if line.strip()]
+        line_lengths = [len(line) for line in lines if line.strip()]
 
-            for length in line_lengths:
-                self.assertLessEqual(length, mock_terminal_size.return_value.columns)
+        for length in line_lengths:
+            self.assertLessEqual(length, 60)
 
-            all_output = " ".join(lines)
-            for item in items:
-                self.assertIn(item.decode(), all_output)
+        all_output = " ".join(lines)
+        for item in items:
+            self.assertIn(item.decode(), all_output)
 
 
 class CheckoutCommandTest(DulwichCliTestCase):