Selaa lähdekoodia

sparse_patterns: Add more typing

Jelmer Vernooij 2 kuukautta sitten
vanhempi
commit
db17b81b6f
1 muutettua tiedostoa jossa 62 lisäystä ja 17 poistoa
  1. 62 17
      dulwich/sparse_patterns.py

+ 62 - 17
dulwich/sparse_patterns.py

@@ -23,8 +23,11 @@
 
 
 import os
 import os
 from fnmatch import fnmatch
 from fnmatch import fnmatch
+from typing import Any, Union, cast
 
 
 from .file import ensure_dir_exists
 from .file import ensure_dir_exists
+from .index import IndexEntry
+from .repo import Repo
 
 
 
 
 class SparseCheckoutConflictError(Exception):
 class SparseCheckoutConflictError(Exception):
@@ -35,7 +38,9 @@ class BlobNotFoundError(Exception):
     """Raised when a requested blob is not found in the repository's object store."""
     """Raised when a requested blob is not found in the repository's object store."""
 
 
 
 
-def determine_included_paths(repo, lines, cone):
+def determine_included_paths(
+    repo: Union[str, Repo], lines: list[str], cone: bool
+) -> set[str]:
     """Determine which paths in the index should be included based on either
     """Determine which paths in the index should be included based on either
     a full-pattern match or a cone-mode approach.
     a full-pattern match or a cone-mode approach.
 
 
@@ -53,7 +58,7 @@ def determine_included_paths(repo, lines, cone):
         return compute_included_paths_full(repo, lines)
         return compute_included_paths_full(repo, lines)
 
 
 
 
-def compute_included_paths_full(repo, lines):
+def compute_included_paths_full(repo: Union[str, Repo], lines: list[str]) -> set[str]:
     """Use .gitignore-style parsing and matching to determine included paths.
     """Use .gitignore-style parsing and matching to determine included paths.
 
 
     Each file path in the index is tested against the parsed sparse patterns.
     Each file path in the index is tested against the parsed sparse patterns.
@@ -67,7 +72,13 @@ def compute_included_paths_full(repo, lines):
       A set of included path strings.
       A set of included path strings.
     """
     """
     parsed = parse_sparse_patterns(lines)
     parsed = parse_sparse_patterns(lines)
-    index = repo.open_index()
+    if isinstance(repo, str):
+        from .porcelain import open_repo
+
+        repo_obj = open_repo(repo)
+    else:
+        repo_obj = repo
+    index = repo_obj.open_index()
     included = set()
     included = set()
     for path_bytes, entry in index.items():
     for path_bytes, entry in index.items():
         path_str = path_bytes.decode("utf-8")
         path_str = path_bytes.decode("utf-8")
@@ -77,7 +88,7 @@ def compute_included_paths_full(repo, lines):
     return included
     return included
 
 
 
 
-def compute_included_paths_cone(repo, lines):
+def compute_included_paths_cone(repo: Union[str, Repo], lines: list[str]) -> set[str]:
     """Implement a simplified 'cone' approach for sparse-checkout.
     """Implement a simplified 'cone' approach for sparse-checkout.
 
 
     By default, this can include top-level files, exclude all subdirectories,
     By default, this can include top-level files, exclude all subdirectories,
@@ -108,7 +119,13 @@ def compute_included_paths_cone(repo, lines):
             if d:
             if d:
                 reinclude_dirs.add(d)
                 reinclude_dirs.add(d)
 
 
-    index = repo.open_index()
+    if isinstance(repo, str):
+        from .porcelain import open_repo
+
+        repo_obj = open_repo(repo)
+    else:
+        repo_obj = repo
+    index = repo_obj.open_index()
     included = set()
     included = set()
 
 
     for path_bytes, entry in index.items():
     for path_bytes, entry in index.items():
@@ -134,7 +151,9 @@ def compute_included_paths_cone(repo, lines):
     return included
     return included
 
 
 
 
-def apply_included_paths(repo, included_paths, force=False):
+def apply_included_paths(
+    repo: Union[str, Repo], included_paths: set[str], force: bool = False
+) -> None:
     """Apply the sparse-checkout inclusion set to the index and working tree.
     """Apply the sparse-checkout inclusion set to the index and working tree.
 
 
     This function updates skip-worktree bits in the index based on whether each
     This function updates skip-worktree bits in the index based on whether each
@@ -150,10 +169,18 @@ def apply_included_paths(repo, included_paths, force=False):
     Returns:
     Returns:
       None
       None
     """
     """
