Quellcode durchsuchen

sparse_patterns: Add more typing

Jelmer Vernooij vor 1 Monat
Ursprung
Commit
db17b81b6f
1 geänderte Dateien mit 62 neuen und 17 gelöschten Zeilen
  1. 62 17
      dulwich/sparse_patterns.py

+ 62 - 17
dulwich/sparse_patterns.py

@@ -23,8 +23,11 @@
 
 import os
 from fnmatch import fnmatch
+from typing import Any, Union, cast
 
 from .file import ensure_dir_exists
+from .index import IndexEntry
+from .repo import Repo
 
 
 class SparseCheckoutConflictError(Exception):
@@ -35,7 +38,9 @@ class BlobNotFoundError(Exception):
     """Raised when a requested blob is not found in the repository's object store."""
 
 
-def determine_included_paths(repo, lines, cone):
+def determine_included_paths(
+    repo: Union[str, Repo], lines: list[str], cone: bool
+) -> set[str]:
     """Determine which paths in the index should be included based on either
     a full-pattern match or a cone-mode approach.
 
@@ -53,7 +58,7 @@ def determine_included_paths(repo, lines, cone):
         return compute_included_paths_full(repo, lines)
 
 
-def compute_included_paths_full(repo, lines):
+def compute_included_paths_full(repo: Union[str, Repo], lines: list[str]) -> set[str]:
     """Use .gitignore-style parsing and matching to determine included paths.
 
     Each file path in the index is tested against the parsed sparse patterns.
@@ -67,7 +72,13 @@ def compute_included_paths_full(repo, lines):
       A set of included path strings.
     """
     parsed = parse_sparse_patterns(lines)
-    index = repo.open_index()
+    if isinstance(repo, str):
+        from .porcelain import open_repo
+
+        repo_obj = open_repo(repo)
+    else:
+        repo_obj = repo
+    index = repo_obj.open_index()
     included = set()
     for path_bytes, entry in index.items():
         path_str = path_bytes.decode("utf-8")
@@ -77,7 +88,7 @@ def compute_included_paths_full(repo, lines):
     return included
 
 
-def compute_included_paths_cone(repo, lines):
+def compute_included_paths_cone(repo: Union[str, Repo], lines: list[str]) -> set[str]:
     """Implement a simplified 'cone' approach for sparse-checkout.
 
     By default, this can include top-level files, exclude all subdirectories,
@@ -108,7 +119,13 @@ def compute_included_paths_cone(repo, lines):
             if d:
                 reinclude_dirs.add(d)
 
-    index = repo.open_index()
+    if isinstance(repo, str):
+        from .porcelain import open_repo
+
+        repo_obj = open_repo(repo)
+    else:
+        repo_obj = repo
+    index = repo_obj.open_index()
     included = set()
 
     for path_bytes, entry in index.items():
@@ -134,7 +151,9 @@ def compute_included_paths_cone(repo, lines):
     return included
 
 
-def apply_included_paths(repo, included_paths, force=False):
+def apply_included_paths(
+    repo: Union[str, Repo], included_paths: set[str], force: bool = False
+) -> None:
     """Apply the sparse-checkout inclusion set to the index and working tree.
 
     This function updates skip-worktree bits in the index based on whether each
@@ -150,10 +169,18 @@ def apply_included_paths(repo, included_paths, force=False):
     Returns:
       None
     """
-    index = repo.open_index()
-    normalizer = repo.get_blob_normalizer()
+    if isinstance(repo, str):
+        from .porcelain import open_repo
+
+        repo_obj = open_repo(repo)
+    else:
+        repo_obj = repo
+    index = repo_obj.open_index()
+    if not hasattr(repo_obj, "get_blob_normalizer"):
+        raise ValueError("Repository must support get_blob_normalizer")
+    normalizer = repo_obj.get_blob_normalizer()
 
-    def local_modifications_exist(full_path, index_entry):
+    def local_modifications_exist(full_path: str, index_entry: IndexEntry) -> bool:
         if not os.path.exists(full_path):
             return False
         try:
@@ -162,14 +189,21 @@ def apply_included_paths(repo, included_paths, force=False):
         except OSError:
             return True
         try:
-            blob = repo.object_store[index_entry.sha]
+            blob_obj = repo_obj.object_store[index_entry.sha]
         except KeyError:
             return True
         norm_data = normalizer.checkin_normalize(disk_data, full_path)
-        return norm_data != blob.data
+        from .objects import Blob
+
+        if not isinstance(blob_obj, Blob):
+            return True
+        return norm_data != blob_obj.data
 
     # 1) Update skip-worktree bits
+
     for path_bytes, entry in list(index.items()):
+        if not isinstance(entry, IndexEntry):
+            continue  # Skip conflicted entries
         path_str = path_bytes.decode("utf-8")
         if path_str in included_paths:
             entry.set_skip_worktree(False)
@@ -180,7 +214,11 @@ def apply_included_paths(repo, included_paths, force=False):
 
     # 2) Reflect changes in the working tree
     for path_bytes, entry in list(index.items()):
-        full_path = os.path.join(repo.path, path_bytes.decode("utf-8"))
+        if not isinstance(entry, IndexEntry):
+            continue  # Skip conflicted entries
+        if not hasattr(repo_obj, "path"):
+            raise ValueError("Repository must have a path attribute")
+        full_path = os.path.join(cast(Any, repo_obj).path, path_bytes.decode("utf-8"))
 
         if entry.skip_worktree:
             # Excluded => remove if safe
@@ -200,17 +238,20 @@ def apply_included_paths(repo, included_paths, force=False):
             # Included => materialize if missing
             if not os.path.exists(full_path):
                 try:
-                    blob = repo.object_store[entry.sha]
+                    blob = repo_obj.object_store[entry.sha]
                 except KeyError:
                     raise BlobNotFoundError(
-                        f"Blob {entry.sha} not found for {path_bytes}."
+                        f"Blob {entry.sha.hex()} not found for {path_bytes.decode('utf-8')}."
                     )
                 ensure_dir_exists(os.path.dirname(full_path))
+                from .objects import Blob
+
                 with open(full_path, "wb") as f:
-                    f.write(blob.data)
+                    if isinstance(blob, Blob):
+                        f.write(blob.data)
 
 
-def parse_sparse_patterns(lines):
+def parse_sparse_patterns(lines: list[str]) -> list[tuple[str, bool, bool, bool]]:
     """Parse pattern lines from a sparse-checkout file (.git/info/sparse-checkout).
 
     This simplified parser:
@@ -259,7 +300,11 @@ def parse_sparse_patterns(lines):
     return results
 
 
-def match_gitignore_patterns(path_str, parsed_patterns, path_is_dir=False):
+def match_gitignore_patterns(
+    path_str: str,
+    parsed_patterns: list[tuple[str, bool, bool, bool]],
+    path_is_dir: bool = False,
+) -> bool:
     """Check whether a path is included based on .gitignore-style patterns.
 
     This is a simplified approach that: