2
0
Эх сурвалжийг харах

client: add support for GIT_SSH_COMMAND

Peter Rowlands 3 жил өмнө
parent
commit
5971aa03ae

+ 15 - 2
dulwich/client.py

@@ -1467,6 +1467,7 @@ class SSHVendor(object):
         port=None,
         password=None,
         key_filename=None,
+        ssh_command=None,
     ):
         """Connect to an SSH server.
 
@@ -1480,6 +1481,7 @@ class SSHVendor(object):
           port: Optional SSH port to use
           password: Optional ssh password for login or private key
           key_filename: Optional path to private keyfile
+          ssh_command: Optional SSH command
 
         Returns:
 
@@ -1505,6 +1507,7 @@ class SubprocessSSHVendor(SSHVendor):
         port=None,
         password=None,
         key_filename=None,
+        ssh_command=None,
     ):
 
         if password is not None:
@@ -1512,7 +1515,7 @@ class SubprocessSSHVendor(SSHVendor):
                 "Setting password not supported by SubprocessSSHVendor."
             )
 
-        args = ["ssh", "-x"]
+        args = [ssh_command or "ssh", "-x"]
 
         if port:
             args.extend(["-p", str(port)])
@@ -1547,9 +1550,12 @@ class PLinkSSHVendor(SSHVendor):
         port=None,
         password=None,
         key_filename=None,
+        ssh_command=None,
     ):
 
-        if sys.platform == "win32":
+        if ssh_command:
+            args = [ssh_command, "-ssh"]
+        elif sys.platform == "win32":
             args = ["plink.exe", "-ssh"]
         else:
             args = ["plink", "-ssh"]
@@ -1611,6 +1617,7 @@ class SSHGitClient(TraditionalGitClient):
         config=None,
         password=None,
         key_filename=None,
+        ssh_command=None,
         **kwargs
     ):
         self.host = host
@@ -1618,6 +1625,9 @@ class SSHGitClient(TraditionalGitClient):
         self.username = username
         self.password = password
         self.key_filename = key_filename
+        self.ssh_command = ssh_command or os.environ.get(
+            "GIT_SSH_COMMAND", os.environ.get("GIT_SSH")
+        )
         super(SSHGitClient, self).__init__(**kwargs)
         self.alternative_paths = {}
         if vendor is not None:
@@ -1667,6 +1677,9 @@ class SSHGitClient(TraditionalGitClient):
             kwargs["password"] = self.password
         if self.key_filename is not None:
             kwargs["key_filename"] = self.key_filename
+        # GIT_SSH_COMMAND takes precendence over GIT_SSH
+        if self.ssh_command is not None:
+            kwargs["ssh_command"] = self.ssh_command
         con = self.ssh_vendor.run_command(
             self.host, argv, port=self.port, username=self.username, **kwargs
         )

+ 53 - 0
dulwich/tests/test_client.py

@@ -705,6 +705,7 @@ class TestSSHVendor(object):
         port=None,
         password=None,
         key_filename=None,
+        ssh_command=None,
     ):
         self.host = host
         self.command = command
@@ -712,6 +713,7 @@ class TestSSHVendor(object):
         self.port = port
         self.password = password
         self.key_filename = key_filename
+        self.ssh_command = ssh_command
 
         class Subprocess:
             pass
@@ -785,6 +787,21 @@ class SSHGitClientTests(TestCase):
         client._connect(b"relative-command", b"/~/path/to/repo")
         self.assertEqual("git-relative-command '~/path/to/repo'", server.command)
 
+    def test_ssh_command_precedence(self):
+        os.environ["GIT_SSH"] = "/path/to/ssh"
+        test_client = SSHGitClient("git.samba.org")
+        self.assertEqual(test_client.ssh_command, "/path/to/ssh")
+
+        os.environ["GIT_SSH_COMMAND"] = "/path/to/ssh -o Option=Value"
+        test_client = SSHGitClient("git.samba.org")
+        self.assertEqual(test_client.ssh_command, "/path/to/ssh -o Option=Value")
+
+        test_client = SSHGitClient("git.samba.org", ssh_command="ssh -o Option1=Value1")
+        self.assertEqual(test_client.ssh_command, "ssh -o Option1=Value1")
+
+        del os.environ["GIT_SSH"]
+        del os.environ["GIT_SSH_COMMAND"]
+
 
 class ReportStatusParserTests(TestCase):
     def test_invalid_pack(self):
@@ -1230,6 +1247,24 @@ class SubprocessSSHVendorTests(TestCase):
 
         self.assertListEqual(expected, args[0])
 
+    def test_run_with_ssh_command(self):
+        expected = [
+            "/path/to/ssh -o Option=Value",
+            "-x",
+            "host",
+            "git-clone-url",
+        ]
+
+        vendor = SubprocessSSHVendor()
+        command = vendor.run_command(
+            "host",
+            "git-clone-url",
+            ssh_command="/path/to/ssh -o Option=Value",
+        )
+
+        args = command.proc.args
+        self.assertListEqual(expected, args[0])
+
 
 class PLinkSSHVendorTests(TestCase):
     def setUp(self):
@@ -1353,6 +1388,24 @@ class PLinkSSHVendorTests(TestCase):
 
         self.assertListEqual(expected, args[0])
 
+    def test_run_with_ssh_command(self):
+        expected = [
+            "/path/to/plink",
+            "-x",
+            "host",
+            "git-clone-url",
+        ]
+
+        vendor = SubprocessSSHVendor()
+        command = vendor.run_command(
+            "host",
+            "git-clone-url",
+            ssh_command="/path/to/plink",
+        )
+
+        args = command.proc.args
+        self.assertListEqual(expected, args[0])
+
 
 class RsyncUrlTests(TestCase):
     def test_simple(self):