Просмотр исходного кода

Add sparse:oid filter integration with sparse checkout patterns

Jelmer Vernooij 3 недель назад
Родитель
Сommit
270f08258d
3 измененных файлов с 203 добавлено и 37 удалено
  1. 86 18
      dulwich/partial_clone.py
  2. 2 2
      dulwich/server.py
  3. 115 17
      tests/test_partial_clone.py

+ 86 - 18
dulwich/partial_clone.py

@@ -33,21 +33,22 @@ Supported filter specs:
 """
 
 __all__ = [
-    "FilterSpec",
-    "BlobNoneFilter",
     "BlobLimitFilter",
-    "TreeDepthFilter",
-    "SparseOidFilter",
+    "BlobNoneFilter",
     "CombineFilter",
-    "parse_filter_spec",
+    "FilterSpec",
+    "SparseOidFilter",
+    "TreeDepthFilter",
     "filter_pack_objects",
+    "parse_filter_spec",
 ]
 
 from abc import ABC, abstractmethod
 from typing import TYPE_CHECKING
 
 if TYPE_CHECKING:
-    from .objects import ObjectID, ShaFile
+    from .object_store import BaseObjectStore
+    from .objects import ObjectID
 
 
 class FilterSpec(ABC):
@@ -103,7 +104,7 @@ class BlobNoneFilter(FilterSpec):
         return "blob:none"
 
     def __repr__(self) -> str:
-        return f"BlobNoneFilter()"
+        return "BlobNoneFilter()"
 
 
 class BlobLimitFilter(FilterSpec):
@@ -169,15 +170,69 @@ class TreeDepthFilter(FilterSpec):
 
 
 class SparseOidFilter(FilterSpec):
-    """Filter that uses a sparse specification from an object."""
+    """Filter that uses a sparse specification from an object.
 
-    def __init__(self, oid: "ObjectID") -> None:
+    This filter reads sparse-checkout patterns from a blob object and uses them
+    to determine which paths should be included in the partial clone.
+    """
+
+    def __init__(
+        self, oid: "ObjectID", object_store: "BaseObjectStore | None" = None
+    ) -> None:
         """Initialize sparse OID filter.
 
         Args:
-            oid: Object ID of the sparse specification
+            oid: Object ID of the sparse specification blob
+            object_store: Optional object store to load the sparse patterns from
         """
         self.oid = oid
+        self._patterns: list[tuple[str, bool, bool, bool]] | None = None
+        self._object_store = object_store
+
+    def _load_patterns(self) -> None:
+        """Load and parse sparse patterns from the blob."""
+        if self._patterns is not None:
+            return
+
+        if self._object_store is None:
+            raise ValueError("Cannot load sparse patterns without an object store")
+
+        from .objects import Blob
+        from .sparse_patterns import parse_sparse_patterns
+
+        try:
+            obj = self._object_store[self.oid]
+        except KeyError:
+            raise ValueError(
+                f"Sparse specification blob {self.oid.hex() if isinstance(self.oid, bytes) else self.oid} not found"
+            )
+
+        if not isinstance(obj, Blob):
+            raise ValueError(
+                f"Sparse specification {self.oid.hex() if isinstance(self.oid, bytes) else self.oid} is not a blob"
+            )
+
+        # Parse the blob content as sparse patterns
+        lines = obj.data.decode("utf-8").splitlines()
+        self._patterns = parse_sparse_patterns(lines)
+
+    def should_include_path(self, path: str) -> bool:
+        """Determine if a path should be included based on sparse patterns.
+
+        Args:
+            path: Path to check (e.g., 'src/file.py')
+
+        Returns:
+            True if the path matches the sparse patterns, False otherwise
+        """
+        self._load_patterns()
+        from .sparse_patterns import match_sparse_patterns
+
+        # Determine if path is a directory based on whether it ends with '/'
+        path_is_dir = path.endswith("/")
+        path_str = path.rstrip("/")
+
+        return match_sparse_patterns(path_str, self._patterns, path_is_dir=path_is_dir)
 
     def should_include_blob(self, blob_size: int) -> bool:
         """Include all blobs (sparse filtering is path-based, not size-based)."""
@@ -251,11 +306,14 @@ def _parse_size(size_str: str) -> int:
             raise ValueError(f"Invalid size specification: {size_str}")
 
 
-def parse_filter_spec(spec: str | bytes) -> FilterSpec:
+def parse_filter_spec(
+    spec: str | bytes, object_store: "BaseObjectStore | None" = None
+) -> FilterSpec:
     """Parse a filter specification string.
 
     Args:
         spec: Filter specification (e.g., 'blob:none', 'blob:limit=1m')
