Bladeren bron

Honor git configuration when handling pager

Jelmer Vernooij 1 maand geleden
bovenliggende
commit
016ceb31f3
2 gewijzigde bestanden met toevoegingen van 349 en 96 verwijderingen
  1. 158 96
      dulwich/cli.py
  2. 191 0
      tests/test_cli.py

+ 158 - 96
dulwich/cli.py

@@ -149,27 +149,15 @@ class PagerBuffer:
 class Pager:
     """File-like object that pages output through external pager programs."""
 
-    def __init__(self):
+    def __init__(self, pager_cmd="cat"):
         self.pager_process = None
         self.buffer = PagerBuffer(self)
         self._closed = False
+        self.pager_cmd = pager_cmd
 
     def _get_pager_command(self) -> str:
         """Get the pager command to use."""
-        # Priority order: DULWICH_PAGER, GIT_PAGER, PAGER, then fallback
-        for env_var in ["DULWICH_PAGER", "GIT_PAGER", "PAGER"]:
-            pager = os.environ.get(env_var)
-            if pager and pager != "false":
-                return pager
-
-        # Fallback to common pagers
-        for pager in ["less", "more", "cat"]:
-            if shutil.which(pager):
-                if pager == "less":
-                    return "less -FRX"  # -F: quit if one screen, -R: raw control chars, -X: no init/deinit
-                return pager
-
-        return "cat"  # Ultimate fallback
+        return self.pager_cmd
 
     def _ensure_pager_started(self):
         """Start the pager process if not already started."""
@@ -287,9 +275,13 @@ class _StreamContextAdapter:
         return getattr(self.stream, name)
 
 
-def get_pager():
+def get_pager(config=None, cmd_name=None):
     """Get a pager instance if paging should be used, otherwise return sys.stdout.
 
+    Args:
+        config: Optional config instance (e.g., StackedConfig) to read settings from
+        cmd_name: Optional command name for per-command pager settings
+
     Returns:
         Either a wrapped sys.stdout or a Pager instance (both context managers)
     """
@@ -297,17 +289,73 @@ def get_pager():
     if getattr(get_pager, "_disabled", False):
         return _StreamContextAdapter(sys.stdout)
 
-    # Check if paging should be disabled via environment
-    if os.environ.get("DULWICH_PAGER") == "false":
-        return _StreamContextAdapter(sys.stdout)
-    if os.environ.get("GIT_PAGER") == "false":
-        return _StreamContextAdapter(sys.stdout)
-
     # Don't page if stdout is not a terminal
     if not sys.stdout.isatty():
         return _StreamContextAdapter(sys.stdout)
 
-    return Pager()
+    # Priority order for pager command (following git's behavior):
+    # 1. Check pager.<cmd> config (if cmd_name provided)
+    # 2. Check environment variables: DULWICH_PAGER, GIT_PAGER, PAGER
+    # 3. Check core.pager config
+    # 4. Fallback to common pagers
+
+    pager_cmd = None
+
+    # 1. Check per-command pager config (pager.<cmd>)
+    if config and cmd_name:
+        try:
+            pager_value = config.get(
+                ("pager",), cmd_name.encode() if isinstance(cmd_name, str) else cmd_name
+            )
+        except KeyError:
+            pass
+        else:
+            if pager_value == b"false":
+                return _StreamContextAdapter(sys.stdout)
+            elif pager_value != b"true":
+                # It's a custom pager command
+                pager_cmd = (
+                    pager_value.decode()
+                    if isinstance(pager_value, bytes)
+                    else pager_value
+                )
+
+    # 2. Check environment variables
+    if not pager_cmd:
+        for env_var in ["DULWICH_PAGER", "GIT_PAGER", "PAGER"]:
+            pager = os.environ.get(env_var)
+            if pager:
+                if pager == "false":
+                    return _StreamContextAdapter(sys.stdout)
+                pager_cmd = pager
+                break
+
+    # 3. Check core.pager config
+    if not pager_cmd and config:
+        try:
+            core_pager = config.get(("core",), b"pager")
+        except KeyError:
+            pass
+        else:
+            if core_pager == b"false" or core_pager == b"":
+                return _StreamContextAdapter(sys.stdout)
+            pager_cmd = (
+                core_pager.decode() if isinstance(core_pager, bytes) else core_pager
+            )
+
+    # 4. Fallback to common pagers
+    if not pager_cmd:
+        for pager in ["less", "more", "cat"]:
+            if shutil.which(pager):
+                if pager == "less":
+                    pager_cmd = "less -FRX"  # -F: quit if one screen, -R: raw control chars, -X: no init/deinit
+                else:
+                    pager_cmd = pager
+                break
+        else:
+            pager_cmd = "cat"  # Ultimate fallback
+
+    return Pager(pager_cmd)
 
 
 def disable_pager():
