Explorar o código

Add FilterContext for managing stateful filter resources

Introduces FilterContext class to properly manage filter driver state,
particularly for long-running process filters. This provides better
separation of concerns between filter registration (FilterRegistry)
and runtime state management (FilterContext).
Jelmer Vernooij hai 4 meses
pai
achega
79782761b5

+ 205 - 20
dulwich/filters.py

@@ -51,6 +51,28 @@ class FilterDriver(TypingProtocol):
         """Apply smudge filter (repository → working tree)."""
         ...
 
+    def cleanup(self) -> None:
+        """Clean up any resources held by this filter driver."""
+        ...
+
+    def reuse(self, config: "StackedConfig", filter_name: str) -> bool:
+        """Check if this filter driver should be reused with the given configuration.
+
+        This method determines whether a cached filter driver instance should continue
+        to be used or if it should be recreated. Only filters that are expensive to
+        create (like long-running process filters) and whose configuration hasn't
+        changed should return True. Lightweight filters should return False to ensure
+        they always use the latest configuration.
+
+        Args:
+            config: The current configuration stack
+            filter_name: The name of the filter in config
+
+        Returns:
+            True if the filter should be reused, False if it should be recreated
+        """
+        ...
+
 
 class ProcessFilterDriver:
     """Filter driver that executes external processes."""
@@ -156,7 +178,7 @@ class ProcessFilterDriver:
                         self._capabilities.add(cap[11:])  # Remove "capability=" prefix
 
             except (OSError, subprocess.SubprocessError, HangupException) as e:
-                self._cleanup_process()
+                self.cleanup()
                 raise FilterError(f"Failed to start process filter: {e}")
         return self._process
 
@@ -214,7 +236,7 @@ class ProcessFilterDriver:
 
             except (OSError, subprocess.SubprocessError, ValueError) as e:
                 # Clean up broken process
-                self._cleanup_process()
+                self.cleanup()
                 raise FilterError(f"Process filter failed: {e}")
 
     def clean(self, data: bytes) -> bytes:
@@ -290,7 +312,7 @@ class ProcessFilterDriver:
             logging.warning(f"Optional smudge filter failed: {e}")
             return data
 
-    def _cleanup_process(self):
+    def cleanup(self):
         """Clean up the process filter."""
         if self._process:
             # Close stdin first to signal the process to quit cleanly
@@ -347,9 +369,136 @@ class ProcessFilterDriver:
         self._process = None
         self._protocol = None
 
+    def reuse(self, config: "StackedConfig", filter_name: str) -> bool:
+        """Check if this filter driver should be reused with the given configuration."""
+        # Only reuse if it's a long-running process filter AND config hasn't changed
+        if self.process_cmd is None:
+            # Not a long-running filter, don't cache
+            return False
+
+        # Check if the filter commands in config match our current commands
+        try:
+            clean_cmd = config.get(("filter", filter_name), "clean")
+        except KeyError:
+            clean_cmd = None
+        if clean_cmd != self.clean_cmd:
+            return False
+
+        try:
+            smudge_cmd = config.get(("filter", filter_name), "smudge")
+        except KeyError:
+            smudge_cmd = None
+        if smudge_cmd != self.smudge_cmd:
+            return False
+
+        try:
+            process_cmd = config.get(("filter", filter_name), "process")
+        except KeyError:
+            process_cmd = None
+        if process_cmd != self.process_cmd:
+            return False
+
+        required = config.get_boolean(("filter", filter_name), "required", False)
+        if required != self.required:
+            return False
+
+        return True
+
     def __del__(self):
         """Clean up the process filter on destruction."""
