Browse Source

various minor type improvements

Jelmer Vernooij 1 tháng trước cách đây
mục cha
commit
516547bda3
3 tập tin đã thay đổi với 64 bổ sung64 xóa
  1. 37 9
      dulwich/porcelain.py
  2. 16 46
      dulwich/sparse_patterns.py
  3. 11 9
      tests/test_sparse_patterns.py

+ 37 - 9
dulwich/porcelain.py

@@ -1,4 +1,4 @@
-# porcelain.py -- Porcelain-like layer on top of Dulwich
+# e porcelain.py -- Porcelain-like layer on top of Dulwich
 # Copyright (C) 2013 Jelmer Vernooij <jelmer@jelmer.uk>
 #
 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
@@ -86,11 +86,11 @@ import sys
 import time
 from collections import namedtuple
 from collections.abc import Iterator
-from contextlib import closing, contextmanager
+from contextlib import AbstractContextManager, closing, contextmanager
 from dataclasses import dataclass
 from io import BytesIO, RawIOBase
 from pathlib import Path
-from typing import Optional, Union
+from typing import Optional, TypeVar, Union, overload
 
 from . import replace_me
 from .archive import tar_stream
@@ -174,6 +174,9 @@ from .sparse_patterns import (
 # Module level tuple definition for status output
 GitStatus = namedtuple("GitStatus", "staged unstaged untracked")
 
+# TypeVar for preserving BaseRepo subclass types
+T = TypeVar("T", bound="BaseRepo")
+
 
 @dataclass
 class CountObjectsResult:
@@ -310,10 +313,22 @@ def get_user_timezones():
     return author_timezone, commit_timezone
 
 
-def open_repo(path_or_repo: Union[str, os.PathLike, BaseRepo]):
+@overload
+def open_repo(path_or_repo: T) -> AbstractContextManager[T]: ...
+
+
+@overload
+def open_repo(
+    path_or_repo: Union[str, os.PathLike],
+) -> AbstractContextManager[Repo]: ...
+
+
+def open_repo(
+    path_or_repo: Union[str, os.PathLike, T],
+) -> AbstractContextManager[Union[T, Repo]]:
     """Open an argument that can be a repository or a path for a repository."""
     if isinstance(path_or_repo, BaseRepo):
-        return path_or_repo
+        return _noop_context_manager(path_or_repo)
     return Repo(path_or_repo)
 
 
@@ -323,7 +338,19 @@ def _noop_context_manager(obj):
     yield obj
 
 
-def open_repo_closing(path_or_repo: Union[str, os.PathLike, BaseRepo]):
+@overload
+def open_repo_closing(path_or_repo: T) -> AbstractContextManager[T]: ...
+
+
+@overload
+def open_repo_closing(
+    path_or_repo: Union[str, os.PathLike],
+) -> AbstractContextManager[Repo]: ...
+
+
+def open_repo_closing(
+    path_or_repo: Union[str, os.PathLike, T],
+) -> AbstractContextManager[Union[T, Repo]]:
     """Open an argument that can be a repository or a path for a repository.
     returns a context manager that will close the repo on exit if the argument
     is a path, else does nothing if the argument is a repo.
@@ -714,7 +741,7 @@ def clone(
     return repo
 
 
-def add(repo: Union[str, os.PathLike, BaseRepo] = ".", paths=None):
+def add(repo: Union[str, os.PathLike, Repo] = ".", paths=None):
     """Add files to the staging area.
 
     Args:
@@ -949,7 +976,7 @@ rm = remove
 
 
 def mv(
-    repo: Union[str, os.PathLike, BaseRepo],
+    repo: Union[str, os.PathLike, Repo],
     source: Union[str, bytes, os.PathLike],
     destination: Union[str, bytes, os.PathLike],
     force: bool = False,
@@ -3483,7 +3510,8 @@ def sparse_checkout(
             repo_obj.set_sparse_checkout_patterns(patterns)
 
         # --- 2) Determine the set of included paths ---
-        included_paths = determine_included_paths(repo_obj, lines, cone)
+        index = repo_obj.open_index()
+        included_paths = determine_included_paths(index, lines, cone)
 
         # --- 3) Apply those results to the index & working tree ---
         try:

+ 16 - 46
dulwich/sparse_patterns.py

@@ -23,10 +23,10 @@
 
 import os
 from fnmatch import fnmatch
-from typing import Any, Union, cast
 
 from .file import ensure_dir_exists
-from .index import IndexEntry
+from .index import Index, IndexEntry
+from .objects import Blob
 from .repo import Repo
 
 
@@ -38,14 +38,12 @@ class BlobNotFoundError(Exception):
     """Raised when a requested blob is not found in the repository's object store."""
 
 
-def determine_included_paths(
-    repo: Union[str, Repo], lines: list[str], cone: bool
-) -> set[str]:
+def determine_included_paths(index: Index, lines: list[str], cone: bool) -> set[str]:
     """Determine which paths in the index should be included based on either
     a full-pattern match or a cone-mode approach.
 
     Args:
-      repo: A path to the repository or a Repo object.
+      index: An Index object containing the repository's index.
       lines: A list of pattern lines (strings) from sparse-checkout config.
       cone: A bool indicating cone mode.
 
@@ -53,32 +51,25 @@ def determine_included_paths(
       A set of included path strings.
     """
     if cone:
-        return compute_included_paths_cone(repo, lines)
+        return compute_included_paths_cone(index, lines)
     else:
-        return compute_included_paths_full(repo, lines)
+        return compute_included_paths_full(index, lines)
 
 
-def compute_included_paths_full(repo: Union[str, Repo], lines: list[str]) -> set[str]:
+def compute_included_paths_full(index: Index, lines: list[str]) -> set[str]:
     """Use .gitignore-style parsing and matching to determine included paths.
 
     Each file path in the index is tested against the parsed sparse patterns.
     If it matches the final (most recently applied) positive pattern, it is included.
 
     Args:
-      repo: A path to the repository or a Repo object.
+      index: An Index object containing the repository's index.
       lines: A list of pattern lines (strings) from sparse-checkout config.
 
     Returns:
       A set of included path strings.
     """
     parsed = parse_sparse_patterns(lines)
-    if isinstance(repo, str):
-        from .porcelain import open_repo
-
-        repo_obj = open_repo(repo)
-    else:
-        repo_obj = repo
-    index = repo_obj.open_index()
     included = set()
     for path_bytes, entry in index.items():
         path_str = path_bytes.decode("utf-8")
@@ -88,7 +79,7 @@ def compute_included_paths_full(repo: Union[str, Repo], lines: list[str]) -> set
     return included
 
 
-def compute_included_paths_cone(repo: Union[str, Repo], lines: list[str]) -> set[str]:
+def compute_included_paths_cone(index: Index, lines: list[str]) -> set[str]:
     """Implement a simplified 'cone' approach for sparse-checkout.
 
     By default, this can include top-level files, exclude all subdirectories,
@@ -97,7 +88,7 @@ def compute_included_paths_cone(repo: Union[str, Repo], lines: list[str]) -> set
     of the recursive cone mode.
 
     Args:
-      repo: A path to the repository or a Repo object.
+      index: An Index object containing the repository's index.
       lines: A list of pattern lines (strings), typically including entries like
         "/*", "!/*/", or "/mydir/".
 
@@ -119,13 +110,6 @@ def compute_included_paths_cone(repo: Union[str, Repo], lines: list[str]) -> set
             if d:
                 reinclude_dirs.add(d)
 
-    if isinstance(repo, str):
-        from .porcelain import open_repo
-
-        repo_obj = open_repo(repo)
-    else:
-        repo_obj = repo
-    index = repo_obj.open_index()
     included = set()
 
     for path_bytes, entry in index.items():
@@ -152,7 +136,7 @@ def compute_included_paths_cone(repo: Union[str, Repo], lines: list[str]) -> set
 
 
 def apply_included_paths(
-    repo: Union[str, Repo], included_paths: set[str], force: bool = False
+    repo: Repo, included_paths: set[str], force: bool = False
 ) -> None:
     """Apply the sparse-checkout inclusion set to the index and working tree.
 
@@ -169,16 +153,8 @@ def apply_included_paths(
     Returns:
       None
     """
-    if isinstance(repo, str):
-        from .porcelain import open_repo
-
-        repo_obj = open_repo(repo)
-    else:
-        repo_obj = repo
-    index = repo_obj.open_index()
-    if not hasattr(repo_obj, "get_blob_normalizer"):
-        raise ValueError("Repository must support get_blob_normalizer")
-    normalizer = repo_obj.get_blob_normalizer()
+    index = repo.open_index()
+    normalizer = repo.get_blob_normalizer()
 
     def local_modifications_exist(full_path: str, index_entry: IndexEntry) -> bool:
         if not os.path.exists(full_path):
@@ -186,12 +162,10 @@ def apply_included_paths(
         with open(full_path, "rb") as f:
             disk_data = f.read()
         try:
-            blob_obj = repo_obj.object_store[index_entry.sha]
+            blob_obj = repo.object_store[index_entry.sha]
         except KeyError:
             return True
         norm_data = normalizer.checkin_normalize(disk_data, full_path)
-        from .objects import Blob
-
         if not isinstance(blob_obj, Blob):
             return True
         return norm_data != blob_obj.data
@@ -213,9 +187,7 @@ def apply_included_paths(
     for path_bytes, entry in list(index.items()):
         if not isinstance(entry, IndexEntry):
             continue  # Skip conflicted entries
-        if not hasattr(repo_obj, "path"):
-            raise ValueError("Repository must have a path attribute")
-        full_path = os.path.join(cast(Any, repo_obj).path, path_bytes.decode("utf-8"))
+        full_path = os.path.join(repo.path, path_bytes.decode("utf-8"))
 
         if entry.skip_worktree:
             # Excluded => remove if safe
@@ -238,14 +210,12 @@ def apply_included_paths(
             # Included => materialize if missing
             if not os.path.exists(full_path):
                 try:
-                    blob = repo_obj.object_store[entry.sha]
+                    blob = repo.object_store[entry.sha]
                 except KeyError:
                     raise BlobNotFoundError(
                         f"Blob {entry.sha.hex()} not found for {path_bytes.decode('utf-8')}."
                     )
                 ensure_dir_exists(os.path.dirname(full_path))
-                from .objects import Blob
-
                 # Apply checkout normalization if normalizer is available
                 if normalizer and isinstance(blob, Blob):
                     blob = normalizer.checkout_normalize(blob, path_bytes)

+ 11 - 9
tests/test_sparse_patterns.py

@@ -209,7 +209,7 @@ class ComputeIncludedPathsFullTests(TestCase):
             "!bar.*",  # exclude bar.md
             "docs/",  # include docs dir
         ]
-        included = compute_included_paths_full(self.repo, lines)
+        included = compute_included_paths_full(self.repo.open_index(), lines)
         self.assertEqual(included, {"foo.py", "docs/readme"})
 
     def test_full_with_utf8_paths(self):
@@ -219,7 +219,7 @@ class ComputeIncludedPathsFullTests(TestCase):
 
         # Include all text files
         lines = ["*.txt"]
-        included = compute_included_paths_full(self.repo, lines)
+        included = compute_included_paths_full(self.repo.open_index(), lines)
         self.assertEqual(included, {"unicode/文件.txt"})
 
 
@@ -256,7 +256,7 @@ class ComputeIncludedPathsConeTests(TestCase):
             "!/*/",
             "/docs/",
         ]
-        included = compute_included_paths_cone(self.repo, lines)
+        included = compute_included_paths_cone(self.repo.open_index(), lines)
         # top-level => includes 'topfile'
         # subdirs => excluded, except docs/
         self.assertEqual(included, {"topfile", "docs/readme.md"})
@@ -272,7 +272,7 @@ class ComputeIncludedPathsConeTests(TestCase):
             "!/*/",
             "/",  # This empty pattern should be skipped
         ]
-        included = compute_included_paths_cone(self.repo, lines)
+        included = compute_included_paths_cone(self.repo.open_index(), lines)
         # Only topfile should be included since the empty pattern is skipped
         self.assertEqual(included, {"topfile"})
 
@@ -286,7 +286,7 @@ class ComputeIncludedPathsConeTests(TestCase):
             "/*",  # top-level
             "/docs/",  # re-include docs?
         ]
-        included = compute_included_paths_cone(self.repo, lines)
+        included = compute_included_paths_cone(self.repo.open_index(), lines)
         # Because exclude_subdirs was never set, everything is included:
         self.assertEqual(
             included,
@@ -301,7 +301,7 @@ class ComputeIncludedPathsConeTests(TestCase):
 
         # Only specify reinclude_dirs, need to explicitly exclude subdirs
         lines = ["!/*/", "/docs/"]
-        included = compute_included_paths_cone(self.repo, lines)
+        included = compute_included_paths_cone(self.repo.open_index(), lines)
         # Only docs/* should be included, not topfile or lib/*
         self.assertEqual(included, {"docs/readme.md"})
 
@@ -313,7 +313,7 @@ class ComputeIncludedPathsConeTests(TestCase):
 
         # Only exclude subdirs and reinclude docs
         lines = ["!/*/", "/docs/"]
-        included = compute_included_paths_cone(self.repo, lines)
+        included = compute_included_paths_cone(self.repo.open_index(), lines)
         # Only docs/* should be included since we didn't include top level
         self.assertEqual(included, {"docs/readme.md"})
 
@@ -339,7 +339,8 @@ class DetermineIncludedPathsTests(TestCase):
         self._add_file_to_index("bar.md")
 
         lines = ["*.py", "!bar.*"]
-        included = determine_included_paths(self.repo, lines, cone=False)
+        index = self.repo.open_index()
+        included = determine_included_paths(index, lines, cone=False)
         self.assertEqual(included, {"foo.py"})
 
     def test_cone_mode(self):
@@ -347,7 +348,8 @@ class DetermineIncludedPathsTests(TestCase):
         self._add_file_to_index("subdir/anotherfile")
 
         lines = ["/*", "!/*/"]
-        included = determine_included_paths(self.repo, lines, cone=True)
+        index = self.repo.open_index()
+        included = determine_included_paths(index, lines, cone=True)
         self.assertEqual(included, {"topfile"})