Browse Source

various minor type improvements

Jelmer Vernooij 1 month ago
parent
commit
516547bda3
3 changed files with 64 additions and 64 deletions
  1. 37 9
      dulwich/porcelain.py
  2. 16 46
      dulwich/sparse_patterns.py
  3. 11 9
      tests/test_sparse_patterns.py

+ 37 - 9
dulwich/porcelain.py

@@ -1,4 +1,4 @@
-# porcelain.py -- Porcelain-like layer on top of Dulwich
+# e porcelain.py -- Porcelain-like layer on top of Dulwich
 # Copyright (C) 2013 Jelmer Vernooij <jelmer@jelmer.uk>
 # Copyright (C) 2013 Jelmer Vernooij <jelmer@jelmer.uk>
 #
 #
 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
 # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
@@ -86,11 +86,11 @@ import sys
 import time
 import time
 from collections import namedtuple
 from collections import namedtuple
 from collections.abc import Iterator
 from collections.abc import Iterator
-from contextlib import closing, contextmanager
+from contextlib import AbstractContextManager, closing, contextmanager
 from dataclasses import dataclass
 from dataclasses import dataclass
 from io import BytesIO, RawIOBase
 from io import BytesIO, RawIOBase
 from pathlib import Path
 from pathlib import Path
-from typing import Optional, Union
+from typing import Optional, TypeVar, Union, overload
 
 
 from . import replace_me
 from . import replace_me
 from .archive import tar_stream
 from .archive import tar_stream
