Просмотр исходного кода

Fix ParamikoSSHVendor interface compatibility with SSHVendor (#2029)

Fixes #2028
Jelmer Vernooij 1 месяц назад
Родитель
Сommit
09d3bc87e2
3 измененных файлов с 32 добавлено и 22 удалено
  1. 2 0
      NEWS
  2. 8 3
      dulwich/contrib/paramiko_vendor.py
  3. 22 19
      tests/contrib/test_paramiko_vendor.py

+ 2 - 0
NEWS

@@ -4,6 +4,8 @@
 is ahead of a 1.0 release, after which API changes will be kept backwards
 compatible.
 
+ * Fix ParamikoSSHVendor interface compatibility with SSHVendor. (Jelmer Vernooij, #2028)
+
  * Fix UTF-8 decode error in process filter protocol when handling binary files.
    (Jelmer Vernooij, #2023)
 

+ 8 - 3
dulwich/contrib/paramiko_vendor.py

@@ -144,12 +144,13 @@ class ParamikoSSHVendor:
     def run_command(
         self,
         host: str,
-        command: str,
+        command: bytes,
         username: str | None = None,
         port: int | None = None,
         password: str | None = None,
         pkey: paramiko.PKey | None = None,
         key_filename: str | None = None,
+        ssh_command: str | None = None,
         protocol_version: int | None = None,
         **kwargs: object,
     ) -> _ParamikoWrapper:
@@ -157,18 +158,22 @@ class ParamikoSSHVendor:
 
         Args:
             host: Hostname to connect to
-            command: Command to execute
+            command: Command to execute (as bytes)
             username: SSH username (optional)
             port: SSH port (optional)
             password: SSH password (optional)
             pkey: Private key for authentication (optional)
             key_filename: Path to private key file (optional)
+            ssh_command: SSH command (ignored - Paramiko doesn't use external SSH)
             protocol_version: SSH protocol version (optional)
             **kwargs: Additional keyword arguments
 
         Returns:
             _ParamikoWrapper instance for the SSH channel
         """
+        # Convert bytes command to str for paramiko
+        command_str = command.decode("utf-8")
+
         client = paramiko.SSHClient()
 
         # Get SSH config for this host
@@ -220,6 +225,6 @@ class ParamikoSSHVendor:
             channel.set_environment_variable(name="GIT_PROTOCOL", value="version=2")
 
         # Run commands
-        channel.exec_command(command)
+        channel.exec_command(command_str)
 
         return _ParamikoWrapper(client, channel)

+ 22 - 19
tests/contrib/test_paramiko_vendor.py

@@ -302,7 +302,7 @@ class ParamikoSSHVendorTests(TestCase):
         )
         vendor.run_command(
             "127.0.0.1",
-            "test_run_command_password",
+            b"test_run_command_password",
             username=USER,
             port=self.port,
             password=PASSWORD,
@@ -319,7 +319,7 @@ class ParamikoSSHVendorTests(TestCase):
         )
         vendor.run_command(
             "127.0.0.1",
-            "test_run_command_with_privkey",
+            b"test_run_command_with_privkey",
             username=USER,
             port=self.port,
             pkey=key,
@@ -334,7 +334,7 @@ class ParamikoSSHVendorTests(TestCase):
         )
         con = vendor.run_command(
             "127.0.0.1",
-            "test_run_command_data_transfer",
+            b"test_run_command_data_transfer",
             username=USER,
             port=self.port,
             password=PASSWORD,
@@ -425,7 +425,7 @@ class ParamikoSSHVendorRealServerTests(TestCase):
 
     def test_password_authentication_success(self) -> None:
         """Test successful password authentication."""
-        con = self._run_command("test_password_auth", password=PASSWORD)
+        con = self._run_command(b"test_password_auth", password=PASSWORD)
         self.assertIn(b"test_password_auth", self.ssh_server.commands)
         self._test_echo(con, b"hello\n")
         con.close()
@@ -433,7 +433,7 @@ class ParamikoSSHVendorRealServerTests(TestCase):
     def test_key_authentication_success(self) -> None:
         """Test successful key authentication."""
         key = paramiko.RSAKey.from_private_key(StringIO(CLIENT_KEY))
-        con = self._run_command("test_key_auth", pkey=key)
+        con = self._run_command(b"test_key_auth", pkey=key)
         self.assertIn(b"test_key_auth", self.ssh_server.commands)
         self._test_echo(con, b"key_test\n")
         con.close()
@@ -442,12 +442,12 @@ class ParamikoSSHVendorRealServerTests(TestCase):
         """Test authentication failures."""
         # Wrong password
         with self.assertRaises(paramiko.AuthenticationException):
-            self._run_command("should_fail", password="wrong_password")
+            self._run_command(b"should_fail", password="wrong_password")
 
         # Wrong key
         wrong_key = paramiko.RSAKey.generate(2048)
         with self.assertRaises(paramiko.AuthenticationException):
-            self._run_command("should_fail", pkey=wrong_key)
+            self._run_command(b"should_fail", pkey=wrong_key)
 
     def test_connection_errors(self) -> None:
         """Test various connection errors."""
@@ -456,18 +456,21 @@ class ParamikoSSHVendorRealServerTests(TestCase):
         # Non-existent port
         with self.assertRaises((OSError, ConnectionRefusedError)):
             vendor.run_command(
-                "127.0.0.1", "fail", username=USER, port=65432, password=PASSWORD
+                "127.0.0.1", b"fail", username=USER, port=65432, password=PASSWORD
             )
 
         # Invalid hostname
         with self.assertRaises((socket.gaierror, OSError)):
             vendor.run_command(
-                "invalid.hostname.example.com", "fail", username=USER, password=PASSWORD
+                "invalid.hostname.example.com",
+                b"fail",
+                username=USER,
+                password=PASSWORD,
             )
 
     def test_data_transfer(self) -> None:
         """Test various data transfer scenarios."""
-        con = self._run_command("test_data", password=PASSWORD)
+        con = self._run_command(b"test_data", password=PASSWORD)
 
         # Large data (10KB)
         large_data = b"X" * 10240
@@ -482,12 +485,12 @@ class ParamikoSSHVendorRealServerTests(TestCase):
     def test_multiple_connections(self) -> None:
         """Test multiple sequential connections."""
         # First connection
-        con1 = self._run_command("test_connection_1", password=PASSWORD)
+        con1 = self._run_command(b"test_connection_1", password=PASSWORD)
         self._test_echo(con1, b"first\n")
         con1.close()
 
         # Second connection
-        con2 = self._run_command("test_connection_2", password=PASSWORD)
+        con2 = self._run_command(b"test_connection_2", password=PASSWORD)
         self._test_echo(con2, b"second\n")
         con2.close()
 
@@ -502,7 +505,7 @@ class ParamikoSSHVendorRealServerTests(TestCase):
             key_path = f.name
 
         try:
-            con = self._run_command("test_key_from_file", key_filename=key_path)
+            con = self._run_command(b"test_key_from_file", key_filename=key_path)
             self.assertIn(b"test_key_from_file", self.ssh_server.commands)
             self._test_echo(con, b"file_key_test\n")
             con.close()
@@ -513,14 +516,14 @@ class ParamikoSSHVendorRealServerTests(TestCase):
         """Test protocol version handling."""
         # Protocol version 2 (default)
         con = self._run_command(
-            "test_protocol_v2", password=PASSWORD, protocol_version=2
+            b"test_protocol_v2", password=PASSWORD, protocol_version=2
         )
         self.assertIn(b"test_protocol_v2", self.ssh_server.commands)
         con.close()
 
         # Protocol version 1
         con = self._run_command(
-            "test_protocol_v1", password=PASSWORD, protocol_version=1
+            b"test_protocol_v1", password=PASSWORD, protocol_version=1
         )
         self.assertIn(b"test_protocol_v1", self.ssh_server.commands)
         con.close()
@@ -531,7 +534,7 @@ class ParamikoSSHVendorRealServerTests(TestCase):
         vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False, timeout=1)
         con = vendor.run_command(
             "127.0.0.1",
-            "test_timeout",
+            b"test_timeout",
             username=USER,
             port=self.ssh_server.port,
             password=PASSWORD,
@@ -541,7 +544,7 @@ class ParamikoSSHVendorRealServerTests(TestCase):
 
     def test_can_read(self) -> None:
         """Test can_read functionality."""
-        con = self._run_command("test_can_read", password=PASSWORD)
+        con = self._run_command(b"test_can_read", password=PASSWORD)
 
         # Check can_read returns bool
         self.assertIsInstance(con.can_read(), bool)
@@ -552,7 +555,7 @@ class ParamikoSSHVendorRealServerTests(TestCase):
 
     def test_partial_reads(self) -> None:
         """Test reading data in small chunks."""
-        con = self._run_command("test_partial", password=PASSWORD)
+        con = self._run_command(b"test_partial", password=PASSWORD)
 
         test_data = b"0123456789" * 10  # 100 bytes
         con.write(test_data)
@@ -589,7 +592,7 @@ Host testserver
 
                 vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False)
                 con = vendor.run_command(
-                    "testserver", "test_ssh_config", password=PASSWORD
+                    "testserver", b"test_ssh_config", password=PASSWORD
                 )
 
                 self.assertIn(b"test_ssh_config", self.ssh_server.commands)