+        object_store: Optional object store for loading sparse specifications
 
     Returns:
         Parsed FilterSpec object
@@ -291,7 +349,9 @@ def parse_filter_spec(spec: str | bytes) -> FilterSpec:
         try:
             limit = _parse_size(limit_str)
             if limit < 0:
-                raise ValueError(f"blob:limit size must be non-negative, got {limit_str}")
+                raise ValueError(
+                    f"blob:limit size must be non-negative, got {limit_str}"
+                )
             return BlobLimitFilter(limit)
         except ValueError as e:
             raise ValueError(f"Invalid blob:limit specification: {e}")
@@ -309,7 +369,9 @@ def parse_filter_spec(spec: str | bytes) -> FilterSpec:
     elif spec.startswith("sparse:oid="):
         oid_str = spec[11:]  # len('sparse:oid=') == 11
         if not oid_str:
-            raise ValueError("sparse:oid requires an object ID (e.g., sparse:oid=abc123...)")
+            raise ValueError(
+                "sparse:oid requires an object ID (e.g., sparse:oid=abc123...)"
+            )
         # Validate OID format (should be 40 hex chars for SHA-1 or 64 for SHA-256)
         if len(oid_str) not in (40, 64):
             raise ValueError(
@@ -320,19 +382,25 @@ def parse_filter_spec(spec: str | bytes) -> FilterSpec:
             oid = oid_str.encode("ascii")
             int(oid_str, 16)  # Validate it's valid hex
         except (ValueError, UnicodeEncodeError):
-            raise ValueError(f"sparse:oid must be a hexadecimal object ID, got: {oid_str}")
-        return SparseOidFilter(oid)
+            raise ValueError(
+                f"sparse:oid must be a hexadecimal object ID, got: {oid_str}"
+            )
+        return SparseOidFilter(oid, object_store=object_store)
     elif spec.startswith("combine:"):
         filter_str = spec[8:]  # len('combine:') == 8
         if not filter_str:
-            raise ValueError("combine filter requires at least one filter (e.g., combine:blob:none+tree:0)")
+            raise ValueError(
+                "combine filter requires at least one filter (e.g., combine:blob:none+tree:0)"
+            )
         filter_specs = filter_str.split("+")
         if len(filter_specs) < 2:
             raise ValueError(
                 "combine filter requires at least two filters separated by '+'"
             )
         try:
-            filters = [parse_filter_spec(f) for f in filter_specs]
+            filters = [
+                parse_filter_spec(f, object_store=object_store) for f in filter_specs
+            ]
         except ValueError as e:
             raise ValueError(f"Invalid filter in combine specification: {e}")
         return CombineFilter(filters)
@@ -367,7 +435,7 @@ def filter_pack_objects(
         This function currently supports blob size filtering. Tree depth filtering
         requires additional path/depth tracking which is not yet implemented.
     """
-    from .objects import Blob, Tree, Commit, Tag
+    from .objects import Blob, Commit, Tag, Tree
 
     filtered_ids = []
 

+ 2 - 2
dulwich/server.py

@@ -100,7 +100,7 @@ from .errors import (
 from .object_store import MissingObjectFinder, PackBasedObjectStore, find_shallow
 from .objects import Commit, ObjectID, Tree, valid_hexsha
 from .pack import ObjectContainer, write_pack_from_container
-from .partial_clone import FilterSpec, filter_pack_objects, parse_filter_spec
+from .partial_clone import filter_pack_objects, parse_filter_spec
 from .protocol import (
     CAPABILITIES_REF,
     CAPABILITY_AGENT,
@@ -491,7 +491,7 @@ class UploadPackHandler(PackHandler):
         filter_spec_bytes = find_capability(caps, CAPABILITY_FILTER)
         if filter_spec_bytes:
             try:
-                self.filter_spec = parse_filter_spec(filter_spec_bytes)
+                self.filter_spec = parse_filter_spec(filter_spec_bytes, object_store=self.repo.object_store)
             except ValueError as e:
                 raise GitProtocolError(f"Invalid filter specification: {e}")
 

+ 115 - 17
tests/test_partial_clone.py

@@ -327,6 +327,105 @@ class SparseOidFilterTests(TestCase):
         self.assertIn("SparseOidFilter", repr(filter_spec))
         self.assertIn("1234567890abcdef1234567890abcdef12345678", repr(filter_spec))
 
+    def test_load_patterns_from_blob(self):
+        """Test loading sparse patterns from a blob object."""
+        from dulwich.object_store import MemoryObjectStore
+        from dulwich.objects import Blob
+
+        # Create a sparse patterns blob
+        patterns = b"*.txt\n!*.log\n/src/\n"
+        blob = Blob.from_string(patterns)
+
+        object_store = MemoryObjectStore()
+        object_store.add_object(blob)
+
+        filter_spec = SparseOidFilter(blob.id, object_store=object_store)
+        filter_spec._load_patterns()
+
+        # Verify patterns were loaded
+        self.assertIsNotNone(filter_spec._patterns)
+        self.assertEqual(3, len(filter_spec._patterns))
+
+    def test_load_patterns_missing_blob(self):
+        """Test error when sparse blob is not found."""
+        from dulwich.object_store import MemoryObjectStore
+
+        oid = b"1234567890abcdef1234567890abcdef12345678"
+        object_store = MemoryObjectStore()
+
+        filter_spec = SparseOidFilter(oid, object_store=object_store)
+
+        with self.assertRaises(ValueError) as cm:
+            filter_spec._load_patterns()
+        self.assertIn("not found", str(cm.exception))
+
+    def test_load_patterns_not_a_blob(self):
+        """Test error when sparse OID points to non-blob object."""
+        from dulwich.object_store import MemoryObjectStore
+        from dulwich.objects import Tree
+
+        tree = Tree()
+        object_store = MemoryObjectStore()
+        object_store.add_object(tree)
+
+        filter_spec = SparseOidFilter(tree.id, object_store=object_store)
+
+        with self.assertRaises(ValueError) as cm:
+            filter_spec._load_patterns()
+        self.assertIn("not a blob", str(cm.exception))
+
+    def test_load_patterns_without_object_store(self):
+        """Test error when trying to load patterns without object store."""
+        oid = b"1234567890abcdef1234567890abcdef12345678"
+        filter_spec = SparseOidFilter(oid)
+
+        with self.assertRaises(ValueError) as cm:
+            filter_spec._load_patterns()
+        self.assertIn("without an object store", str(cm.exception))
+
+    def test_should_include_path_matching(self):
+        """Test path matching with sparse patterns."""
+        from dulwich.object_store import MemoryObjectStore
+        from dulwich.objects import Blob
+
+        # Create a sparse patterns blob: include *.txt files
+        patterns = b"*.txt\n"
+        blob = Blob.from_string(patterns)
+
+        object_store = MemoryObjectStore()
+        object_store.add_object(blob)
+
+        filter_spec = SparseOidFilter(blob.id, object_store=object_store)
+
+        # .txt files should be included
+        self.assertTrue(filter_spec.should_include_path("readme.txt"))
+        self.assertTrue(filter_spec.should_include_path("docs/file.txt"))
+
+        # Other files should not be included
+        self.assertFalse(filter_spec.should_include_path("readme.md"))
+        self.assertFalse(filter_spec.should_include_path("script.py"))
+
+    def test_should_include_path_negation(self):
+        """Test path matching with negation patterns."""
+        from dulwich.object_store import MemoryObjectStore
+        from dulwich.objects import Blob
+
+        # Include all .txt files except logs
+        patterns = b"*.txt\n!*.log\n"
+        blob = Blob.from_string(patterns)
+
+        object_store = MemoryObjectStore()
+        object_store.add_object(blob)
+
+        filter_spec = SparseOidFilter(blob.id, object_store=object_store)
+
+        # .txt files should be included
+        self.assertTrue(filter_spec.should_include_path("readme.txt"))
+
+        # But .log files should be excluded (even though they end in .txt pattern)
+        # Note: This depends on pattern order and sparse_patterns implementation
+        self.assertFalse(filter_spec.should_include_path("debug.log"))
+
 
 class CombineFilterTests(TestCase):
     """Test CombineFilter class."""
@@ -463,10 +562,12 @@ class FilterPackObjectsTests(TestCase):
         ]
 
         # Combine blob:limit with another filter
-        filter_spec = CombineFilter([
-            BlobLimitFilter(100),
-            BlobNoneFilter(),  # This will exclude ALL blobs
-        ])
+        filter_spec = CombineFilter(
+            [
+                BlobLimitFilter(100),
+                BlobNoneFilter(),  # This will exclude ALL blobs
+            ]
+        )
 
         filtered = filter_pack_objects(self.store, object_ids, filter_spec)
 
@@ -488,6 +589,7 @@ class PartialCloneIntegrationTests(TestCase):
     def _cleanup(self):
         """Clean up test repository."""
         import shutil
+
         if os.path.exists(self.repo_dir):
             shutil.rmtree(self.repo_dir)
 
@@ -516,9 +618,7 @@ class PartialCloneIntegrationTests(TestCase):
 
         # Apply blob:none filter
         filter_spec = BlobNoneFilter()
-        filtered = filter_pack_objects(
-            self.repo.object_store, object_ids, filter_spec
-        )
+        filtered = filter_pack_objects(self.repo.object_store, object_ids, filter_spec)
 
         # Verify blobs are excluded
         self.assertNotIn(blob1.id, filtered)
@@ -562,9 +662,7 @@ class PartialCloneIntegrationTests(TestCase):
         ]
 
         filter_spec = BlobLimitFilter(100)
-        filtered = filter_pack_objects(
-            self.repo.object_store, object_ids, filter_spec
-        )
+        filtered = filter_pack_objects(self.repo.object_store, object_ids, filter_spec)
 
         # Small and medium should be included
         self.assertIn(small_blob.id, filtered)
@@ -594,15 +692,15 @@ class PartialCloneIntegrationTests(TestCase):
 
         # Combine: limit to 500 bytes, but also apply blob:none
         # This should exclude ALL blobs (blob:none overrides limit)
-        filter_spec = CombineFilter([
-            BlobLimitFilter(500),
-            BlobNoneFilter(),
-        ])
+        filter_spec = CombineFilter(
+            [
+                BlobLimitFilter(500),
+                BlobNoneFilter(),
+            ]
+        )
 
         object_ids = [blob1.id, blob2.id, tree.id, commit.id]
-        filtered = filter_pack_objects(
-            self.repo.object_store, object_ids, filter_spec
-        )
+        filtered = filter_pack_objects(self.repo.object_store, object_ids, filter_spec)
 
         # All blobs excluded
         self.assertNotIn(blob1.id, filtered)