@@ -375,12 +423,14 @@ class cmd_annotate(Command):
         parser.add_argument("committish", nargs="?", help="Commit to start from")
         args = parser.parse_args(argv)
 
-        with get_pager() as outstream:
-            results = porcelain.annotate(".", args.path, args.committish)
-            for (commit, entry), line in results:
-                # Show shortened commit hash and line content
-                commit_hash = commit.id[:8]
-                outstream.write(f"{commit_hash.decode()} {line.decode()}\n")
+        with Repo(".") as repo:
+            config = repo.get_config_stack()
+            with get_pager(config=config, cmd_name="annotate") as outstream:
+                results = porcelain.annotate(repo, args.path, args.committish)
+                for (commit, entry), line in results:
+                    # Show shortened commit hash and line content
+                    commit_hash = commit.id[:8]
+                    outstream.write(f"{commit_hash.decode()} {line.decode()}\n")
 
 
 class cmd_blame(Command):
@@ -486,14 +536,16 @@ class cmd_log(Command):
         parser.add_argument("paths", nargs="*", help="Paths to show log for")
         args = parser.parse_args(args)
 
-        with get_pager() as outstream:
-            porcelain.log(
-                ".",
-                paths=args.paths,
-                reverse=args.reverse,
-                name_status=args.name_status,
-                outstream=outstream,
-            )
+        with Repo(".") as repo:
+            config = repo.get_config_stack()
+            with get_pager(config=config, cmd_name="log") as outstream:
+                porcelain.log(
+                    repo,
+                    paths=args.paths,
+                    reverse=args.reverse,
+                    name_status=args.name_status,
+                    outstream=outstream,
+                )
 
 
 class cmd_diff(Command):
@@ -523,37 +575,39 @@ class cmd_diff(Command):
 
         args = parsed_args
 
-        with get_pager() as outstream:
-            if len(args.committish) == 0:
-                # Show diff for working tree or staged changes
-                porcelain.diff(
-                    ".",
-                    staged=(args.staged or args.cached),
-                    paths=args.paths or None,
-                    outstream=outstream.buffer,
-                )
-            elif len(args.committish) == 1:
-                # Show diff between working tree and specified commit
-                if args.staged or args.cached:
-                    parser.error("--staged/--cached cannot be used with commits")
-                porcelain.diff(
-                    ".",
-                    commit=args.committish[0],
-                    staged=False,
-                    paths=args.paths or None,
-                    outstream=outstream.buffer,
-                )
-            elif len(args.committish) == 2:
-                # Show diff between two commits
-                porcelain.diff(
-                    ".",
-                    commit=args.committish[0],
-                    commit2=args.committish[1],
-                    paths=args.paths or None,
-                    outstream=outstream.buffer,
-                )
-            else:
-                parser.error("Too many arguments - specify at most two commits")
+        with Repo(".") as repo:
+            config = repo.get_config_stack()
+            with get_pager(config=config, cmd_name="diff") as outstream:
+                if len(args.committish) == 0:
+                    # Show diff for working tree or staged changes
+                    porcelain.diff(
+                        repo,
+                        staged=(args.staged or args.cached),
+                        paths=args.paths or None,
+                        outstream=outstream.buffer,
+                    )
+                elif len(args.committish) == 1:
+                    # Show diff between working tree and specified commit
+                    if args.staged or args.cached:
+                        parser.error("--staged/--cached cannot be used with commits")
+                    porcelain.diff(
+                        repo,
+                        commit=args.committish[0],
+                        staged=False,
+                        paths=args.paths or None,
+                        outstream=outstream.buffer,
+                    )
+                elif len(args.committish) == 2:
+                    # Show diff between two commits
+                    porcelain.diff(
+                        repo,
+                        commit=args.committish[0],
+                        commit2=args.committish[1],
+                        paths=args.paths or None,
+                        outstream=outstream.buffer,
+                    )
+                else:
+                    parser.error("Too many arguments - specify at most two commits")
 
 
 class cmd_dump_pack(Command):