@@ -174,6 +174,9 @@ from .sparse_patterns import (
 # Module level tuple definition for status output
 # Module level tuple definition for status output
 GitStatus = namedtuple("GitStatus", "staged unstaged untracked")
 GitStatus = namedtuple("GitStatus", "staged unstaged untracked")
 
 
+# TypeVar for preserving BaseRepo subclass types
+T = TypeVar("T", bound="BaseRepo")
+
 
 
 @dataclass
 @dataclass
 class CountObjectsResult:
 class CountObjectsResult:
@@ -310,10 +313,22 @@ def get_user_timezones():
     return author_timezone, commit_timezone
     return author_timezone, commit_timezone
 
 
 
 
-def open_repo(path_or_repo: Union[str, os.PathLike, BaseRepo]):
+@overload
+def open_repo(path_or_repo: T) -> AbstractContextManager[T]: ...
+
+
+@overload
+def open_repo(
+    path_or_repo: Union[str, os.PathLike],
+) -> AbstractContextManager[Repo]: ...
+
+
+def open_repo(
+    path_or_repo: Union[str, os.PathLike, T],
+) -> AbstractContextManager[Union[T, Repo]]:
     """Open an argument that can be a repository or a path for a repository."""
     """Open an argument that can be a repository or a path for a repository."""
     if isinstance(path_or_repo, BaseRepo):
     if isinstance(path_or_repo, BaseRepo):
-        return path_or_repo
+        return _noop_context_manager(path_or_repo)
     return Repo(path_or_repo)
     return Repo(path_or_repo)
 
 
 
 
@@ -323,7 +338,19 @@ def _noop_context_manager(obj):
     yield obj
     yield obj
 
 
 
 
-def open_repo_closing(path_or_repo: Union[str, os.PathLike, BaseRepo]):
+@overload
+def open_repo_closing(path_or_repo: T) -> AbstractContextManager[T]: ...
+
+
+@overload
+def open_repo_closing(
+    path_or_repo: Union[str, os.PathLike],
+) -> AbstractContextManager[Repo]: ...
+
+
+def open_repo_closing(
+    path_or_repo: Union[str, os.PathLike, T],
+) -> AbstractContextManager[Union[T, Repo]]:
     """Open an argument that can be a repository or a path for a repository.
     """Open an argument that can be a repository or a path for a repository.
     returns a context manager that will close the repo on exit if the argument
     returns a context manager that will close the repo on exit if the argument
     is a path, else does nothing if the argument is a repo.
     is a path, else does nothing if the argument is a repo.
@@ -714,7 +741,7 @@ def clone(
     return repo
     return repo
 
 
 
 
-def add(repo: Union[str, os.PathLike, BaseRepo] = ".", paths=None):
+def add(repo: Union[str, os.PathLike, Repo] = ".", paths=None):
     """Add files to the staging area.
     """Add files to the staging area.
 
 
     Args:
     Args:
@@ -949,7 +976,7 @@ rm = remove
 
 
 
 
 def mv(
 def mv(
-    repo: Union[str, os.PathLike, BaseRepo],
+    repo: Union[str, os.PathLike, Repo],
     source: Union[str, bytes, os.PathLike],
     source: Union[str, bytes, os.PathLike],
     destination: Union[str, bytes, os.PathLike],
     destination: Union[str, bytes, os.PathLike],
     force: bool = False,
     force: bool = False,
@@ -3483,7 +3510,8 @@ def sparse_checkout(
             repo_obj.set_sparse_checkout_patterns(patterns)
             repo_obj.set_sparse_checkout_patterns(patterns)
 
 
         # --- 2) Determine the set of included paths ---
         # --- 2) Determine the set of included paths ---
-        included_paths = determine_included_paths(repo_obj, lines, cone)
+        index = repo_obj.open_index()
+        included_paths = determine_included_paths(index, lines, cone)
 
 
         # --- 3) Apply those results to the index & working tree ---
         # --- 3) Apply those results to the index & working tree ---
         try:
         try:

+ 16 - 46
dulwich/sparse_patterns.py

@@ -23,10 +23,10 @@
 
 
 import os
 import os
 from fnmatch import fnmatch
 from fnmatch import fnmatch
-from typing import Any, Union, cast
 
 
 from .file import ensure_dir_exists
 from .file import ensure_dir_exists
-from .index import IndexEntry
+from .index import Index, IndexEntry
+from .objects import Blob
 from .repo import Repo
 from .repo import Repo
 
 
 
 
@@ -38,14 +38,12 @@ class BlobNotFoundError(Exception):
     """Raised when a requested blob is not found in the repository's object store."""
     """Raised when a requested blob is not found in the repository's object store."""
 
 
 
 
-def determine_included_paths(
-    repo: Union[str, Repo], lines: list[str], cone: bool
-) -> set[str]:
+def determine_included_paths(index: Index, lines: list[str], cone: bool) -> set[str]:
     """Determine which paths in the index should be included based on either
     """Determine which paths in the index should be included based on either
     a full-pattern match or a cone-mode approach.
     a full-pattern match or a cone-mode approach.
 
 
     Args:
     Args:
-      repo: A path to the repository or a Repo object.
+      index: An Index object containing the repository's index.
       lines: A list of pattern lines (strings) from sparse-checkout config.
       lines: A list of pattern lines (strings) from sparse-checkout config.
       cone: A bool indicating cone mode.
       cone: A bool indicating cone mode.
 
 
@@ -53,32 +51,25 @@ def determine_included_paths(
       A set of included path strings.
       A set of included path strings.
     """
     """
     if cone:
     if cone:
-        return compute_included_paths_cone(repo, lines)
+        return compute_included_paths_cone(index, lines)
     else:
     else:
-        return compute_included_paths_full(repo, lines)
+        return compute_included_paths_full(index, lines)
 
 
 
 
-def compute_included_paths_full(repo: Union[str, Repo], lines: list[str]) -> set[str]:
+def compute_included_paths_full(index: Index, lines: list[str]) -> set[str]:
     """Use .gitignore-style parsing and matching to determine included paths.
     """Use .gitignore-style parsing and matching to determine included paths.
 
 
     Each file path in the index is tested against the parsed sparse patterns.
     Each file path in the index is tested against the parsed sparse patterns.
     If it matches the final (most recently applied) positive pattern, it is included.
     If it matches the final (most recently applied) positive pattern, it is included.
 
 
     Args:
     Args:
-      repo: A path to the repository or a Repo object.
+      index: An Index object containing the repository's index.
       lines: A list of pattern lines (strings) from sparse-checkout config.
       lines: A list of pattern lines (strings) from sparse-checkout config.
 
 
     Returns:
     Returns:
       A set of included path strings.
       A set of included path strings.
     """
     """
     parsed = parse_sparse_patterns(lines)
     parsed = parse_sparse_patterns(lines)
-    if isinstance(repo, str):
-        from .porcelain import open_repo
-
-        repo_obj = open_repo(repo)
-    else:
-        repo_obj = repo
-    index = repo_obj.open_index()
     included = set()
     included = set()
     for path_bytes, entry in index.items():
     for path_bytes, entry in index.items():
         path_str = path_bytes.decode("utf-8")
         path_str = path_bytes.decode("utf-8")
@@ -88,7 +79,7 @@ def compute_included_paths_full(repo: Union[str, Repo], lines: list[str]) -> set
     return included
     return included
 
 
 
 
-def compute_included_paths_cone(repo: Union[str, Repo], lines: list[str]) -> set[str]:
+def compute_included_paths_cone(index: Index, lines: list[str]) -> set[str]:
     """Implement a simplified 'cone' approach for sparse-checkout.
     """Implement a simplified 'cone' approach for sparse-checkout.
 
 
     By default, this can include top-level files, exclude all subdirectories,
     By default, this can include top-level files, exclude all subdirectories,
@@ -97,7 +88,7 @@ def compute_included_paths_cone(repo: Union[str, Repo], lines: list[str]) -> set
     of the recursive cone mode.
     of the recursive cone mode.
 
 
     Args:
     Args:
-      repo: A path to the repository or a Repo object.
+      index: An Index object containing the repository's index.
       lines: A list of pattern lines (strings), typically including entries like
       lines: A list of pattern lines (strings), typically including entries like
         "/*", "!/*/", or "/mydir/".
         "/*", "!/*/", or "/mydir/".
 
 
@@ -119,13 +110,6 @@ def compute_included_paths_cone(repo: Union[str, Repo], lines: list[str]) -> set
             if d:
             if d:
                 reinclude_dirs.add(d)
                 reinclude_dirs.add(d)
 
 
-    if isinstance(repo, str):
-        from .porcelain import open_repo
-
-        repo_obj = open_repo(repo)
-    else:
-        repo_obj = repo
-    index = repo_obj.open_index()
     included = set()
     included = set()
 
 
     for path_bytes, entry in index.items():
     for path_bytes, entry in index.items():
@@ -152,7 +136,7 @@ def compute_included_paths_cone(repo: Union[str, Repo], lines: list[str]) -> set
 
 
 
 
 def apply_included_paths(
 def apply_included_paths(
-    repo: Union[str, Repo], included_paths: set[str], force: bool = False
+    repo: Repo, included_paths: set[str], force: bool = False
 ) -> None:
 ) -> None:
     """Apply the sparse-checkout inclusion set to the index and working tree.
     """Apply the sparse-checkout inclusion set to the index and working tree.
 
 
@@ -169,16 +153,8 @@ def apply_included_paths(
     Returns:
     Returns:
       None
       None
     """
     """
-    if isinstance(repo, str):
-        from .porcelain import open_repo
-
-        repo_obj = open_repo(repo)
-    else:
-        repo_obj = repo
-    index = repo_obj.open_index()
-    if not hasattr(repo_obj, "get_blob_normalizer"):
-        raise ValueError("Repository must support get_blob_normalizer")
-    normalizer = repo_obj.get_blob_normalizer()
+    index = repo.open_index()
+    normalizer = repo.get_blob_normalizer()
 
 
     def local_modifications_exist(full_path: str, index_entry: IndexEntry) -> bool:
     def local_modifications_exist(full_path: str, index_entry: IndexEntry) -> bool:
         if not os.path.exists(full_path):
         if not os.path.exists(full_path):
@@ -186,12 +162,10 @@ def apply_included_paths(
         with open(full_path, "rb") as f:
         with open(full_path, "rb") as f:
             disk_data = f.read()
             disk_data = f.read()
         try:
         try:
-            blob_obj = repo_obj.object_store[index_entry.sha]
+            blob_obj = repo.object_store[index_entry.sha]
         except KeyError:
         except KeyError:
             return True
             return True
         norm_data = normalizer.checkin_normalize(disk_data, full_path)
         norm_data = normalizer.checkin_normalize(disk_data, full_path)
-        from .objects import Blob
-
         if not isinstance(blob_obj, Blob):
         if not isinstance(blob_obj, Blob):
             return True
             return True
         return norm_data != blob_obj.data
         return norm_data != blob_obj.data
@@ -213,9 +187,7 @@ def apply_included_paths(
     for path_bytes, entry in list(index.items()):
     for path_bytes, entry in list(index.items()):
         if not isinstance(entry, IndexEntry):
         if not isinstance(entry, IndexEntry):
             continue  # Skip conflicted entries
             continue  # Skip conflicted entries
-        if not hasattr(repo_obj, "path"):
-            raise ValueError("Repository must have a path attribute")
-        full_path = os.path.join(cast(Any, repo_obj).path, path_bytes.decode("utf-8"))
+        full_path = os.path.join(repo.path, path_bytes.decode("utf-8"))
 
 
         if entry.skip_worktree:
         if entry.skip_worktree:
             # Excluded => remove if safe
             # Excluded => remove if safe
@@ -238,14 +210,12 @@ def apply_included_paths(
             # Included => materialize if missing
             # Included => materialize if missing
             if not os.path.exists(full_path):
             if not os.path.exists(full_path):
                 try:
                 try:
-                    blob = repo_obj.object_store[entry.sha]
+                    blob = repo.object_store[entry.sha]
                 except KeyError:
                 except KeyError:
                     raise BlobNotFoundError(
                     raise BlobNotFoundError(
                         f"Blob {entry.sha.hex()} not found for {path_bytes.decode('utf-8')}."
                         f"Blob {entry.sha.hex()} not found for {path_bytes.decode('utf-8')}."
                     )
                     )
                 ensure_dir_exists(os.path.dirname(full_path))
                 ensure_dir_exists(os.path.dirname(full_path))
-                from .objects import Blob
-
                 # Apply checkout normalization if normalizer is available
                 # Apply checkout normalization if normalizer is available
                 if normalizer and isinstance(blob, Blob):
                 if normalizer and isinstance(blob, Blob):
                     blob = normalizer.checkout_normalize(blob, path_bytes)
                     blob = normalizer.checkout_normalize(blob, path_bytes)

+ 11 - 9
tests/test_sparse_patterns.py

@@ -209,7 +209,7 @@ class ComputeIncludedPathsFullTests(TestCase):
             "!bar.*",  # exclude bar.md
             "!bar.*",  # exclude bar.md
             "docs/",  # include docs dir
             "docs/",  # include docs dir
         ]
         ]
-        included = compute_included_paths_full(self.repo, lines)
+        included = compute_included_paths_full(self.repo.open_index(), lines)
         self.assertEqual(included, {"foo.py", "docs/readme"})
         self.assertEqual(included, {"foo.py", "docs/readme"})
 
 
     def test_full_with_utf8_paths(self):
     def test_full_with_utf8_paths(self):
@@ -219,7 +219,7 @@ class ComputeIncludedPathsFullTests(TestCase):
 
 
         # Include all text files
         # Include all text files
         lines = ["*.txt"]
         lines = ["*.txt"]
-        included = compute_included_paths_full(self.repo, lines)
+        included = compute_included_paths_full(self.repo.open_index(), lines)
         self.assertEqual(included, {"unicode/文件.txt"})
         self.assertEqual(included, {"unicode/文件.txt"})
 
 
 
 
@@ -256,7 +256,7 @@ class ComputeIncludedPathsConeTests(TestCase):
             "!/*/",
             "!/*/",
             "/docs/",
             "/docs/",
         ]
         ]
-        included = compute_included_paths_cone(self.repo, lines)
+        included = compute_included_paths_cone(self.repo.open_index(), lines)
         # top-level => includes 'topfile'
         # top-level => includes 'topfile'
         # subdirs => excluded, except docs/
         # subdirs => excluded, except docs/
         self.assertEqual(included, {"topfile", "docs/readme.md"})
         self.assertEqual(included, {"topfile", "docs/readme.md"})
@@ -272,7 +272,7 @@ class ComputeIncludedPathsConeTests(TestCase):
             "!/*/",
             "!/*/",
             "/",  # This empty pattern should be skipped
             "/",  # This empty pattern should be skipped
         ]
         ]
-        included = compute_included_paths_cone(self.repo, lines)
+        included = compute_included_paths_cone(self.repo.open_index(), lines)
         # Only topfile should be included since the empty pattern is skipped
         # Only topfile should be included since the empty pattern is skipped
         self.assertEqual(included, {"topfile"})
         self.assertEqual(included, {"topfile"})
 
 
@@ -286,7 +286,7 @@ class ComputeIncludedPathsConeTests(TestCase):
             "/*",  # top-level
             "/*",  # top-level
             "/docs/",  # re-include docs?
             "/docs/",  # re-include docs?
         ]
         ]
-        included = compute_included_paths_cone(self.repo, lines)
+        included = compute_included_paths_cone(self.repo.open_index(), lines)
         # Because exclude_subdirs was never set, everything is included:
         # Because exclude_subdirs was never set, everything is included:
         self.assertEqual(
         self.assertEqual(
             included,
             included,
@@ -301,7 +301,7 @@ class ComputeIncludedPathsConeTests(TestCase):
 
 
         # Only specify reinclude_dirs, need to explicitly exclude subdirs
         # Only specify reinclude_dirs, need to explicitly exclude subdirs
         lines = ["!/*/", "/docs/"]
         lines = ["!/*/", "/docs/"]
-        included = compute_included_paths_cone(self.repo, lines)
+        included = compute_included_paths_cone(self.repo.open_index(), lines)
         # Only docs/* should be included, not topfile or lib/*
         # Only docs/* should be included, not topfile or lib/*
         self.assertEqual(included, {"docs/readme.md"})
         self.assertEqual(included, {"docs/readme.md"})
 
 
@@ -313,7 +313,7 @@ class ComputeIncludedPathsConeTests(TestCase):
 
 
         # Only exclude subdirs and reinclude docs
         # Only exclude subdirs and reinclude docs
         lines = ["!/*/", "/docs/"]
         lines = ["!/*/", "/docs/"]
-        included = compute_included_paths_cone(self.repo, lines)
+        included = compute_included_paths_cone(self.repo.open_index(), lines)
         # Only docs/* should be included since we didn't include top level
         # Only docs/* should be included since we didn't include top level
         self.assertEqual(included, {"docs/readme.md"})
         self.assertEqual(included, {"docs/readme.md"})
 
 
@@ -339,7 +339,8 @@ class DetermineIncludedPathsTests(TestCase):
         self._add_file_to_index("bar.md")
         self._add_file_to_index("bar.md")
 
 
         lines = ["*.py", "!bar.*"]
         lines = ["*.py", "!bar.*"]
-        included = determine_included_paths(self.repo, lines, cone=False)
+        index = self.repo.open_index()
+        included = determine_included_paths(index, lines, cone=False)
         self.assertEqual(included, {"foo.py"})
         self.assertEqual(included, {"foo.py"})
 
 
     def test_cone_mode(self):
     def test_cone_mode(self):
@@ -347,7 +348,8 @@ class DetermineIncludedPathsTests(TestCase):
         self._add_file_to_index("subdir/anotherfile")
         self._add_file_to_index("subdir/anotherfile")
 
 
         lines = ["/*", "!/*/"]
         lines = ["/*", "!/*/"]
-        included = determine_included_paths(self.repo, lines, cone=True)
+        index = self.repo.open_index()
+        included = determine_included_paths(index, lines, cone=True)
         self.assertEqual(included, {"topfile"})
         self.assertEqual(included, {"topfile"})