|
@@ -24,6 +24,7 @@ import os
|
|
|
import socket
|
|
|
import tempfile
|
|
|
import threading
|
|
|
+import time
|
|
|
from io import StringIO
|
|
|
from typing import Optional
|
|
|
from unittest import skipIf
|
|
@@ -37,6 +38,8 @@ except ImportError:
|
|
|
has_paramiko = False
|
|
|
else:
|
|
|
has_paramiko = True
|
|
|
+ import paramiko.transport
|
|
|
+
|
|
|
from dulwich.contrib.paramiko_vendor import ParamikoSSHVendor
|
|
|
|
|
|
class Server(paramiko.ServerInterface):
|
|
@@ -69,6 +72,140 @@ else:
|
|
|
def get_allowed_auths(self, username) -> str:
|
|
|
return "password,publickey"
|
|
|
|
|
|
+ class SSHServer:
|
|
|
+ """A real SSH server using Paramiko that listens on a TCP port."""
|
|
|
+
|
|
|
+ def __init__(self):
|
|
|
+ self.commands = []
|
|
|
+ self.server_socket = None
|
|
|
+ self.server_thread = None
|
|
|
+ self.host_key = paramiko.RSAKey.from_private_key(StringIO(SERVER_KEY))
|
|
|
+ self.running = False
|
|
|
+ self.connection_threads = []
|
|
|
+
|
|
|
+ def start(self):
|
|
|
+ """Start the SSH server on a random port."""
|
|
|
+ self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
|
|
+ self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
|
+ self.server_socket.bind(("127.0.0.1", 0))
|
|
|
+ self.port = self.server_socket.getsockname()[1]
|
|
|
+ self.server_socket.listen(5)
|
|
|
+
|
|
|
+ self.running = True
|
|
|
+ self.server_thread = threading.Thread(target=self._run_server)
|
|
|
+ self.server_thread.daemon = True
|
|
|
+ self.server_thread.start()
|
|
|
+
|
|
|
+ # Give the server a moment to start up
|
|
|
+ time.sleep(0.1)
|
|
|
+
|
|
|
+ def stop(self):
|
|
|
+ """Stop the SSH server."""
|
|
|
+ self.running = False
|
|
|
+ if self.server_socket:
|
|
|
+ try:
|
|
|
+ self.server_socket.close()
|
|
|
+ except OSError:
|
|
|
+ pass
|
|
|
+
|
|
|
+ # Clean up connection threads
|
|
|
+ for thread in self.connection_threads:
|
|
|
+ if thread.is_alive():
|
|
|
+ thread.join(timeout=1)
|
|
|
+
|
|
|
+ if self.server_thread and self.server_thread.is_alive():
|
|
|
+ self.server_thread.join(timeout=5)
|
|
|
+
|
|
|
+ def _run_server(self):
|
|
|
+ """Main server loop."""
|
|
|
+ self.server_socket.settimeout(
|
|
|
+ 1.0
|
|
|
+ ) # Allow checking self.running periodically
|
|
|
+
|
|
|
+ while self.running:
|
|
|
+ try:
|
|
|
+ client_socket, addr = self.server_socket.accept()
|
|
|
+
|
|
|
+ # Handle each connection in a separate thread
|
|
|
+ conn_thread = threading.Thread(
|
|
|
+ target=self._handle_connection, args=(client_socket,)
|
|
|
+ )
|
|
|
+ conn_thread.daemon = True
|
|
|
+ conn_thread.start()
|
|
|
+ self.connection_threads.append(conn_thread)
|
|
|
+
|
|
|
+ except socket.timeout:
|
|
|
+ # Normal timeout, continue to check if we should keep running
|
|
|
+ continue
|
|
|
+ except OSError as e:
|
|
|
+ # Socket was closed, exit gracefully
|
|
|
+ if not self.running:
|
|
|
+ break
|
|
|
+ # Otherwise re-raise the error
|
|
|
+ raise e
|
|
|
+
|
|
|
+ def _handle_connection(self, client_socket):
|
|
|
+ """Handle a single SSH connection."""
|
|
|
+ transport = None
|
|
|
+ try:
|
|
|
+ transport = paramiko.Transport(client_socket)
|
|
|
+ transport.add_server_key(self.host_key)
|
|
|
+
|
|
|
+ server = Server(self.commands)
|
|
|
+ transport.start_server(server=server)
|
|
|
+
|
|
|
+ # Wait for channel requests and handle them
|
|
|
+ while self.running and transport.is_active():
|
|
|
+ channel = transport.accept(1)
|
|
|
+ if channel is None:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # Handle channel in a separate thread to allow multiple channels
|
|
|
+ channel_thread = threading.Thread(
|
|
|
+ target=self._handle_channel, args=(channel,)
|
|
|
+ )
|
|
|
+ channel_thread.daemon = True
|
|
|
+ channel_thread.start()
|
|
|
+
|
|
|
+ except paramiko.SSHException as e:
|
|
|
+ print(f"SSH error in connection handler: {e}")
|
|
|
+ except OSError as e:
|
|
|
+ print(f"Socket error in connection handler: {e}")
|
|
|
+ finally:
|
|
|
+ if transport:
|
|
|
+ transport.close()
|
|
|
+ if client_socket:
|
|
|
+ client_socket.close()
|
|
|
+
|
|
|
+ def _handle_channel(self, channel):
|
|
|
+ """Handle a single SSH channel - echo server."""
|
|
|
+ try:
|
|
|
+ # Set channel to blocking mode
|
|
|
+ channel.setblocking(True)
|
|
|
+ channel.settimeout(10.0)
|
|
|
+
|
|
|
+ # Read all available data and echo it back
|
|
|
+ while True:
|
|
|
+ try:
|
|
|
+ data = channel.recv(4096)
|
|
|
+ if not data:
|
|
|
+ break
|
|
|
+ # Echo the data back immediately
|
|
|
+ channel.send(data)
|
|
|
+ except socket.timeout:
|
|
|
+ # No more data available, break
|
|
|
+ break
|
|
|
+
|
|
|
+ except paramiko.SSHException as e:
|
|
|
+ print(f"SSH error in channel handler: {e}")
|
|
|
+ except OSError as e:
|
|
|
+ print(f"Socket error in channel handler: {e}")
|
|
|
+ finally:
|
|
|
+ try:
|
|
|
+ channel.close()
|
|
|
+ except OSError:
|
|
|
+ pass
|
|
|
+
|
|
|
|
|
|
USER = "testuser"
|
|
|
PASSWORD = "test"
|
|
@@ -133,12 +270,6 @@ WxtWBWHwxfSmqgTXilEA3ALJp0kNolLnEttnhENwJpZHlqtes0ZA4w==
|
|
|
@skipIf(not has_paramiko, "paramiko is not installed")
|
|
|
class ParamikoSSHVendorTests(TestCase):
|
|
|
def setUp(self) -> None:
|
|
|
- import paramiko.transport
|
|
|
-
|
|
|
- # re-enable server functionality for tests
|
|
|
- if hasattr(paramiko.transport, "SERVER_DISABLED_BY_GENTOO"):
|
|
|
- paramiko.transport.SERVER_DISABLED_BY_GENTOO = False
|
|
|
-
|
|
|
self.commands = []
|
|
|
socket.setdefaulttimeout(10)
|
|
|
self.addCleanup(socket.setdefaulttimeout, None)
|
|
@@ -267,3 +398,203 @@ Host testserver
|
|
|
|
|
|
finally:
|
|
|
os.unlink(config_path)
|
|
|
+
|
|
|
+
|
|
|
+@skipIf(not has_paramiko, "paramiko is not installed")
|
|
|
+class ParamikoSSHVendorRealServerTests(TestCase):
|
|
|
+ """Tests for ParamikoSSHVendor using a real SSH server listening on TCP."""
|
|
|
+
|
|
|
+ def setUp(self) -> None:
|
|
|
+ self.ssh_server = SSHServer()
|
|
|
+ self.ssh_server.start()
|
|
|
+ socket.setdefaulttimeout(10)
|
|
|
+ self.addCleanup(socket.setdefaulttimeout, None)
|
|
|
+ self.addCleanup(self.ssh_server.stop)
|
|
|
+
|
|
|
+ def _run_command(self, command, **kwargs):
|
|
|
+ """Helper to run a command with default vendor settings."""
|
|
|
+ vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False)
|
|
|
+ kwargs.setdefault("port", self.ssh_server.port)
|
|
|
+ kwargs.setdefault("username", USER)
|
|
|
+ return vendor.run_command("127.0.0.1", command, **kwargs)
|
|
|
+
|
|
|
+ def _test_echo(self, con, data):
|
|
|
+ """Helper to test echo functionality."""
|
|
|
+ con.write(data)
|
|
|
+ response = con.read(len(data))
|
|
|
+ self.assertEqual(data, response)
|
|
|
+
|
|
|
+ def test_password_authentication_success(self) -> None:
|
|
|
+ """Test successful password authentication."""
|
|
|
+ con = self._run_command("test_password_auth", password=PASSWORD)
|
|
|
+ self.assertIn(b"test_password_auth", self.ssh_server.commands)
|
|
|
+ self._test_echo(con, b"hello\n")
|
|
|
+ con.close()
|
|
|
+
|
|
|
+ def test_key_authentication_success(self) -> None:
|
|
|
+ """Test successful key authentication."""
|
|
|
+ key = paramiko.RSAKey.from_private_key(StringIO(CLIENT_KEY))
|
|
|
+ con = self._run_command("test_key_auth", pkey=key)
|
|
|
+ self.assertIn(b"test_key_auth", self.ssh_server.commands)
|
|
|
+ self._test_echo(con, b"key_test\n")
|
|
|
+ con.close()
|
|
|
+
|
|
|
+ def test_authentication_failures(self) -> None:
|
|
|
+ """Test authentication failures."""
|
|
|
+ # Wrong password
|
|
|
+ with self.assertRaises(paramiko.AuthenticationException):
|
|
|
+ self._run_command("should_fail", password="wrong_password")
|
|
|
+
|
|
|
+ # Wrong key
|
|
|
+ wrong_key = paramiko.RSAKey.generate(2048)
|
|
|
+ with self.assertRaises(paramiko.AuthenticationException):
|
|
|
+ self._run_command("should_fail", pkey=wrong_key)
|
|
|
+
|
|
|
+ def test_connection_errors(self) -> None:
|
|
|
+ """Test various connection errors."""
|
|
|
+ vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False)
|
|
|
+
|
|
|
+ # Non-existent port
|
|
|
+ with self.assertRaises((OSError, ConnectionRefusedError)):
|
|
|
+ vendor.run_command(
|
|
|
+ "127.0.0.1", "fail", username=USER, port=65432, password=PASSWORD
|
|
|
+ )
|
|
|
+
|
|
|
+ # Invalid hostname
|
|
|
+ with self.assertRaises((socket.gaierror, OSError)):
|
|
|
+ vendor.run_command(
|
|
|
+ "invalid.hostname.example.com", "fail", username=USER, password=PASSWORD
|
|
|
+ )
|
|
|
+
|
|
|
+ def test_data_transfer(self) -> None:
|
|
|
+ """Test various data transfer scenarios."""
|
|
|
+ con = self._run_command("test_data", password=PASSWORD)
|
|
|
+
|
|
|
+ # Large data (10KB)
|
|
|
+ large_data = b"X" * 10240
|
|
|
+ self._test_echo(con, large_data)
|
|
|
+
|
|
|
+ # Binary data with all byte values
|
|
|
+ binary_data = bytes(range(256))
|
|
|
+ self._test_echo(con, binary_data)
|
|
|
+
|
|
|
+ con.close()
|
|
|
+
|
|
|
+ def test_multiple_connections(self) -> None:
|
|
|
+ """Test multiple sequential connections."""
|
|
|
+ # First connection
|
|
|
+ con1 = self._run_command("test_connection_1", password=PASSWORD)
|
|
|
+ self._test_echo(con1, b"first\n")
|
|
|
+ con1.close()
|
|
|
+
|
|
|
+ # Second connection
|
|
|
+ con2 = self._run_command("test_connection_2", password=PASSWORD)
|
|
|
+ self._test_echo(con2, b"second\n")
|
|
|
+ con2.close()
|
|
|
+
|
|
|
+ # Verify both commands were recorded
|
|
|
+ self.assertIn(b"test_connection_1", self.ssh_server.commands)
|
|
|
+ self.assertIn(b"test_connection_2", self.ssh_server.commands)
|
|
|
+
|
|
|
+ def test_key_from_file(self) -> None:
|
|
|
+ """Test authentication using key file."""
|
|
|
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".key", delete=False) as f:
|
|
|
+ f.write(CLIENT_KEY)
|
|
|
+ key_path = f.name
|
|
|
+
|
|
|
+ try:
|
|
|
+ con = self._run_command("test_key_from_file", key_filename=key_path)
|
|
|
+ self.assertIn(b"test_key_from_file", self.ssh_server.commands)
|
|
|
+ self._test_echo(con, b"file_key_test\n")
|
|
|
+ con.close()
|
|
|
+ finally:
|
|
|
+ os.unlink(key_path)
|
|
|
+
|
|
|
+ def test_protocol_versions(self) -> None:
|
|
|
+ """Test protocol version handling."""
|
|
|
+ # Protocol version 2 (default)
|
|
|
+ con = self._run_command(
|
|
|
+ "test_protocol_v2", password=PASSWORD, protocol_version=2
|
|
|
+ )
|
|
|
+ self.assertIn(b"test_protocol_v2", self.ssh_server.commands)
|
|
|
+ con.close()
|
|
|
+
|
|
|
+ # Protocol version 1
|
|
|
+ con = self._run_command(
|
|
|
+ "test_protocol_v1", password=PASSWORD, protocol_version=1
|
|
|
+ )
|
|
|
+ self.assertIn(b"test_protocol_v1", self.ssh_server.commands)
|
|
|
+ con.close()
|
|
|
+
|
|
|
+ def test_vendor_options(self) -> None:
|
|
|
+ """Test vendor initialization options."""
|
|
|
+ # Test with timeout
|
|
|
+ vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False, timeout=1)
|
|
|
+ con = vendor.run_command(
|
|
|
+ "127.0.0.1",
|
|
|
+ "test_timeout",
|
|
|
+ username=USER,
|
|
|
+ port=self.ssh_server.port,
|
|
|
+ password=PASSWORD,
|
|
|
+ )
|
|
|
+ self.assertIn(b"test_timeout", self.ssh_server.commands)
|
|
|
+ con.close()
|
|
|
+
|
|
|
+ def test_can_read(self) -> None:
|
|
|
+ """Test can_read functionality."""
|
|
|
+ con = self._run_command("test_can_read", password=PASSWORD)
|
|
|
+
|
|
|
+ # Check can_read returns bool
|
|
|
+ self.assertIsInstance(con.can_read(), bool)
|
|
|
+
|
|
|
+ # Send data and verify echo
|
|
|
+ self._test_echo(con, b"test_data\n")
|
|
|
+ con.close()
|
|
|
+
|
|
|
+ def test_partial_reads(self) -> None:
|
|
|
+ """Test reading data in small chunks."""
|
|
|
+ con = self._run_command("test_partial", password=PASSWORD)
|
|
|
+
|
|
|
+ test_data = b"0123456789" * 10 # 100 bytes
|
|
|
+ con.write(test_data)
|
|
|
+
|
|
|
+ # Read in 10-byte chunks
|
|
|
+ received_data = b""
|
|
|
+ while len(received_data) < len(test_data):
|
|
|
+ chunk = con.read(10)
|
|
|
+ if not chunk:
|
|
|
+ break
|
|
|
+ received_data += chunk
|
|
|
+
|
|
|
+ self.assertEqual(test_data, received_data)
|
|
|
+ con.close()
|
|
|
+
|
|
|
+ def test_ssh_config_integration(self) -> None:
|
|
|
+ """Test SSH config integration."""
|
|
|
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".config", delete=False) as f:
|
|
|
+ f.write(f"""
|
|
|
+Host testserver
|
|
|
+ HostName 127.0.0.1
|
|
|
+ User {USER}
|
|
|
+ Port {self.ssh_server.port}
|
|
|
+""")
|
|
|
+ config_path = f.name
|
|
|
+
|
|
|
+ try:
|
|
|
+ with patch(
|
|
|
+ "dulwich.contrib.paramiko_vendor.os.path.expanduser"
|
|
|
+ ) as mock_expanduser:
|
|
|
+ mock_expanduser.side_effect = (
|
|
|
+ lambda p: config_path if p == "~/.ssh/config" else p
|
|
|
+ )
|
|
|
+
|
|
|
+ vendor = ParamikoSSHVendor(allow_agent=False, look_for_keys=False)
|
|
|
+ con = vendor.run_command(
|
|
|
+ "testserver", "test_ssh_config", password=PASSWORD
|
|
|
+ )
|
|
|
+
|
|
|
+ self.assertIn(b"test_ssh_config", self.ssh_server.commands)
|
|
|
+ self._test_echo(con, b"config_test\n")
|
|
|
+ con.close()
|
|
|
+ finally:
|
|
|
+ os.unlink(config_path)
|