-        self._cleanup_process()
+        self.cleanup()
+
+
+class FilterContext:
+    """Context for managing stateful filter resources.
+
+    This class manages the runtime state for filters, including:
+    - Cached filter driver instances that maintain long-running state
+    - Resource lifecycle management
+
+    It works in conjunction with FilterRegistry to provide complete
+    filter functionality while maintaining proper separation of concerns.
+    """
+
+    def __init__(self, filter_registry: "FilterRegistry") -> None:
+        """Initialize FilterContext.
+
+        Args:
+            filter_registry: The filter registry to use for driver lookups
+        """
+        self.filter_registry = filter_registry
+        self._active_drivers: dict[str, FilterDriver] = {}
+
+    def get_driver(self, name: str) -> Optional[FilterDriver]:
+        """Get a filter driver by name, managing stateful instances.
+
+        This method handles driver instantiation and caching. Only drivers
+        that should be reused are cached.
+
+        Args:
+            name: The filter name
+
+        Returns:
+            FilterDriver instance or None
+        """
+        driver: Optional[FilterDriver] = None
+        # Check if we have a cached instance that should be reused
+        if name in self._active_drivers:
+            driver = self._active_drivers[name]
+            # Check if the cached driver should still be reused
+            if self.filter_registry.config and driver.reuse(
+                self.filter_registry.config, name
+            ):
+                return driver
+            else:
+                # Driver shouldn't be reused, clean it up and remove from cache
+                driver.cleanup()
+                del self._active_drivers[name]
+
+        # Get driver from registry
+        driver = self.filter_registry.get_driver(name)
+        if driver is not None and self.filter_registry.config:
+            # Only cache drivers that should be reused
+            if driver.reuse(self.filter_registry.config, name):
+                self._active_drivers[name] = driver
+
+        return driver
+
+    def close(self) -> None:
+        """Close all active filter resources."""
+        # Clean up active drivers
+        for driver in self._active_drivers.values():
+            driver.cleanup()
+        self._active_drivers.clear()
+
+        # Also close the registry
+        self.filter_registry.close()
+
+    def refresh_config(self, config: "StackedConfig") -> None:
+        """Refresh the configuration used by the filter registry.
+
+        This should be called when the configuration has changed to ensure
+        filters use the latest settings.
+
+        Args:
+            config: The new configuration stack
+        """
+        # Update the registry's config
+        self.filter_registry.config = config
+
+        # Re-setup line ending filter with new config
+        # This will update the text filter factory to use new autocrlf settings
+        self.filter_registry._setup_line_ending_filter()
+
+        # The get_driver method will now handle checking reuse() for cached drivers
+
+    def __del__(self) -> None:
+        """Clean up on destruction."""
+        try:
+            self.close()
+        except Exception:
+            # Don't raise exceptions in __del__
+            pass
 
 
 class FilterRegistry:
@@ -412,8 +561,7 @@ class FilterRegistry:
     def close(self) -> None:
         """Close all filter drivers, ensuring process cleanup."""
         for driver in self._drivers.values():
-            if isinstance(driver, ProcessFilterDriver):
-                driver._cleanup_process()
+            driver.cleanup()
         self._drivers.clear()
 
     def __del__(self) -> None:
@@ -577,18 +725,30 @@ class FilterRegistry:
 def get_filter_for_path(
     path: bytes,
     gitattributes: "GitAttributes",
-    filter_registry: FilterRegistry,
+    filter_registry: Optional[FilterRegistry] = None,
+    filter_context: Optional[FilterContext] = None,
 ) -> Optional[FilterDriver]:
     """Get the appropriate filter driver for a given path.
 
     Args:
         path: Path to check
         gitattributes: GitAttributes object with parsed patterns
-        filter_registry: Registry of filter drivers
+        filter_registry: Registry of filter drivers (deprecated, use filter_context)
+        filter_context: Context for managing filter state
 
     Returns:
         FilterDriver instance or None
     """
+    # Use filter_context if provided, otherwise fall back to registry
+    if filter_context is not None:
+        registry = filter_context.filter_registry
+        get_driver = filter_context.get_driver
+    elif filter_registry is not None:
+        registry = filter_registry
+        get_driver = filter_registry.get_driver
+    else:
+        raise ValueError("Either filter_registry or filter_context must be provided")
+
     # Get all attributes for this path
     attributes = gitattributes.match_path(path)
 