@@ -729,8 +783,10 @@ class cmd_show(Command):
         parser = argparse.ArgumentParser()
         parser.add_argument("objectish", type=str, nargs="*")
         args = parser.parse_args(argv)
-        with get_pager() as outstream:
-            porcelain.show(".", args.objectish or None, outstream=outstream)
+        with Repo(".") as repo:
+            config = repo.get_config_stack()
+            with get_pager(config=config, cmd_name="show") as outstream:
+                porcelain.show(repo, args.objectish or None, outstream=outstream)
 
 
 class cmd_diff_tree(Command):
@@ -787,26 +843,30 @@ class cmd_reflog(Command):
         )
         args = parser.parse_args(args)
 
-        with get_pager() as outstream:
-            if args.all:
-                # Show reflogs for all refs
-                for ref_bytes, entry in porcelain.reflog(".", all=True):
-                    ref_str = ref_bytes.decode("utf-8", "replace")
-                    short_new = entry.new_sha[:8].decode("ascii")
-                    outstream.write(
-                        f"{short_new} {ref_str}: {entry.message.decode('utf-8', 'replace')}\n"
+        with Repo(".") as repo:
+            config = repo.get_config_stack()
+            with get_pager(config=config, cmd_name="reflog") as outstream:
+                if args.all:
+                    # Show reflogs for all refs
+                    for ref_bytes, entry in porcelain.reflog(repo, all=True):
+                        ref_str = ref_bytes.decode("utf-8", "replace")
+                        short_new = entry.new_sha[:8].decode("ascii")
+                        outstream.write(
+                            f"{short_new} {ref_str}: {entry.message.decode('utf-8', 'replace')}\n"
+                        )
+                else:
+                    ref = (
+                        args.ref.encode("utf-8")
+                        if isinstance(args.ref, str)
+                        else args.ref
                     )
-            else:
-                ref = (
-                    args.ref.encode("utf-8") if isinstance(args.ref, str) else args.ref
-                )
 
-                for i, entry in enumerate(porcelain.reflog(".", ref)):
-                    # Format similar to git reflog
-                    short_new = entry.new_sha[:8].decode("ascii")
-                    outstream.write(
-                        f"{short_new} {ref.decode('utf-8', 'replace')}@{{{i}}}: {entry.message.decode('utf-8', 'replace')}\n"
-                    )
+                    for i, entry in enumerate(porcelain.reflog(repo, ref)):
+                        # Format similar to git reflog
+                        short_new = entry.new_sha[:8].decode("ascii")
+                        outstream.write(
+                            f"{short_new} {ref.decode('utf-8', 'replace')}@{{{i}}}: {entry.message.decode('utf-8', 'replace')}\n"
+                        )
 
 
 class cmd_reset(Command):
@@ -997,14 +1057,16 @@ class cmd_ls_tree(Command):
         )
         parser.add_argument("treeish", nargs="?", help="Tree-ish to list")
         args = parser.parse_args(args)
-        with get_pager() as outstream:
-            porcelain.ls_tree(
-                ".",
-                args.treeish,
-                outstream=outstream,
-                recursive=args.recursive,
-                name_only=args.name_only,
-            )
+        with Repo(".") as repo:
+            config = repo.get_config_stack()
+            with get_pager(config=config, cmd_name="ls-tree") as outstream:
+                porcelain.ls_tree(
+                    repo,
+                    args.treeish,
+                    outstream=outstream,
+                    recursive=args.recursive,
+                    name_only=args.name_only,
+                )
 
 
 class cmd_pack_objects(Command):

+ 191 - 0
tests/test_cli.py

