Jelmer Vernooij 2 недель назад
Родитель
Сommit
f9a6e0474f
3 измененных файлов с 11 добавлено и 18 удалено
  1. 6 14
      dulwich/object_filters.py
  2. 4 3
      dulwich/server.py
  3. 1 1
      tests/test_object_filters.py

+ 6 - 14
dulwich/object_filters.py

@@ -51,6 +51,8 @@ __all__ = [
 from abc import ABC, abstractmethod
 from typing import TYPE_CHECKING
 
+from .objects import S_ISGITLINK, Blob, Commit, ObjectID, Tag, Tree, valid_hexsha
+
 if TYPE_CHECKING:
     from collections.abc import Callable
 
@@ -207,7 +209,6 @@ class SparseOidFilter(FilterSpec):
         if self._object_store is None:
             raise ValueError("Cannot load sparse patterns without an object store")
 
-        from .objects import Blob
         from .sparse_patterns import parse_sparse_patterns
 
         try:
@@ -242,6 +243,7 @@ class SparseOidFilter(FilterSpec):
         path_is_dir = path.endswith("/")
         path_str = path.rstrip("/")
 
+        assert self._patterns is not None  # _load_patterns ensures this
         return match_sparse_patterns(path_str, self._patterns, path_is_dir=path_is_dir)
 
     def should_include_blob(self, blob_size: int) -> bool:
@@ -385,18 +387,12 @@ def parse_filter_spec(
                 "sparse:oid requires an object ID (e.g., sparse:oid=abc123...)"
             )
         # Validate OID format (should be 40 hex chars for SHA-1 or 64 for SHA-256)
-        if len(oid_str) not in (40, 64):
+        if not valid_hexsha(oid_str):
             raise ValueError(
                 f"sparse:oid requires a valid object ID (40 or 64 hex chars), got {len(oid_str)} chars"
             )
-        try:
-            # Convert to bytes and validate hex
-            oid = oid_str.encode("ascii")
-            int(oid_str, 16)  # Validate it's valid hex
-        except (ValueError, UnicodeEncodeError):
-            raise ValueError(
-                f"sparse:oid must be a hexadecimal object ID, got: {oid_str}"
-            )
+
+        oid: ObjectID = ObjectID(oid_str.encode("ascii"))
         return SparseOidFilter(oid, object_store=object_store)
     elif spec.startswith("combine:"):
         filter_str = spec[8:]  # len('combine:') == 8
@@ -447,8 +443,6 @@ def filter_pack_objects(
         This function currently supports blob size filtering. Tree depth filtering
         requires additional path/depth tracking which is not yet implemented.
     """
-    from .objects import Blob, Commit, Tag, Tree
-
     filtered_ids = []
 
     for oid in object_ids:
@@ -502,8 +496,6 @@ def filter_pack_objects_with_paths(
     """
     import stat
 
-    from .objects import S_ISGITLINK, Blob, Commit, Tag, Tree
-
     included_objects: set[ObjectID] = set()
     # Track (oid, path, depth) tuples to process
     to_process: list[tuple[ObjectID, str, int]] = []

+ 4 - 3
dulwich/server.py

@@ -99,6 +99,7 @@ from .errors import (
 )
 from .object_filters import (
     CombineFilter,
+    FilterSpec,
     SparseOidFilter,
     TreeDepthFilter,
     filter_pack_objects,
@@ -457,7 +458,7 @@ class UploadPackHandler(PackHandler):
         # data (such as side-band, see the progress method here).
         self._processing_have_lines = False
         # Filter specification for partial clone support
-        self.filter_spec = None
+        self.filter_spec: FilterSpec | None = None
 
     def capabilities(self) -> list[bytes]:
         """Return the list of capabilities supported by upload-pack.
@@ -632,7 +633,7 @@ class UploadPackHandler(PackHandler):
 
             # Use path-aware filtering for tree depth and sparse:oid filters
             # Check if filter requires path tracking
-            def needs_path_tracking(filter_spec):
+            def needs_path_tracking(filter_spec: FilterSpec) -> bool:
                 if isinstance(filter_spec, (TreeDepthFilter, SparseOidFilter)):
                     return True
                 if isinstance(filter_spec, CombineFilter):
@@ -827,7 +828,7 @@ class _ProtocolGraphWalker:
 
     def __init__(
         self,
-        handler: PackHandler,
+        handler: "UploadPackHandler",
         object_store: ObjectContainer,
         get_peeled: Callable[[bytes], ObjectID | None],
         get_symrefs: Callable[[], dict[Ref, Ref]],

+ 1 - 1
tests/test_object_filters.py

@@ -164,7 +164,7 @@ class ParseFilterSpecTests(TestCase):
         """Test that non-hex OID raises ValueError."""
         with self.assertRaises(ValueError) as cm:
             parse_filter_spec("sparse:oid=" + "x" * 40)
-        self.assertIn("hexadecimal", str(cm.exception))
+        self.assertIn("valid object ID", str(cm.exception))
 
     def test_parse_combine_single_filter(self):
         """Test that combine with single filter raises ValueError."""