|
|
@@ -23,10 +23,11 @@
|
|
|
|
|
|
import os
|
|
|
import tempfile
|
|
|
+import threading
|
|
|
import unittest
|
|
|
|
|
|
from dulwich import porcelain
|
|
|
-from dulwich.filters import FilterError
|
|
|
+from dulwich.filters import FilterError, ProcessFilterDriver
|
|
|
from dulwich.repo import Repo
|
|
|
|
|
|
from . import TestCase
|
|
|
@@ -317,3 +318,638 @@ class GitAttributesFilterIntegrationTests(TestCase):
|
|
|
entry = index[b"test.txt"]
|
|
|
blob = self.repo.object_store[entry.sha]
|
|
|
self.assertEqual(blob.data, b"test content\n")
|
|
|
+
|
|
|
+
|
|
|
+class ProcessFilterDriverTests(TestCase):
|
|
|
+ """Tests for ProcessFilterDriver with real process filter."""
|
|
|
+
|
|
|
+ def setUp(self):
|
|
|
+ super().setUp()
|
|
|
+ # Create a temporary test filter process dynamically
|
|
|
+ self.test_filter_path = self._create_test_filter()
|
|
|
+
|
|
|
+ def tearDown(self):
|
|
|
+ # Clean up the test filter
|
|
|
+ if hasattr(self, "test_filter_path") and os.path.exists(self.test_filter_path):
|
|
|
+ os.unlink(self.test_filter_path)
|
|
|
+ super().tearDown()
|
|
|
+
|
|
|
+ def _create_test_filter(self):
|
|
|
+ """Create a simple test filter process that works on all platforms."""
|
|
|
+ import tempfile
|
|
|
+
|
|
|
+ # Create filter script that uppercases on clean, lowercases on smudge
|
|
|
+ filter_script = """import sys
|
|
|
+import os
|
|
|
+
|
|
|
+# Simple filter that doesn't use any external dependencies
|
|
|
+def read_exact(n):
|
|
|
+ data = b""
|
|
|
+ while len(data) < n:
|
|
|
+ chunk = sys.stdin.buffer.read(n - len(data))
|
|
|
+ if not chunk:
|
|
|
+ break
|
|
|
+ data += chunk
|
|
|
+ return data
|
|
|
+
|
|
|
+def write_pkt(data):
|
|
|
+ if data is None:
|
|
|
+ sys.stdout.buffer.write(b"0000")
|
|
|
+ else:
|
|
|
+ length = len(data) + 4
|
|
|
+ sys.stdout.buffer.write(("{:04x}".format(length)).encode())
|
|
|
+ sys.stdout.buffer.write(data)
|
|
|
+ sys.stdout.buffer.flush()
|
|
|
+
|
|
|
+def read_pkt():
|
|
|
+ size_bytes = read_exact(4)
|
|
|
+ if not size_bytes:
|
|
|
+ return None
|
|
|
+ size = int(size_bytes.decode(), 16)
|
|
|
+ if size == 0:
|
|
|
+ return None
|
|
|
+ return read_exact(size - 4)
|
|
|
+
|
|
|
+# Handshake
|
|
|
+client_hello = read_pkt()
|
|
|
+version = read_pkt()
|
|
|
+flush = read_pkt()
|
|
|
+
|
|
|
+write_pkt(b"git-filter-server")
|
|
|
+write_pkt(b"version=2")
|
|
|
+write_pkt(None)
|
|
|
+
|
|
|
+# Read and echo capabilities
|
|
|
+caps = []
|
|
|
+while True:
|
|
|
+ cap = read_pkt()
|
|
|
+ if cap is None:
|
|
|
+ break
|
|
|
+ caps.append(cap)
|
|
|
+
|
|
|
+for cap in caps:
|
|
|
+ write_pkt(cap)
|
|
|
+write_pkt(None)
|
|
|
+
|
|
|
+# Process commands
|
|
|
+while True:
|
|
|
+ headers = {}
|
|
|
+ while True:
|
|
|
+ line = read_pkt()
|
|
|
+ if line is None:
|
|
|
+ break
|
|
|
+ if b"=" in line:
|
|
|
+ k, v = line.split(b"=", 1)
|
|
|
+ headers[k.decode()] = v.decode()
|
|
|
+
|
|
|
+ if not headers:
|
|
|
+ break
|
|
|
+
|
|
|
+ # Read data
|
|
|
+ data_chunks = []
|
|
|
+ while True:
|
|
|
+ chunk = read_pkt()
|
|
|
+ if chunk is None:
|
|
|
+ break
|
|
|
+ data_chunks.append(chunk)
|
|
|
+
|
|
|
+ data = b"".join(data_chunks)
|
|
|
+
|
|
|
+ # Process (uppercase for clean, lowercase for smudge)
|
|
|
+ if headers.get("command") == "clean":
|
|
|
+ result = data.upper()
|
|
|
+ elif headers.get("command") == "smudge":
|
|
|
+ result = data.lower()
|
|
|
+ else:
|
|
|
+ result = data
|
|
|
+
|
|
|
+ # Send response
|
|
|
+ write_pkt(b"status=success")
|
|
|
+ write_pkt(None)
|
|
|
+
|
|
|
+ # Send result
|
|
|
+ chunk_size = 65516
|
|
|
+ for i in range(0, len(result), chunk_size):
|
|
|
+ write_pkt(result[i:i+chunk_size])
|
|
|
+ write_pkt(None)
|
|
|
+"""
|
|
|
+
|
|
|
+ # Create temporary file
|
|
|
+ fd, path = tempfile.mkstemp(suffix=".py", prefix="test_filter_")
|
|
|
+ try:
|
|
|
+ os.write(fd, filter_script.encode())
|
|
|
+ os.close(fd)
|
|
|
+
|
|
|
+ # Make executable on Unix-like systems
|
|
|
+ if os.name != "nt": # Not Windows
|
|
|
+ os.chmod(path, 0o755)
|
|
|
+
|
|
|
+ return path
|
|
|
+ except:
|
|
|
+ if os.path.exists(path):
|
|
|
+ os.unlink(path)
|
|
|
+ raise
|
|
|
+
|
|
|
+ def test_process_filter_clean_operation(self):
|
|
|
+ """Test clean operation using real process filter."""
|
|
|
+ import sys
|
|
|
+
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd=f"{sys.executable} {self.test_filter_path}", required=False
|
|
|
+ )
|
|
|
+
|
|
|
+ test_data = b"hello world"
|
|
|
+ result = driver.clean(test_data)
|
|
|
+
|
|
|
+ # Our test filter uppercases on clean
|
|
|
+ self.assertEqual(result, b"HELLO WORLD")
|
|
|
+
|
|
|
+ def test_process_filter_smudge_operation(self):
|
|
|
+ """Test smudge operation using real process filter."""
|
|
|
+ import sys
|
|
|
+
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd=f"{sys.executable} {self.test_filter_path}", required=False
|
|
|
+ )
|
|
|
+
|
|
|
+ test_data = b"HELLO WORLD"
|
|
|
+ result = driver.smudge(test_data, b"test.txt")
|
|
|
+
|
|
|
+ # Our test filter lowercases on smudge
|
|
|
+ self.assertEqual(result, b"hello world")
|
|
|
+
|
|
|
+ def test_process_filter_large_data(self):
|
|
|
+ """Test process filter with data larger than single pkt-line."""
|
|
|
+ import sys
|
|
|
+
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd=f"{sys.executable} {self.test_filter_path}", required=False
|
|
|
+ )
|
|
|
+
|
|
|
+ # Create data larger than max pkt-line payload (65516 bytes)
|
|
|
+ test_data = b"a" * 70000
|
|
|
+ result = driver.clean(test_data)
|
|
|
+
|
|
|
+ # Should be uppercased
|
|
|
+ self.assertEqual(result, b"A" * 70000)
|
|
|
+
|
|
|
+ def test_fallback_to_individual_commands(self):
|
|
|
+ """Test fallback when process filter fails."""
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ clean_cmd="tr '[:lower:]' '[:upper:]'", # Shell command to uppercase
|
|
|
+ process_cmd="/nonexistent/command", # This should fail
|
|
|
+ required=False,
|
|
|
+ )
|
|
|
+
|
|
|
+ test_data = b"hello world\n"
|
|
|
+ result = driver.clean(test_data)
|
|
|
+
|
|
|
+ # Should fallback to tr command and uppercase
|
|
|
+ self.assertEqual(result, b"HELLO WORLD\n")
|
|
|
+
|
|
|
+ def test_process_reuse(self):
|
|
|
+ """Test that process is reused across multiple operations."""
|
|
|
+ import sys
|
|
|
+
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd=f"{sys.executable} {self.test_filter_path}", required=False
|
|
|
+ )
|
|
|
+
|
|
|
+ # First operation
|
|
|
+ result1 = driver.clean(b"test1")
|
|
|
+ self.assertEqual(result1, b"TEST1")
|
|
|
+
|
|
|
+ # Second operation should reuse the same process
|
|
|
+ result2 = driver.clean(b"test2")
|
|
|
+ self.assertEqual(result2, b"TEST2")
|
|
|
+
|
|
|
+ # Process should still be alive
|
|
|
+ self.assertIsNotNone(driver._process)
|
|
|
+ self.assertIsNone(driver._process.poll()) # None means still running
|
|
|
+
|
|
|
+ def test_error_handling_invalid_command(self):
|
|
|
+ """Test error handling with invalid filter command."""
|
|
|
+ driver = ProcessFilterDriver(process_cmd="/nonexistent/command", required=True)
|
|
|
+
|
|
|
+ with self.assertRaises(FilterError) as cm:
|
|
|
+ driver.clean(b"test data")
|
|
|
+
|
|
|
+ self.assertIn("Failed to start process filter", str(cm.exception))
|
|
|
+
|
|
|
+ def test_thread_safety_with_process_filter(self):
|
|
|
+ """Test thread safety with actual process filter."""
|
|
|
+ import sys
|
|
|
+
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd=f"{sys.executable} {self.test_filter_path}", required=False
|
|
|
+ )
|
|
|
+
|
|
|
+ results = []
|
|
|
+ errors = []
|
|
|
+
|
|
|
+ def worker(data):
|
|
|
+ try:
|
|
|
+ result = driver.clean(data)
|
|
|
+ results.append(result)
|
|
|
+ except Exception as e:
|
|
|
+ errors.append(e)
|
|
|
+
|
|
|
+ # Start multiple threads
|
|
|
+ threads = []
|
|
|
+ for i in range(3):
|
|
|
+ data = f"test{i}".encode()
|
|
|
+ t = threading.Thread(target=worker, args=(data,))
|
|
|
+ threads.append(t)
|
|
|
+ t.start()
|
|
|
+
|
|
|
+ # Wait for all threads
|
|
|
+ for t in threads:
|
|
|
+ t.join()
|
|
|
+
|
|
|
+ # Should have no errors and correct results
|
|
|
+ self.assertEqual(len(errors), 0, f"Errors: {errors}")
|
|
|
+ self.assertEqual(len(results), 3)
|
|
|
+
|
|
|
+ # Check results are correct (uppercased)
|
|
|
+ expected = [b"TEST0", b"TEST1", b"TEST2"]
|
|
|
+ self.assertEqual(sorted(results), sorted(expected))
|
|
|
+
|
|
|
+
|
|
|
+class ProcessFilterProtocolTests(TestCase):
|
|
|
+ """Tests for ProcessFilterDriver protocol compliance."""
|
|
|
+
|
|
|
+ def setUp(self):
|
|
|
+ super().setUp()
|
|
|
+ # Create a spec-compliant test filter process dynamically
|
|
|
+ self.test_filter_path = self._create_spec_compliant_filter()
|
|
|
+
|
|
|
+ def tearDown(self):
|
|
|
+ # Clean up the test filter
|
|
|
+ if hasattr(self, "test_filter_path") and os.path.exists(self.test_filter_path):
|
|
|
+ os.unlink(self.test_filter_path)
|
|
|
+ super().tearDown()
|
|
|
+
|
|
|
+ def _create_spec_compliant_filter(self):
|
|
|
+ """Create a spec-compliant test filter that works on all platforms."""
|
|
|
+ import tempfile
|
|
|
+
|
|
|
+ # This filter strictly follows Git spec - no newlines in packets
|
|
|
+ filter_script = """import sys
|
|
|
+
|
|
|
+def read_exact(n):
|
|
|
+ data = b""
|
|
|
+ while len(data) < n:
|
|
|
+ chunk = sys.stdin.buffer.read(n - len(data))
|
|
|
+ if not chunk:
|
|
|
+ break
|
|
|
+ data += chunk
|
|
|
+ return data
|
|
|
+
|
|
|
+def write_pkt(data):
|
|
|
+ if data is None:
|
|
|
+ sys.stdout.buffer.write(b"0000")
|
|
|
+ else:
|
|
|
+ length = len(data) + 4
|
|
|
+ sys.stdout.buffer.write(("{:04x}".format(length)).encode())
|
|
|
+ sys.stdout.buffer.write(data)
|
|
|
+ sys.stdout.buffer.flush()
|
|
|
+
|
|
|
+def read_pkt():
|
|
|
+ size_bytes = read_exact(4)
|
|
|
+ if not size_bytes:
|
|
|
+ return None
|
|
|
+ size = int(size_bytes.decode(), 16)
|
|
|
+ if size == 0:
|
|
|
+ return None
|
|
|
+ return read_exact(size - 4)
|
|
|
+
|
|
|
+# Handshake - exact format, no newlines
|
|
|
+client_hello = read_pkt()
|
|
|
+version = read_pkt()
|
|
|
+flush = read_pkt()
|
|
|
+
|
|
|
+if client_hello != b"git-filter-client":
|
|
|
+ sys.exit(1)
|
|
|
+if version != b"version=2":
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+write_pkt(b"git-filter-server") # No newline
|
|
|
+write_pkt(b"version=2") # No newline
|
|
|
+write_pkt(None)
|
|
|
+
|
|
|
+# Read and echo capabilities
|
|
|
+caps = []
|
|
|
+while True:
|
|
|
+ cap = read_pkt()
|
|
|
+ if cap is None:
|
|
|
+ break
|
|
|
+ caps.append(cap)
|
|
|
+
|
|
|
+for cap in caps:
|
|
|
+ if cap in [b"capability=clean", b"capability=smudge"]:
|
|
|
+ write_pkt(cap)
|
|
|
+write_pkt(None)
|
|
|
+
|
|
|
+# Process commands
|
|
|
+while True:
|
|
|
+ headers = {}
|
|
|
+ while True:
|
|
|
+ line = read_pkt()
|
|
|
+ if line is None:
|
|
|
+ break
|
|
|
+ if b"=" in line:
|
|
|
+ k, v = line.split(b"=", 1)
|
|
|
+ headers[k.decode()] = v.decode()
|
|
|
+
|
|
|
+ if not headers:
|
|
|
+ break
|
|
|
+
|
|
|
+ # Read data
|
|
|
+ data_chunks = []
|
|
|
+ while True:
|
|
|
+ chunk = read_pkt()
|
|
|
+ if chunk is None:
|
|
|
+ break
|
|
|
+ data_chunks.append(chunk)
|
|
|
+
|
|
|
+ data = b"".join(data_chunks)
|
|
|
+
|
|
|
+ # Process
|
|
|
+ if headers.get("command") == "clean":
|
|
|
+ result = data.upper()
|
|
|
+ elif headers.get("command") == "smudge":
|
|
|
+ result = data.lower()
|
|
|
+ else:
|
|
|
+ result = data
|
|
|
+
|
|
|
+ # Send response
|
|
|
+ write_pkt(b"status=success")
|
|
|
+ write_pkt(None)
|
|
|
+
|
|
|
+ # Send result
|
|
|
+ chunk_size = 65516
|
|
|
+ for i in range(0, len(result), chunk_size):
|
|
|
+ write_pkt(result[i:i+chunk_size])
|
|
|
+ write_pkt(None)
|
|
|
+"""
|
|
|
+
|
|
|
+ fd, path = tempfile.mkstemp(suffix=".py", prefix="test_filter_spec_")
|
|
|
+ try:
|
|
|
+ os.write(fd, filter_script.encode())
|
|
|
+ os.close(fd)
|
|
|
+
|
|
|
+ if os.name != "nt": # Not Windows
|
|
|
+ os.chmod(path, 0o755)
|
|
|
+
|
|
|
+ return path
|
|
|
+ except:
|
|
|
+ if os.path.exists(path):
|
|
|
+ os.unlink(path)
|
|
|
+ raise
|
|
|
+
|
|
|
+ def test_protocol_handshake_exact_format(self):
|
|
|
+ """Test that handshake uses exact format without newlines."""
|
|
|
+ import sys
|
|
|
+
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd=f"{sys.executable} {self.test_filter_path}",
|
|
|
+ required=True, # Require success to test protocol compliance
|
|
|
+ )
|
|
|
+
|
|
|
+ # This should work with exact protocol format
|
|
|
+ test_data = b"hello world"
|
|
|
+ result = driver.clean(test_data)
|
|
|
+
|
|
|
+ # Our test filter uppercases on clean
|
|
|
+ self.assertEqual(result, b"HELLO WORLD")
|
|
|
+
|
|
|
+ def test_capability_negotiation_exact_format(self):
|
|
|
+ """Test that capabilities are sent and received in exact format."""
|
|
|
+ import sys
|
|
|
+
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd=f"{sys.executable} {self.test_filter_path}", required=True
|
|
|
+ )
|
|
|
+
|
|
|
+ # Force capability negotiation by using both clean and smudge
|
|
|
+ clean_result = driver.clean(b"test")
|
|
|
+ smudge_result = driver.smudge(b"TEST", b"test.txt")
|
|
|
+
|
|
|
+ self.assertEqual(clean_result, b"TEST")
|
|
|
+ self.assertEqual(smudge_result, b"test")
|
|
|
+
|
|
|
+ def test_binary_data_handling(self):
|
|
|
+ """Test handling of binary data through the protocol."""
|
|
|
+ import sys
|
|
|
+
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd=f"{sys.executable} {self.test_filter_path}", required=False
|
|
|
+ )
|
|
|
+
|
|
|
+ # Binary data with null bytes, high bytes, etc.
|
|
|
+ binary_data = bytes(range(256))
|
|
|
+
|
|
|
+ try:
|
|
|
+ result = driver.clean(binary_data)
|
|
|
+ # Should handle binary data without crashing
|
|
|
+ self.assertIsInstance(result, bytes)
|
|
|
+ # Our test filter uppercases, which may not work for all binary data
|
|
|
+ # but should not crash
|
|
|
+ except UnicodeDecodeError:
|
|
|
+ # This might happen with binary data - acceptable
|
|
|
+ pass
|
|
|
+
|
|
|
+ def test_large_file_chunking(self):
|
|
|
+ """Test proper chunking of large files."""
|
|
|
+ import sys
|
|
|
+
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd=f"{sys.executable} {self.test_filter_path}", required=True
|
|
|
+ )
|
|
|
+
|
|
|
+ # Create data larger than max pkt-line payload (65516 bytes)
|
|
|
+ large_data = b"a" * 100000
|
|
|
+ result = driver.clean(large_data)
|
|
|
+
|
|
|
+ # Should be properly processed (uppercased)
|
|
|
+ expected = b"A" * 100000
|
|
|
+ self.assertEqual(result, expected)
|
|
|
+
|
|
|
+ def test_empty_file_handling(self):
|
|
|
+ """Test handling of empty files."""
|
|
|
+ import sys
|
|
|
+
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd=f"{sys.executable} {self.test_filter_path}", required=True
|
|
|
+ )
|
|
|
+
|
|
|
+ result = driver.clean(b"")
|
|
|
+ self.assertEqual(result, b"")
|
|
|
+
|
|
|
+ def test_special_characters_in_pathname(self):
|
|
|
+ """Test paths with special characters are handled correctly."""
|
|
|
+ import sys
|
|
|
+
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd=f"{sys.executable} {self.test_filter_path}", required=True
|
|
|
+ )
|
|
|
+
|
|
|
+ # Test various special characters in paths
|
|
|
+ special_paths = [
|
|
|
+ b"file with spaces.txt",
|
|
|
+ b"path/with/slashes.txt",
|
|
|
+ b"file=with=equals.txt",
|
|
|
+ b"file\nwith\nnewlines.txt",
|
|
|
+ ]
|
|
|
+
|
|
|
+ test_data = b"test data"
|
|
|
+
|
|
|
+ for path in special_paths:
|
|
|
+ result = driver.smudge(test_data, path)
|
|
|
+ self.assertEqual(result, b"test data")
|
|
|
+
|
|
|
+ def test_process_crash_recovery(self):
|
|
|
+ """Test that process is properly restarted after crash."""
|
|
|
+ import sys
|
|
|
+
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd=f"{sys.executable} {self.test_filter_path}", required=False
|
|
|
+ )
|
|
|
+
|
|
|
+ # First operation
|
|
|
+ result = driver.clean(b"test1")
|
|
|
+ self.assertEqual(result, b"TEST1")
|
|
|
+
|
|
|
+ # Kill the process
|
|
|
+ if driver._process:
|
|
|
+ driver._process.kill()
|
|
|
+ driver._process.wait()
|
|
|
+ driver._cleanup_process()
|
|
|
+
|
|
|
+ # Should restart and work again
|
|
|
+ result = driver.clean(b"test2")
|
|
|
+ self.assertEqual(result, b"TEST2")
|
|
|
+
|
|
|
+ def test_malformed_process_response_handling(self):
|
|
|
+ """Test handling of malformed responses from process."""
|
|
|
+ # Create a filter that sends malformed responses
|
|
|
+ malformed_filter = """#!/usr/bin/env python3
|
|
|
+import sys
|
|
|
+import os
|
|
|
+sys.path.insert(0, os.path.dirname(__file__))
|
|
|
+from dulwich.protocol import Protocol
|
|
|
+
|
|
|
+protocol = Protocol(
|
|
|
+ lambda n: sys.stdin.buffer.read(n),
|
|
|
+ lambda d: sys.stdout.buffer.write(d) or len(d)
|
|
|
+)
|
|
|
+
|
|
|
+# Read handshake
|
|
|
+protocol.read_pkt_line()
|
|
|
+protocol.read_pkt_line()
|
|
|
+protocol.read_pkt_line()
|
|
|
+
|
|
|
+# Send invalid handshake
|
|
|
+protocol.write_pkt_line(b"invalid-welcome")
|
|
|
+protocol.write_pkt_line(b"version=2")
|
|
|
+protocol.write_pkt_line(None)
|
|
|
+"""
|
|
|
+
|
|
|
+ import tempfile
|
|
|
+
|
|
|
+ fd, script_path = tempfile.mkstemp(suffix=".py")
|
|
|
+ try:
|
|
|
+ os.write(fd, malformed_filter.encode())
|
|
|
+ os.close(fd)
|
|
|
+ os.chmod(script_path, 0o755)
|
|
|
+
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd=f"python3 {script_path}",
|
|
|
+ clean_cmd="cat", # Fallback
|
|
|
+ required=False,
|
|
|
+ )
|
|
|
+
|
|
|
+ # Should fallback to clean_cmd when process fails
|
|
|
+ result = driver.clean(b"test data")
|
|
|
+ self.assertEqual(result, b"test data")
|
|
|
+
|
|
|
+ finally:
|
|
|
+ os.unlink(script_path)
|
|
|
+
|
|
|
+ def test_concurrent_filter_operations(self):
|
|
|
+ """Test that concurrent operations work correctly."""
|
|
|
+ import sys
|
|
|
+
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd=f"{sys.executable} {self.test_filter_path}", required=True
|
|
|
+ )
|
|
|
+
|
|
|
+ results = []
|
|
|
+ errors = []
|
|
|
+
|
|
|
+ def worker(data):
|
|
|
+ try:
|
|
|
+ result = driver.clean(data)
|
|
|
+ results.append(result)
|
|
|
+ except Exception as e:
|
|
|
+ errors.append(e)
|
|
|
+
|
|
|
+ # Start 5 concurrent operations
|
|
|
+ threads = []
|
|
|
+ test_data = [f"test{i}".encode() for i in range(5)]
|
|
|
+
|
|
|
+ for data in test_data:
|
|
|
+ t = threading.Thread(target=worker, args=(data,))
|
|
|
+ threads.append(t)
|
|
|
+ t.start()
|
|
|
+
|
|
|
+ for t in threads:
|
|
|
+ t.join()
|
|
|
+
|
|
|
+ # Should have no errors
|
|
|
+ self.assertEqual(len(errors), 0, f"Errors: {errors}")
|
|
|
+ self.assertEqual(len(results), 5)
|
|
|
+
|
|
|
+ # All results should be uppercase versions
|
|
|
+ expected = [data.upper() for data in test_data]
|
|
|
+ self.assertEqual(sorted(results), sorted(expected))
|
|
|
+
|
|
|
+ def test_process_resource_cleanup(self):
|
|
|
+ """Test that process resources are properly cleaned up."""
|
|
|
+ import sys
|
|
|
+
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd=f"{sys.executable} {self.test_filter_path}", required=False
|
|
|
+ )
|
|
|
+
|
|
|
+ # Use the driver
|
|
|
+ result = driver.clean(b"test")
|
|
|
+ self.assertEqual(result, b"TEST")
|
|
|
+
|
|
|
+ # Process should be running
|
|
|
+ self.assertIsNotNone(driver._process)
|
|
|
+ self.assertIsNone(driver._process.poll()) # None means still running
|
|
|
+
|
|
|
+ # Remember the old process to check it was terminated
|
|
|
+ old_process = driver._process
|
|
|
+
|
|
|
+ # Manually clean up (simulates __del__)
|
|
|
+ driver._cleanup_process()
|
|
|
+
|
|
|
+ # Process reference should be cleared
|
|
|
+ self.assertIsNone(driver._process)
|
|
|
+ self.assertIsNone(driver._protocol)
|
|
|
+
|
|
|
+ # Old process should be terminated
|
|
|
+ self.assertIsNotNone(old_process.poll()) # Not None means terminated
|
|
|
+
|
|
|
+ def test_required_filter_error_propagation(self):
|
|
|
+ """Test that errors are properly propagated when filter is required."""
|
|
|
+ driver = ProcessFilterDriver(
|
|
|
+ process_cmd="/definitely/nonexistent/command", required=True
|
|
|
+ )
|
|
|
+
|
|
|
+ with self.assertRaises(FilterError) as cm:
|
|
|
+ driver.clean(b"test data")
|
|
|
+
|
|
|
+ self.assertIn("Failed to start process filter", str(cm.exception))
|