@@ -1582,5 +1582,196 @@ class ParseRelativeTimeTestCase(TestCase):
         )
 
 
+class GetPagerTest(TestCase):
+    """Tests for get_pager function."""
+
+    def setUp(self):
+        super().setUp()
+        # Save original environment
+        self.original_env = os.environ.copy()
+        # Clear pager-related environment variables
+        for var in ["DULWICH_PAGER", "GIT_PAGER", "PAGER"]:
+            os.environ.pop(var, None)
+        # Reset the global pager disable flag
+        cli.get_pager._disabled = False
+
+    def tearDown(self):
+        super().tearDown()
+        # Restore original environment
+        os.environ.clear()
+        os.environ.update(self.original_env)
+        # Reset the global pager disable flag
+        cli.get_pager._disabled = False
+
+    def test_pager_disabled_globally(self):
+        """Test that globally disabled pager returns stdout wrapper."""
+        cli.disable_pager()
+        pager = cli.get_pager()
+        self.assertIsInstance(pager, cli._StreamContextAdapter)
+        self.assertEqual(pager.stream, sys.stdout)
+
+    def test_pager_not_tty(self):
+        """Test that pager is disabled when stdout is not a TTY."""
+        with patch("sys.stdout.isatty", return_value=False):
+            pager = cli.get_pager()
+            self.assertIsInstance(pager, cli._StreamContextAdapter)
+
+    def test_pager_env_dulwich_pager(self):
+        """Test DULWICH_PAGER environment variable."""
+        os.environ["DULWICH_PAGER"] = "custom_pager"
+        with patch("sys.stdout.isatty", return_value=True):
+            pager = cli.get_pager()
+            self.assertIsInstance(pager, cli.Pager)
+            self.assertEqual(pager.pager_cmd, "custom_pager")
+
+    def test_pager_env_dulwich_pager_false(self):
+        """Test DULWICH_PAGER=false disables pager."""
+        os.environ["DULWICH_PAGER"] = "false"
+        with patch("sys.stdout.isatty", return_value=True):
+            pager = cli.get_pager()
+            self.assertIsInstance(pager, cli._StreamContextAdapter)
+
+    def test_pager_env_git_pager(self):
+        """Test GIT_PAGER environment variable."""
+        os.environ["GIT_PAGER"] = "git_custom_pager"
+        with patch("sys.stdout.isatty", return_value=True):
+            pager = cli.get_pager()
+            self.assertIsInstance(pager, cli.Pager)
+            self.assertEqual(pager.pager_cmd, "git_custom_pager")
+
+    def test_pager_env_pager(self):
+        """Test PAGER environment variable."""
+        os.environ["PAGER"] = "my_pager"
+        with patch("sys.stdout.isatty", return_value=True):
+            pager = cli.get_pager()
+            self.assertIsInstance(pager, cli.Pager)
+            self.assertEqual(pager.pager_cmd, "my_pager")
+
+    def test_pager_env_priority(self):
+        """Test environment variable priority order."""
+        os.environ["PAGER"] = "pager_low"
+        os.environ["GIT_PAGER"] = "pager_medium"
+        os.environ["DULWICH_PAGER"] = "pager_high"
+        with patch("sys.stdout.isatty", return_value=True):
+            pager = cli.get_pager()
+            self.assertEqual(pager.pager_cmd, "pager_high")
+
+    def test_pager_config_core_pager(self):
+        """Test core.pager configuration."""
+        config = MagicMock()
+        config.get.return_value = b"config_pager"
+        with patch("sys.stdout.isatty", return_value=True):
+            pager = cli.get_pager(config=config)
+            self.assertIsInstance(pager, cli.Pager)
+            self.assertEqual(pager.pager_cmd, "config_pager")
+            config.get.assert_called_with(("core",), b"pager")
+
+    def test_pager_config_core_pager_false(self):
+        """Test core.pager=false disables pager."""
+        config = MagicMock()
+        config.get.return_value = b"false"
+        with patch("sys.stdout.isatty", return_value=True):
+            pager = cli.get_pager(config=config)
+            self.assertIsInstance(pager, cli._StreamContextAdapter)
+
+    def test_pager_config_core_pager_empty(self):
+        """Test core.pager="" disables pager."""
+        config = MagicMock()
+        config.get.return_value = b""
+        with patch("sys.stdout.isatty", return_value=True):
+            pager = cli.get_pager(config=config)
+            self.assertIsInstance(pager, cli._StreamContextAdapter)
+
+    def test_pager_config_per_command(self):
+        """Test per-command pager configuration."""
+        config = MagicMock()
+        config.get.side_effect = lambda section, key: {
+            (("pager",), b"log"): b"log_pager",
+        }.get((section, key), KeyError())
+
+        with patch("sys.stdout.isatty", return_value=True):
+            pager = cli.get_pager(config=config, cmd_name="log")
+            self.assertIsInstance(pager, cli.Pager)
+            self.assertEqual(pager.pager_cmd, "log_pager")
+
+    def test_pager_config_per_command_false(self):
+        """Test per-command pager=false disables pager."""
+        config = MagicMock()
+        config.get.return_value = b"false"
+        with patch("sys.stdout.isatty", return_value=True):
+            pager = cli.get_pager(config=config, cmd_name="log")
+            self.assertIsInstance(pager, cli._StreamContextAdapter)
+
+    def test_pager_config_per_command_true(self):
+        """Test per-command pager=true uses default pager."""
+        config = MagicMock()
+
+        def get_side_effect(section, key):
+            if section == ("pager",) and key == b"log":
+                return b"true"
+            raise KeyError
+
+        config.get.side_effect = get_side_effect
+
+        with patch("sys.stdout.isatty", return_value=True):
+            with patch("shutil.which", side_effect=lambda cmd: cmd == "less"):
+                pager = cli.get_pager(config=config, cmd_name="log")
+                self.assertIsInstance(pager, cli.Pager)
+                self.assertEqual(pager.pager_cmd, "less -FRX")
+
+    def test_pager_priority_order(self):
+        """Test complete priority order."""
+        # Set up all possible configurations
+        os.environ["PAGER"] = "env_pager"
+        os.environ["GIT_PAGER"] = "env_git_pager"
+
+        config = MagicMock()
+
+        def get_side_effect(section, key):
+            if section == ("pager",) and key == b"log":
+                return b"cmd_pager"
+            elif section == ("core",) and key == b"pager":
+                return b"core_pager"
+            raise KeyError
+
+        config.get.side_effect = get_side_effect
+
+        with patch("sys.stdout.isatty", return_value=True):
+            # Per-command config should win
+            pager = cli.get_pager(config=config, cmd_name="log")
+            self.assertEqual(pager.pager_cmd, "cmd_pager")
+
+    def test_pager_fallback_less(self):
+        """Test fallback to less with proper flags."""
+        with patch("sys.stdout.isatty", return_value=True):
+            with patch("shutil.which", side_effect=lambda cmd: cmd == "less"):
+                pager = cli.get_pager()
+                self.assertIsInstance(pager, cli.Pager)
+                self.assertEqual(pager.pager_cmd, "less -FRX")
+
+    def test_pager_fallback_more(self):
+        """Test fallback to more when less is not available."""
+        with patch("sys.stdout.isatty", return_value=True):
+            with patch("shutil.which", side_effect=lambda cmd: cmd == "more"):
+                pager = cli.get_pager()
+                self.assertIsInstance(pager, cli.Pager)
+                self.assertEqual(pager.pager_cmd, "more")
+
+    def test_pager_fallback_cat(self):
+        """Test ultimate fallback to cat."""
+        with patch("sys.stdout.isatty", return_value=True):
+            with patch("shutil.which", return_value=None):
+                pager = cli.get_pager()
+                self.assertIsInstance(pager, cli.Pager)
+                self.assertEqual(pager.pager_cmd, "cat")
+
+    def test_pager_context_manager(self):
+        """Test that pager works as a context manager."""
+        with patch("sys.stdout.isatty", return_value=True):
+            with cli.get_pager() as pager:
+                self.assertTrue(hasattr(pager, "write"))
+                self.assertTrue(hasattr(pager, "flush"))
+
+
 if __name__ == "__main__":
     unittest.main()