@@ -599,11 +759,11 @@ def get_filter_for_path(
             return None
         if isinstance(filter_name, bytes):
             filter_name_str = filter_name.decode("utf-8")
-            driver = filter_registry.get_driver(filter_name_str)
+            driver = get_driver(filter_name_str)
 
             # Check if filter is required but missing
-            if driver is None and filter_registry.config is not None:
-                required = filter_registry.config.get_boolean(
+            if driver is None and registry.config is not None:
+                required = registry.config.get_boolean(
                     ("filter", filter_name_str), "required", False
                 )
                 if required:
@@ -618,16 +778,16 @@ def get_filter_for_path(
     text_attr = attributes.get(b"text")
     if text_attr is True:
         # Use the text filter for line ending conversion
-        return filter_registry.get_driver("text")
+        return get_driver("text")
     elif text_attr is False:
         # -text means binary, no conversion
         return None
 
     # If no explicit text attribute, check if autocrlf is enabled
     # When autocrlf is true/input, files are treated as text by default
-    if filter_registry.config is not None:
+    if registry.config is not None:
         try:
-            autocrlf_raw = filter_registry.config.get("core", "autocrlf")
+            autocrlf_raw = registry.config.get("core", "autocrlf")
             autocrlf: bytes = (
                 autocrlf_raw.lower()
                 if isinstance(autocrlf_raw, bytes)
@@ -635,7 +795,7 @@ def get_filter_for_path(
             )
             if autocrlf in (b"true", b"input"):
                 # Use text filter for files without explicit attributes
-                return filter_registry.get_driver("text")
+                return get_driver("text")
         except KeyError:
             pass
 
@@ -654,24 +814,47 @@ class FilterBlobNormalizer:
         gitattributes: GitAttributes,
         filter_registry: Optional[FilterRegistry] = None,
         repo: Optional["BaseRepo"] = None,
+        filter_context: Optional[FilterContext] = None,
     ) -> None:
         """Initialize FilterBlobNormalizer.
 
         Args:
           config_stack: Git configuration stack
           gitattributes: GitAttributes instance
-          filter_registry: Optional filter registry to use
+          filter_registry: Optional filter registry to use (deprecated, use filter_context)
           repo: Optional repository instance
+          filter_context: Optional filter context to use for managing filter state
         """
         self.config_stack = config_stack
         self.gitattributes = gitattributes
-        self.filter_registry = filter_registry or FilterRegistry(config_stack, repo)
+        self._owns_context = False  # Track if we created our own context
+
+        # Support both old and new API
+        if filter_context is not None:
+            self.filter_context = filter_context
+            self.filter_registry = filter_context.filter_registry
+            self._owns_context = False  # We're using an external context
+        else:
+            if filter_registry is not None:
+                import warnings
+
+                warnings.warn(
+                    "Passing filter_registry to FilterBlobNormalizer is deprecated. "
+                    "Pass a FilterContext instead.",
+                    DeprecationWarning,
+                    stacklevel=2,
+                )
+                self.filter_registry = filter_registry
+            else:
+                self.filter_registry = FilterRegistry(config_stack, repo)
+            self.filter_context = FilterContext(self.filter_registry)
+            self._owns_context = True  # We created our own context
 
     def checkin_normalize(self, blob: Blob, path: bytes) -> Blob:
         """Apply clean filter during checkin (working tree -> repository)."""
         # Get filter for this path
         filter_driver = get_filter_for_path(
-            path, self.gitattributes, self.filter_registry
+            path, self.gitattributes, filter_context=self.filter_context
         )
         if filter_driver is None:
             return blob
@@ -690,7 +873,7 @@ class FilterBlobNormalizer:
         """Apply smudge filter during checkout (repository -> working tree)."""
         # Get filter for this path
         filter_driver = get_filter_for_path(
-            path, self.gitattributes, self.filter_registry
+            path, self.gitattributes, filter_context=self.filter_context
         )
         if filter_driver is None:
             return blob
@@ -707,7 +890,9 @@ class FilterBlobNormalizer:
 
     def close(self) -> None:
         """Close all filter drivers, ensuring process cleanup."""
-        self.filter_registry.close()
+        # Only close the filter context if we created it ourselves
+        if self._owns_context:
+            self.filter_context.close()
 
     def __del__(self) -> None:
         """Clean up filter drivers on destruction."""

+ 9 - 0
dulwich/lfs.py

@@ -324,6 +324,15 @@ class LFSFilterDriver:
 
         return content
 
+    def cleanup(self) -> None:
+        """Clean up any resources held by this filter driver."""
+        # LFSFilterDriver doesn't hold any resources that need cleanup
+
+    def reuse(self, config, filter_name: str) -> bool:
+        """Check if this filter driver should be reused with the given configuration."""
+        # LFSFilterDriver is stateless and lightweight, no need to cache
+        return False
+
 
 def _get_lfs_user_agent(config: Optional["Config"]) -> str:
     """Get User-Agent string for LFS requests, respecting git config."""

+ 10 - 0
dulwich/line_ending.py

@@ -190,6 +190,16 @@ class LineEndingFilter(FilterDriver):
 
         return self.smudge_conversion(data)
 
+    def cleanup(self) -> None:
+        """Clean up any resources held by this filter driver."""
+        # LineEndingFilter doesn't hold any resources that need cleanup
+
+    def reuse(self, config, filter_name: str) -> bool:
+        """Check if this filter driver should be reused with the given configuration."""
+        # LineEndingFilter is lightweight and should always be recreated
+        # to ensure it uses the latest configuration
+        return False
+
 
 def convert_crlf_to_lf(text_hunk: bytes) -> bytes:
     """Convert CRLF in text hunk into LF.

+ 45 - 20
dulwich/repo.py

@@ -1301,6 +1301,9 @@ class Repo(BaseRepo):
         self.hooks["post-commit"] = PostCommitShellHook(self.controldir())
         self.hooks["post-receive"] = PostReceiveShellHook(self.controldir())
 
+        # Initialize filter context as None, will be created lazily
+        self.filter_context = None
+
     def get_worktree(self) -> "WorkTree":
         """Get the working tree for this repository.
 
@@ -1969,10 +1972,10 @@ class Repo(BaseRepo):
     def close(self) -> None:
         """Close any files opened by this repository."""
         self.object_store.close()
-        # Clean up cached blob normalizer
-        if hasattr(self, '_blob_normalizer'):
-            self._blob_normalizer.close()
-            del self._blob_normalizer
+        # Clean up filter context if it was created
+        if self.filter_context is not None:
+            self.filter_context.close()
+            self.filter_context = None
 
     def __enter__(self):
         """Enter context manager."""
@@ -2023,17 +2026,24 @@ class Repo(BaseRepo):
 
     def get_blob_normalizer(self):
         """Return a BlobNormalizer object."""
-        from .filters import FilterBlobNormalizer, FilterRegistry
+        from .filters import FilterBlobNormalizer, FilterContext, FilterRegistry
+
+        # Get fresh configuration and GitAttributes
+        config_stack = self.get_config_stack()
+        git_attributes = self.get_gitattributes()
 
-        # Cache FilterBlobNormalizer per repository to maintain registered drivers
-        if not hasattr(self, '_blob_normalizer'):
-            # Get proper GitAttributes object
-            git_attributes = self.get_gitattributes()
-            config_stack = self.get_config_stack()
+        # Lazily create FilterContext if needed
+        if self.filter_context is None:
             filter_registry = FilterRegistry(config_stack, self)
-            self._blob_normalizer = FilterBlobNormalizer(config_stack, git_attributes, filter_registry, self)
+            self.filter_context = FilterContext(filter_registry)
+        else:
+            # Refresh the context with current config to handle config changes
+            self.filter_context.refresh_config(config_stack)
 
-        return self._blob_normalizer
+        # Return a new FilterBlobNormalizer with the context
+        return FilterBlobNormalizer(
+            config_stack, git_attributes, filter_context=self.filter_context
+        )
 
     def get_gitattributes(self, tree: Optional[bytes] = None) -> "GitAttributes":
         """Read gitattributes for the repository.
@@ -2166,6 +2176,7 @@ class MemoryRepo(BaseRepo):
         self.bare = True
         self._config = ConfigFile()
         self._description = None
+        self.filter_context = None
 
     def _append_reflog(self, *args) -> None:
         self._reflog.append(args)
@@ -2258,17 +2269,24 @@ class MemoryRepo(BaseRepo):
 
     def get_blob_normalizer(self):
         """Return a BlobNormalizer object for checkin/checkout operations."""
-        from .filters import FilterBlobNormalizer, FilterRegistry
+        from .filters import FilterBlobNormalizer, FilterContext, FilterRegistry
+
+        # Get fresh configuration and GitAttributes
+        config_stack = self.get_config_stack()
+        git_attributes = self.get_gitattributes()
 
-        # Cache FilterBlobNormalizer per repository to maintain registered drivers
-        if not hasattr(self, '_blob_normalizer'):
-            # Get GitAttributes object
-            git_attributes = self.get_gitattributes()
-            config_stack = self.get_config_stack()
+        # Lazily create FilterContext if needed
+        if self.filter_context is None:
             filter_registry = FilterRegistry(config_stack, self)
-            self._blob_normalizer = FilterBlobNormalizer(config_stack, git_attributes, filter_registry, self)
+            self.filter_context = FilterContext(filter_registry)
+        else:
+            # Refresh the context with current config to handle config changes
+            self.filter_context.refresh_config(config_stack)
 
-        return self._blob_normalizer
+        # Return a new FilterBlobNormalizer with the context
+        return FilterBlobNormalizer(
+            config_stack, git_attributes, filter_context=self.filter_context
+        )
 
     def get_gitattributes(self, tree: Optional[bytes] = None) -> "GitAttributes":
         """Read gitattributes for the repository."""
@@ -2278,6 +2296,13 @@ class MemoryRepo(BaseRepo):
         # Return empty GitAttributes
         return GitAttributes([])
 
+    def close(self) -> None:
+        """Close any resources opened by this repository."""
+        # Clean up filter context if it was created
+        if self.filter_context is not None:
+            self.filter_context.close()
+            self.filter_context = None
+
     def do_commit(
         self,
         message: Optional[bytes] = None,

+ 3 - 1
dulwich/sparse_patterns.py

@@ -167,7 +167,9 @@ def apply_included_paths(
         # Create a temporary blob for normalization
         temp_blob = Blob()
         temp_blob.data = disk_data
-        norm_blob = normalizer.checkin_normalize(temp_blob, os.path.relpath(full_path, repo.path).encode())
+        norm_blob = normalizer.checkin_normalize(
+            temp_blob, os.path.relpath(full_path, repo.path).encode()
+        )
         norm_data = norm_blob.data
         if not isinstance(blob_obj, Blob):
             return True

+ 181 - 31
tests/test_filters.py

@@ -27,7 +27,12 @@ import threading
 import unittest
 
 from dulwich import porcelain
-from dulwich.filters import FilterError, ProcessFilterDriver
+from dulwich.filters import (
+    FilterContext,
+    FilterError,
+    FilterRegistry,
+    ProcessFilterDriver,
+)
 from dulwich.repo import Repo
 
 from . import TestCase
@@ -536,43 +541,188 @@ while True:
 
         self.assertIn("Failed to start process filter", str(cm.exception))
 
-    def test_thread_safety_with_process_filter(self):
-        """Test thread safety with actual process filter."""
+
+class FilterContextTests(TestCase):
+    """Tests for FilterContext class."""
+
+    def test_filter_context_caches_long_running_drivers(self):
+        """Test that FilterContext caches only long-running drivers."""
+
+        # Create real filter drivers
+        class UppercaseFilter:
+            def clean(self, data):
+                return data.upper()
+
+            def smudge(self, data, path=b""):
+                return data.lower()
+
+            def cleanup(self):
+                pass
+
+            def reuse(self, config, filter_name):
+                # Pretend it's a long-running filter that should be cached
+                return True
+
+        class IdentityFilter:
+            def clean(self, data):
+                return data
+
+            def smudge(self, data, path=b""):
+                return data
+
+            def cleanup(self):
+                pass
+
+            def reuse(self, config, filter_name):
+                # Lightweight filter, don't cache
+                return False
+
+        # Create registry and context
+        registry = FilterRegistry()
+        context = FilterContext(registry)
+
+        # Register drivers
+        long_running = UppercaseFilter()
+        stateless = IdentityFilter()
+        registry.register_driver("uppercase", long_running)
+        registry.register_driver("identity", stateless)
+
+        # Get drivers through context
+        driver1 = context.get_driver("uppercase")
+        driver2 = context.get_driver("uppercase")
+
+        # Long-running driver should be cached
+        self.assertIs(driver1, driver2)
+        self.assertIs(driver1, long_running)
+
+        # Get stateless driver
+        stateless1 = context.get_driver("identity")
+        stateless2 = context.get_driver("identity")
+
+        # Stateless driver comes from registry but isn't cached in context
+        self.assertIs(stateless1, stateless)
+        self.assertIs(stateless2, stateless)
+        self.assertNotIn("identity", context._active_drivers)
+        self.assertIn("uppercase", context._active_drivers)
+
+    def test_filter_context_cleanup(self):
+        """Test that FilterContext properly cleans up resources."""
+        cleanup_called = []
+
+        class TrackableFilter:
+            def __init__(self, name):
+                self.name = name
+
+            def clean(self, data):
+                return data
+
+            def smudge(self, data, path=b""):
+                return data
+
+            def cleanup(self):
+                cleanup_called.append(self.name)
+
+            def is_long_running(self):
+                return True
+
+        # Create registry and context
+        registry = FilterRegistry()
+        context = FilterContext(registry)
+
+        # Register and use drivers
+        filter1 = TrackableFilter("filter1")
+        filter2 = TrackableFilter("filter2")
+        filter3 = TrackableFilter("filter3")
+        registry.register_driver("filter1", filter1)
+        registry.register_driver("filter2", filter2)
+        registry.register_driver("filter3", filter3)
+
+        # Get only some drivers to cache them
+        context.get_driver("filter1")
+        context.get_driver("filter2")
+        # Don't get filter3
+
+        # Close context
+        context.close()
+
+        # Verify cleanup was called for all drivers (context closes registry too)
+        self.assertEqual(set(cleanup_called), {"filter1", "filter2", "filter3"})
+
+    def test_filter_context_get_driver_returns_none_for_missing(self):
+        """Test that get_driver returns None for non-existent drivers."""
+        registry = FilterRegistry()
+        context = FilterContext(registry)
+
+        result = context.get_driver("nonexistent")
+        self.assertIsNone(result)
+
+    def test_filter_context_with_real_process_filter(self):
+        """Test FilterContext with real ProcessFilterDriver instances."""
         import sys
 
-        driver = ProcessFilterDriver(
-            process_cmd=f"{sys.executable} {self.test_filter_path}", required=False
+        # Use existing test filter from ProcessFilterDriverTests
+        test_dir = tempfile.mkdtemp()
+        self.addCleanup(lambda: __import__("shutil").rmtree(test_dir))
+
+        # Create a simple test filter that just passes data through
+        filter_script = """import sys
+while True:
+    line = sys.stdin.buffer.read()
+    if not line:
+        break
+    sys.stdout.buffer.write(line)
+    sys.stdout.buffer.flush()
+"""
+        filter_path = os.path.join(test_dir, "simple_filter.py")
+        with open(filter_path, "w") as f:
+            f.write(filter_script)
+
+        # Create ProcessFilterDriver instances
+        # One with process_cmd (long-running)
+        process_driver = ProcessFilterDriver(
+            process_cmd=None,  # Don't use actual process to avoid complexity
+            clean_cmd=f"{sys.executable} {filter_path}",
+            smudge_cmd=f"{sys.executable} {filter_path}",
         )
 
-        results = []
-        errors = []
+        # Register in context
+        registry = FilterRegistry()
+        context = FilterContext(registry)
+        registry.register_driver("process", process_driver)
 
-        def worker(data):
-            try:
-                result = driver.clean(data)
-                results.append(result)
-            except Exception as e:
-                errors.append(e)
+        # Get driver - should not be cached since it's not long-running
+        driver1 = context.get_driver("process")
+        self.assertIsNotNone(driver1)
+        self.assertFalse(driver1.is_long_running())
+        self.assertNotIn("process", context._active_drivers)
 
-        # Start multiple threads
-        threads = []
-        for i in range(3):
-            data = f"test{i}".encode()
-            t = threading.Thread(target=worker, args=(data,))
-            threads.append(t)
-            t.start()
+        # Test with a long-running driver (has process_cmd)
+        long_process_driver = ProcessFilterDriver()
+        long_process_driver.process_cmd = "dummy"  # Just to make it long-running
+        registry.register_driver("long_process", long_process_driver)
 
-        # Wait for all threads
-        for t in threads:
-            t.join()
+        driver2 = context.get_driver("long_process")
+        self.assertTrue(driver2.is_long_running())
+        self.assertIn("long_process", context._active_drivers)
 
-        # Should have no errors and correct results
-        self.assertEqual(len(errors), 0, f"Errors: {errors}")
-        self.assertEqual(len(results), 3)
+        context.close()
 
-        # Check results are correct (uppercased)
-        expected = [b"TEST0", b"TEST1", b"TEST2"]
-        self.assertEqual(sorted(results), sorted(expected))
+    def test_filter_context_closes_registry(self):
+        """Test that closing FilterContext also closes the registry."""
+        # Track if registry.close() is called
+        registry_closed = []
+
+        class TrackingRegistry(FilterRegistry):
+            def close(self):
+                registry_closed.append(True)
+                super().close()
+
+        registry = TrackingRegistry()
+        context = FilterContext(registry)
+
+        # Close context should also close registry
+        context.close()
+        self.assertTrue(registry_closed)
 
 
 class ProcessFilterProtocolTests(TestCase):
@@ -824,7 +974,7 @@ while True:
         if driver._process:
             driver._process.kill()
             driver._process.wait()
-        driver._cleanup_process()
+        driver.cleanup()
 
         # Should restart and work again
         result = driver.clean(b"test2")
@@ -934,7 +1084,7 @@ protocol.write_pkt_line(None)
         old_process = driver._process
 
         # Manually clean up (simulates __del__)
-        driver._cleanup_process()
+        driver.cleanup()
 
         # Process reference should be cleared
         self.assertIsNone(driver._process)

+ 6 - 0
tests/test_line_ending.py

@@ -519,6 +519,12 @@ class LineEndingIntegrationTests(TestCase):
             def smudge(self, data):
                 return b"LFS content"
 
+            def cleanup(self):
+                pass
+
+            def reuse(self, config, filter_name):
+                return False
+
         self.registry.register_driver("lfs", MockLFSFilter())
 
         # Different files use different filters

+ 19 - 8
tests/test_sparse_patterns.py

@@ -27,7 +27,6 @@ import shutil
 import tempfile
 import time
 
-from dulwich.filters import FilterBlobNormalizer, FilterRegistry
 from dulwich.index import IndexEntry
 from dulwich.objects import Blob
 from dulwich.repo import Repo
@@ -544,7 +543,7 @@ class ApplyIncludedPathsTests(TestCase):
 
     def test_checkout_normalization_applied(self):
         """Test that checkout normalization is applied when materializing files during sparse checkout."""
-        
+
         # Create a simple filter that converts content to uppercase
         class UppercaseFilter:
             def smudge(self, input_bytes, path=b""):
@@ -553,18 +552,29 @@ class ApplyIncludedPathsTests(TestCase):
             def clean(self, input_bytes):
                 return input_bytes.lower()
 
+            def cleanup(self):
+                pass
+
+            def reuse(self, config, filter_name):
+                return False
+
         # Create .gitattributes file
         gitattributes_path = os.path.join(self.temp_dir, ".gitattributes")
         with open(gitattributes_path, "w") as f:
             f.write("*.txt filter=uppercase\n")
-        
+
         # Add and commit .gitattributes
-        self.repo.stage([b".gitattributes"])
+        self.repo.get_worktree().stage([b".gitattributes"])
         self.repo.do_commit(b"Add gitattributes", committer=b"Test <test@example.com>")
 
-        # Register the filter with the repo's cached filter registry
-        normalizer = self.repo.get_blob_normalizer()
-        normalizer.filter_registry.register_driver("uppercase", UppercaseFilter())
+        # Initialize the filter context and register the filter
+        _ = self.repo.get_blob_normalizer()
+
+        # Register the filter with the cached filter context
+        uppercase_filter = UppercaseFilter()
+        self.repo.filter_context.filter_registry.register_driver(
+            "uppercase", uppercase_filter
+        )
 
         # Commit a file with lowercase content
         self._commit_blob("test.txt", b"hello world")
@@ -572,7 +582,8 @@ class ApplyIncludedPathsTests(TestCase):
         # Remove the file from working tree to force materialization
         os.remove(os.path.join(self.temp_dir, "test.txt"))
 
-        # Apply sparse checkout
+        # Apply sparse checkout - this will call get_blob_normalizer() internally
+        # which will use the cached filter_context with our registered filter
         apply_included_paths(self.repo, included_paths={"test.txt"}, force=False)
 
         # Verify file was materialized with uppercase content (checkout normalization applied)