-    index = repo.open_index()
-    normalizer = repo.get_blob_normalizer()
+    if isinstance(repo, str):
+        from .porcelain import open_repo
+
+        repo_obj = open_repo(repo)
+    else:
+        repo_obj = repo
+    index = repo_obj.open_index()
+    if not hasattr(repo_obj, "get_blob_normalizer"):
+        raise ValueError("Repository must support get_blob_normalizer")
+    normalizer = repo_obj.get_blob_normalizer()
 
 
-    def local_modifications_exist(full_path, index_entry):
+    def local_modifications_exist(full_path: str, index_entry: IndexEntry) -> bool:
         if not os.path.exists(full_path):
         if not os.path.exists(full_path):
             return False
             return False
         try:
         try:
@@ -162,14 +189,21 @@ def apply_included_paths(repo, included_paths, force=False):
         except OSError:
         except OSError:
             return True
             return True
         try:
         try:
-            blob = repo.object_store[index_entry.sha]
+            blob_obj = repo_obj.object_store[index_entry.sha]
         except KeyError:
         except KeyError:
             return True
             return True
         norm_data = normalizer.checkin_normalize(disk_data, full_path)
         norm_data = normalizer.checkin_normalize(disk_data, full_path)
-        return norm_data != blob.data
+        from .objects import Blob
+
+        if not isinstance(blob_obj, Blob):
+            return True
+        return norm_data != blob_obj.data
 
 
     # 1) Update skip-worktree bits
     # 1) Update skip-worktree bits
+
     for path_bytes, entry in list(index.items()):
     for path_bytes, entry in list(index.items()):
+        if not isinstance(entry, IndexEntry):
+            continue  # Skip conflicted entries
         path_str = path_bytes.decode("utf-8")
         path_str = path_bytes.decode("utf-8")
         if path_str in included_paths:
         if path_str in included_paths:
             entry.set_skip_worktree(False)
             entry.set_skip_worktree(False)
@@ -180,7 +214,11 @@ def apply_included_paths(repo, included_paths, force=False):
 
 
     # 2) Reflect changes in the working tree
     # 2) Reflect changes in the working tree
     for path_bytes, entry in list(index.items()):
     for path_bytes, entry in list(index.items()):
-        full_path = os.path.join(repo.path, path_bytes.decode("utf-8"))
+        if not isinstance(entry, IndexEntry):
+            continue  # Skip conflicted entries
+        if not hasattr(repo_obj, "path"):
+            raise ValueError("Repository must have a path attribute")
+        full_path = os.path.join(cast(Any, repo_obj).path, path_bytes.decode("utf-8"))
 
 
         if entry.skip_worktree:
         if entry.skip_worktree:
             # Excluded => remove if safe
             # Excluded => remove if safe
@@ -200,17 +238,20 @@ def apply_included_paths(repo, included_paths, force=False):
             # Included => materialize if missing
             # Included => materialize if missing
             if not os.path.exists(full_path):
             if not os.path.exists(full_path):
                 try:
                 try:
-                    blob = repo.object_store[entry.sha]
+                    blob = repo_obj.object_store[entry.sha]
                 except KeyError:
                 except KeyError:
                     raise BlobNotFoundError(
                     raise BlobNotFoundError(
-                        f"Blob {entry.sha} not found for {path_bytes}."
+                        f"Blob {entry.sha.hex()} not found for {path_bytes.decode('utf-8')}."
                     )
                     )
                 ensure_dir_exists(os.path.dirname(full_path))
                 ensure_dir_exists(os.path.dirname(full_path))
+                from .objects import Blob
+
                 with open(full_path, "wb") as f:
                 with open(full_path, "wb") as f:
-                    f.write(blob.data)
+                    if isinstance(blob, Blob):
+                        f.write(blob.data)
 
 
 
 
-def parse_sparse_patterns(lines):
+def parse_sparse_patterns(lines: list[str]) -> list[tuple[str, bool, bool, bool]]:
     """Parse pattern lines from a sparse-checkout file (.git/info/sparse-checkout).
     """Parse pattern lines from a sparse-checkout file (.git/info/sparse-checkout).
 
 
     This simplified parser:
     This simplified parser:
@@ -259,7 +300,11 @@ def parse_sparse_patterns(lines):
     return results
     return results
 
 
 
 
-def match_gitignore_patterns(path_str, parsed_patterns, path_is_dir=False):
+def match_gitignore_patterns(
+    path_str: str,
+    parsed_patterns: list[tuple[str, bool, bool, bool]],
+    path_is_dir: bool = False,
+) -> bool:
     """Check whether a path is included based on .gitignore-style patterns.
     """Check whether a path is included based on .gitignore-style patterns.
 
 
     This is a simplified approach that:
     This is a simplified approach that: