소스 검색

Add some more tests for the paramiko vendor (#1684)

Jelmer Vernooij 1 개월 전
부모
커밋
9bf62f7e7b
2개의 변경된 파일338개의 추가작업 그리고 7개의 파일을 삭제
  1. 1 1
      dulwich/contrib/paramiko_vendor.py
  2. 337 6
      tests/contrib/test_paramiko_vendor.py

+ 1 - 1
dulwich/contrib/paramiko_vendor.py

@@ -28,7 +28,7 @@ the dulwich.client.get_ssh_vendor attribute:
   >>> from dulwich.contrib.paramiko_vendor import ParamikoSSHVendor
   >>> _mod_client.get_ssh_vendor = ParamikoSSHVendor
 
-This implementation is experimental and does not have any tests.
+This implementation has comprehensive tests in tests/contrib/test_paramiko_vendor.py.
 """
 
 import os

+ 337 - 6
tests/contrib/test_paramiko_vendor.py

@@ -24,6 +24,7 @@ import os
 import socket
 import tempfile
 import threading
+import time
 from io import StringIO
 from typing import Optional
 from unittest import skipIf
@@ -37,6 +38,8 @@ except ImportError:
     has_paramiko = False
 else:
     has_paramiko = True
+    import paramiko.transport
+
     from dulwich.contrib.paramiko_vendor import ParamikoSSHVendor
 
     class Server(paramiko.ServerInterface):
@@ -69,6 +72,140 @@ else:
         def get_allowed_auths(self, username) -> str:
             return "password,publickey"
 
+    class SSHServer:
+        """A real SSH server using Paramiko that listens on a TCP port."""
+
+        def __init__(self):
+            self.commands = []
+            self.server_socket = None
+            self.server_thread = None
+            self.host_key = paramiko.RSAKey.from_private_key(StringIO(SERVER_KEY))
+            self.running = False
+            self.connection_threads = []
+
+        def start(self):
+            """Start the SSH server on a random port."""
+            self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+            self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+            self.server_socket.bind(("127.0.0.1", 0))
+            self.port = self.server_socket.getsockname()[1]
+            self.server_socket.listen(5)
+
+            self.running = True
+            self.server_thread = threading.Thread(target=self._run_server)
+            self.server_thread.daemon = True
+            self.server_thread.start()
+
+            # Give the server a moment to start up
+            time.sleep(0.1)
+
+        def stop(self):
+            """Stop the SSH server."""
+            self.running = False
+            if self.server_socket:
+                try:
+                    self.server_socket.close()
+                except OSError:
+                    pass
+
+            # Clean up connection threads
+            for thread in self.connection_threads:
+                if thread.is_alive():
+                    thread.join(timeout=1)
+
+            if self.server_thread and self.server_thread.is_alive():
+                self.server_thread.join(timeout=5)
+
+        def _run_server(self):
+            """Main server loop."""
+            self.server_socket.settimeout(
+                1.0
+            )  # Allow checking self.running periodically
+
+            while self.running:
+                try:
+                    client_socket, addr = self.server_socket.accept()
+
+                    # Handle each connection in a separate thread
+                    conn_thread = threading.Thread(
+                        target=self._handle_connection, args=(client_socket,)
+                    )
+                    conn_thread.daemon = True
+                    conn_thread.start()
+                    self.connection_threads.append(conn_thread)
+
+                except socket.timeout:
+                    # Normal timeout, continue to check if we should keep running
+                    continue
+                except OSError as e:
+                    # Socket was closed, exit gracefully
+                    if not self.running:
+                        break
+                    # Otherwise re-raise the error
+                    raise e
+
+        def _handle_connection(self, client_socket):
+            """Handle a single SSH connection."""
+            transport = None
+            try:
+                transport = paramiko.Transport(client_socket)
+                transport.add_server_key(self.host_key)
+
+                server = Server(self.commands)
+                transport.start_server(server=server)
+
+                # Wait for channel requests and handle them
+                while self.running and transport.is_active():
+                    channel = transport.accept(1)
+                    if channel is None:
+                        continue
+
+                    # Handle channel in a separate thread to allow multiple channels
+                    channel_thread = threading.Thread(
+                        target=self._handle_channel, args=(channel,)
+                    )
+                    channel_thread.daemon = True
+                    channel_thread.start()
+
+            except paramiko.SSHException as e:
+                print(f"SSH error in connection handler: {e}")
+            except OSError as e:
+                print(f"Socket error in connection handler: {e}")
+            finally:
+                if transport:
+                    transport.close()
+                if client_socket:
+                    client_socket.close()
+
+        def _handle_channel(self, channel):
+            """Handle a single SSH channel - echo server."""
+            try:
+                # Set channel to blocking mode
+                channel.setblocking(True)
+                channel.settimeout(10.0)
+
+                # Read all available data and echo it back
+                while True:
+                    try:
+                        data = channel.recv(4096)
+                        if not data:
+                            break
+                        # Echo the data back immediately
+                        channel.send(data)
+                    except socket.timeout:
+                        # No more data available, break
+                        break
+
+            except paramiko.SSHException as e:
+                print(f"SSH error in channel handler: {e}")
+            except OSError as e:
+                print(f"Socket error in channel handler: {e}")
+            finally:
+                try:
+                    channel.close()
+                except OSError:
+                    pass
+
 
 USER = "testuser"
 PASSWORD = "test"
@@ -133,12 +270,6 @@ WxtWBWHwxfSmqgTXilEA3ALJp0kNolLnEttnhENwJpZHlqtes0ZA4w==
 @skipIf(not has_paramiko, "paramiko is not installed")
 class ParamikoSSHVendorTests(TestCase):
     def setUp(self) -> None:
-        import paramiko.transport
-
-        # re-enable server functionality for tests
-        if hasattr(paramiko.transport, "SERVER_DISABLED_BY_GENTOO"):
-            paramiko.transport.SERVER_DISABLED_BY_GENTOO = False
-
         self.commands = []
         socket.setdefaulttimeout(10)
         self.addCleanup(socket.setdefaulttimeout, None)
@@ -267,3 +398,203 @@ Host testserver
 
         finally:
             os.unlink(config_path)
+
+
+@skipIf(not has_paramiko, "paramiko is not installed")
+class ParamikoSSHVendorRealServerTests(TestCase):
+    """Tests for ParamikoSSHVendor using a real SSH server listening on TCP."""
+
+    def setUp(self) -> None:
+        self.ssh_server = SSHServer()
+        self.ssh_server.start()
+        socket.setdefaulttimeout(10)
+        self.addCleanup(socket.setdefaulttimeout, None)
+        self.addCleanup(self.ssh_server.stop)
+
+    def _run_command(self, command, **kwargs):
+        """Helper to run a command with default vendor settings."""
+        vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False)
+        kwargs.setdefault("port", self.ssh_server.port)
+        kwargs.setdefault("username", USER)
+        return vendor.run_command("127.0.0.1", command, **kwargs)
+
+    def _test_echo(self, con, data):
+        """Helper to test echo functionality."""
+        con.write(data)
+        response = con.read(len(data))
+        self.assertEqual(data, response)
+
+    def test_password_authentication_success(self) -> None:
+        """Test successful password authentication."""
+        con = self._run_command("test_password_auth", password=PASSWORD)
+        self.assertIn(b"test_password_auth", self.ssh_server.commands)
+        self._test_echo(con, b"hello\n")
+        con.close()
+
+    def test_key_authentication_success(self) -> None:
+        """Test successful key authentication."""
+        key = paramiko.RSAKey.from_private_key(StringIO(CLIENT_KEY))
+        con = self._run_command("test_key_auth", pkey=key)
+        self.assertIn(b"test_key_auth", self.ssh_server.commands)
+        self._test_echo(con, b"key_test\n")
+        con.close()
+
+    def test_authentication_failures(self) -> None:
+        """Test authentication failures."""
+        # Wrong password
+        with self.assertRaises(paramiko.AuthenticationException):
+            self._run_command("should_fail", password="wrong_password")
+
+        # Wrong key
+        wrong_key = paramiko.RSAKey.generate(2048)
+        with self.assertRaises(paramiko.AuthenticationException):
+            self._run_command("should_fail", pkey=wrong_key)
+
+    def test_connection_errors(self) -> None:
+        """Test various connection errors."""
+        vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False)
+
+        # Non-existent port
+        with self.assertRaises((OSError, ConnectionRefusedError)):
+            vendor.run_command(
+                "127.0.0.1", "fail", username=USER, port=65432, password=PASSWORD
+            )
+
+        # Invalid hostname
+        with self.assertRaises((socket.gaierror, OSError)):
+            vendor.run_command(
+                "invalid.hostname.example.com", "fail", username=USER, password=PASSWORD
+            )
+
+    def test_data_transfer(self) -> None:
+        """Test various data transfer scenarios."""
+        con = self._run_command("test_data", password=PASSWORD)
+
+        # Large data (10KB)
+        large_data = b"X" * 10240
+        self._test_echo(con, large_data)
+
+        # Binary data with all byte values
+        binary_data = bytes(range(256))
+        self._test_echo(con, binary_data)
+
+        con.close()
+
+    def test_multiple_connections(self) -> None:
+        """Test multiple sequential connections."""
+        # First connection
+        con1 = self._run_command("test_connection_1", password=PASSWORD)
+        self._test_echo(con1, b"first\n")
+        con1.close()
+
+        # Second connection
+        con2 = self._run_command("test_connection_2", password=PASSWORD)
+        self._test_echo(con2, b"second\n")
+        con2.close()
+
+        # Verify both commands were recorded
+        self.assertIn(b"test_connection_1", self.ssh_server.commands)
+        self.assertIn(b"test_connection_2", self.ssh_server.commands)
+
+    def test_key_from_file(self) -> None:
+        """Test authentication using key file."""
+        with tempfile.NamedTemporaryFile(mode="w", suffix=".key", delete=False) as f:
+            f.write(CLIENT_KEY)
+            key_path = f.name
+
+        try:
+            con = self._run_command("test_key_from_file", key_filename=key_path)
+            self.assertIn(b"test_key_from_file", self.ssh_server.commands)
+            self._test_echo(con, b"file_key_test\n")
+            con.close()
+        finally:
+            os.unlink(key_path)
+
+    def test_protocol_versions(self) -> None:
+        """Test protocol version handling."""
+        # Protocol version 2 (default)
+        con = self._run_command(
+            "test_protocol_v2", password=PASSWORD, protocol_version=2
+        )
+        self.assertIn(b"test_protocol_v2", self.ssh_server.commands)
+        con.close()
+
+        # Protocol version 1
+        con = self._run_command(
+            "test_protocol_v1", password=PASSWORD, protocol_version=1
+        )
+        self.assertIn(b"test_protocol_v1", self.ssh_server.commands)
+        con.close()
+
+    def test_vendor_options(self) -> None:
+        """Test vendor initialization options."""
+        # Test with timeout
+        vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False, timeout=1)
+        con = vendor.run_command(
+            "127.0.0.1",
+            "test_timeout",
+            username=USER,
+            port=self.ssh_server.port,
+            password=PASSWORD,
+        )
+        self.assertIn(b"test_timeout", self.ssh_server.commands)
+        con.close()
+
+    def test_can_read(self) -> None:
+        """Test can_read functionality."""
+        con = self._run_command("test_can_read", password=PASSWORD)
+
+        # Check can_read returns bool
+        self.assertIsInstance(con.can_read(), bool)
+
+        # Send data and verify echo
+        self._test_echo(con, b"test_data\n")
+        con.close()
+
+    def test_partial_reads(self) -> None:
+        """Test reading data in small chunks."""
+        con = self._run_command("test_partial", password=PASSWORD)
+
+        test_data = b"0123456789" * 10  # 100 bytes
+        con.write(test_data)
+
+        # Read in 10-byte chunks
+        received_data = b""
+        while len(received_data) < len(test_data):
+            chunk = con.read(10)
+            if not chunk:
+                break
+            received_data += chunk
+
+        self.assertEqual(test_data, received_data)
+        con.close()
+
+    def test_ssh_config_integration(self) -> None:
+        """Test SSH config integration."""
+        with tempfile.NamedTemporaryFile(mode="w", suffix=".config", delete=False) as f:
+            f.write(f"""
+Host testserver
+    HostName 127.0.0.1
+    User {USER}
+    Port {self.ssh_server.port}
+""")
+            config_path = f.name
+
+        try:
+            with patch(
+                "dulwich.contrib.paramiko_vendor.os.path.expanduser"
+            ) as mock_expanduser:
+                mock_expanduser.side_effect = (
+                    lambda p: config_path if p == "~/.ssh/config" else p
+                )
+
+                vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False)
+                con = vendor.run_command(
+                    "testserver", "test_ssh_config", password=PASSWORD
+                )
+
+                self.assertIn(b"test_ssh_config", self.ssh_server.commands)
+                self._test_echo(con, b"config_test\n")
+                con.close()
+        finally:
+            os.unlink(config_path)