Przeglądaj źródła

More typing fixes (#1908)

Jelmer Vernooij 3 miesięcy temu
rodzic
commit
0ea3e5fbf1
7 zmienionych plików z 174 dodań i 80 usunięć
  1. 9 9
      dulwich/__init__.py
  2. 4 4
      dulwich/diff_tree.py
  3. 0 1
      dulwich/index.py
  4. 68 2
      dulwich/object_store.py
  5. 44 16
      dulwich/objects.py
  6. 36 27
      dulwich/porcelain.py
  7. 13 21
      dulwich/repo.py

+ 9 - 9
dulwich/__init__.py

@@ -24,7 +24,7 @@
 """Python implementation of the Git file formats and protocols."""
 
 import sys
-from typing import Callable, Optional, TypeVar, Union
+from typing import Any, Callable, Optional, TypeVar, Union
 
 if sys.version_info >= (3, 10):
     from typing import ParamSpec
@@ -37,16 +37,16 @@ __all__ = ["__version__", "replace_me"]
 
 P = ParamSpec("P")
 R = TypeVar("R")
-F = TypeVar("F", bound=Callable[..., object])
+F = TypeVar("F", bound=Callable[..., Any])
 
 try:
-    from dissolve import replace_me
+    from dissolve import replace_me as replace_me
 except ImportError:
     # if dissolve is not installed, then just provide a basic implementation
     # of its replace_me decorator
     def replace_me(
-        since: Optional[Union[str, tuple[int, ...]]] = None,
-        remove_in: Optional[Union[str, tuple[int, ...]]] = None,
+        since: Optional[Union[tuple[int, ...], str]] = None,
+        remove_in: Optional[Union[tuple[int, ...], str]] = None,
     ) -> Callable[[F], F]:
         """Decorator to mark functions as deprecated.
 
@@ -58,7 +58,7 @@ except ImportError:
             Decorator function
         """
 
-        def decorator(func: F) -> F:
+        def decorator(func: Callable[P, R]) -> Callable[P, R]:
             import functools
             import warnings
 
@@ -76,7 +76,7 @@ except ImportError:
                 m += " and will be removed in a future version"
 
             @functools.wraps(func)
-            def _wrapped_func(*args, **kwargs):  # type: ignore[no-untyped-def]
+            def _wrapped_func(*args: P.args, **kwargs: P.kwargs) -> R:
                 warnings.warn(
                     m,
                     DeprecationWarning,
@@ -84,6 +84,6 @@ except ImportError:
                 )
                 return func(*args, **kwargs)
 
-            return _wrapped_func  # type: ignore[return-value]
+            return _wrapped_func
 
-        return decorator
+        return decorator  # type: ignore[return-value]

+ 4 - 4
dulwich/diff_tree.py

@@ -116,10 +116,10 @@ def _merge_entries(
     while i1 < len1 and i2 < len2:
         entry1 = entries1[i1]
         entry2 = entries2[i2]
-        if entry1.path < entry2.path:  # type: ignore[operator]
+        if entry1.path < entry2.path:
             result.append((entry1, None))
             i1 += 1
-        elif entry1.path > entry2.path:  # type: ignore[operator]
+        elif entry1.path > entry2.path:
             result.append((None, entry2))
             i2 += 1
         else:
@@ -175,8 +175,8 @@ def walk_trees(
         if prune_identical and is_tree1 and is_tree2 and entry1 == entry2:
             continue
 
-        tree1 = (is_tree1 and entry1 and store[entry1.sha]) or None  # type: ignore[index]
-        tree2 = (is_tree2 and entry2 and store[entry2.sha]) or None  # type: ignore[index]
+        tree1 = (is_tree1 and entry1 and store[entry1.sha]) or None
+        tree2 = (is_tree2 and entry2 and store[entry2.sha]) or None
         path = (
             (entry1.path if entry1 else None)
             or (entry2.path if entry2 else None)

+ 0 - 1
dulwich/index.py

@@ -2776,7 +2776,6 @@ class locked_index:
     def __enter__(self) -> Index:
         """Enter context manager and lock index."""
         f = GitFile(self._path, "wb")
-        assert isinstance(f, _GitFile)  # GitFile in write mode always returns _GitFile
         self._file = f
         self._index = Index(self._path)
         return self._index

+ 68 - 2
dulwich/object_store.py

@@ -622,7 +622,73 @@ class BaseObjectStore:
         raise KeyError(sha)
 
 
-class PackBasedObjectStore(BaseObjectStore, PackedObjectContainer):
+class PackCapableObjectStore(BaseObjectStore, PackedObjectContainer):
+    """Object store that supports pack operations.
+
+    This is a base class for object stores that can handle pack files,
+    including both disk-based and memory-based stores.
+    """
+
+    def add_pack(self) -> tuple[BinaryIO, Callable[[], None], Callable[[], None]]:
+        """Add a new pack to this object store.
+
+        Returns: Tuple of (file, commit_func, abort_func)
+        """
+        raise NotImplementedError(self.add_pack)
+
+    def add_pack_data(
+        self,
+        count: int,
+        unpacked_objects: Iterator["UnpackedObject"],
+        progress: Optional[Callable[..., None]] = None,
+    ) -> Optional["Pack"]:
+        """Add pack data to this object store.
+
+        Args:
+          count: Number of objects
+          unpacked_objects: Iterator over unpacked objects
+          progress: Optional progress callback
+        """
+        raise NotImplementedError(self.add_pack_data)
+
+    def get_unpacked_object(
+        self, sha1: bytes, *, include_comp: bool = False
+    ) -> "UnpackedObject":
+        """Get a raw unresolved object.
+
+        Args:
+            sha1: SHA-1 hash of the object
+            include_comp: Whether to include compressed data
+
+        Returns:
+            UnpackedObject instance
+        """
+        from .pack import UnpackedObject
+
+        obj = self[sha1]
+        return UnpackedObject(obj.type_num, sha=sha1, decomp_chunks=obj.as_raw_chunks())
+
+    def iterobjects_subset(
+        self, shas: Iterable[bytes], *, allow_missing: bool = False
+    ) -> Iterator[ShaFile]:
+        """Iterate over a subset of objects.
+
+        Args:
+            shas: Iterable of object SHAs to retrieve
+            allow_missing: If True, skip missing objects
+
+        Returns:
+            Iterator of ShaFile objects
+        """
+        for sha in shas:
+            try:
+                yield self[sha]
+            except KeyError:
+                if not allow_missing:
+                    raise
+
+
+class PackBasedObjectStore(PackCapableObjectStore, PackedObjectContainer):
     """Object store that uses pack files for storage.
 
     This class provides a base implementation for object stores that use
@@ -1836,7 +1902,7 @@ class DiskObjectStore(PackBasedObjectStore):
                     os.remove(pack_path)
 
 
-class MemoryObjectStore(BaseObjectStore):
+class MemoryObjectStore(PackCapableObjectStore):
     """Object store that keeps all objects in memory."""
 
     def __init__(self) -> None:

+ 44 - 16
dulwich/objects.py

@@ -167,23 +167,50 @@ def hex_to_filename(path: PathT, hex: Union[str, bytes]) -> PathT:
     # os.path.join accepts bytes or unicode, but all args must be of the same
     # type. Make sure that hex which is expected to be bytes, is the same type
     # as path.
-    if type(path) is not type(hex) and isinstance(path, str):
-        hex = hex.decode("ascii")  # type: ignore
-    dir_name = hex[:2]
-    file_name = hex[2:]
-    # Check from object dir
-    return os.path.join(path, dir_name, file_name)  # type: ignore
+    if isinstance(path, str):
+        if isinstance(hex, bytes):
+            hex_str = hex.decode("ascii")
+        else:
+            hex_str = hex
+        dir_name = hex_str[:2]
+        file_name = hex_str[2:]
+        result = os.path.join(path, dir_name, file_name)
+        assert isinstance(result, str)
+        return result
+    else:
+        # path is bytes
+        if isinstance(hex, str):
+            hex_bytes = hex.encode("ascii")
+        else:
+            hex_bytes = hex
+        dir_name_b = hex_bytes[:2]
+        file_name_b = hex_bytes[2:]
+        result_b = os.path.join(path, dir_name_b, file_name_b)
+        assert isinstance(result_b, bytes)
+        return result_b
 
 
 def filename_to_hex(filename: Union[str, bytes]) -> str:
     """Takes an object filename and returns its corresponding hex sha."""
     # grab the last (up to) two path components
-    names = filename.rsplit(os.path.sep, 2)[-2:]  # type: ignore
     errmsg = f"Invalid object filename: {filename!r}"
-    assert len(names) == 2, errmsg
-    base, rest = names
-    assert len(base) == 2 and len(rest) == 38, errmsg
-    hex_bytes = (base + rest).encode("ascii")  # type: ignore
+    if isinstance(filename, str):
+        names = filename.rsplit(os.path.sep, 2)[-2:]
+        assert len(names) == 2, errmsg
+        base, rest = names
+        assert len(base) == 2 and len(rest) == 38, errmsg
+        hex_str = base + rest
+        hex_bytes = hex_str.encode("ascii")
+    else:
+        # filename is bytes
+        sep = (
+            os.path.sep.encode("ascii") if isinstance(os.path.sep, str) else os.path.sep
+        )
+        names_b = filename.rsplit(sep, 2)[-2:]
+        assert len(names_b) == 2, errmsg
+        base_b, rest_b = names_b
+        assert len(base_b) == 2 and len(rest_b) == 38, errmsg
+        hex_bytes = base_b + rest_b
     hex_to_sha(hex_bytes)
     return hex_bytes.decode("ascii")
 
@@ -761,7 +788,8 @@ class Blob(ShaFile):
         if not chunks:
             return []
         if len(chunks) == 1:
-            return chunks[0].splitlines(True)  # type: ignore[no-any-return]
+            result: list[bytes] = chunks[0].splitlines(True)
+            return result
         remaining = None
         ret = []
         for chunk in chunks:
@@ -1154,9 +1182,9 @@ class Tag(ShaFile):
 class TreeEntry(NamedTuple):
     """Named tuple encapsulating a single tree entry."""
 
-    path: Optional[bytes]
-    mode: Optional[int]
-    sha: Optional[bytes]
+    path: bytes
+    mode: int
+    sha: bytes
 
     def in_path(self, path: bytes) -> "TreeEntry":
         """Return a copy of this entry with the given path prepended."""
@@ -1432,7 +1460,7 @@ class Tree(ShaFile):
             last = entry
 
     def _serialize(self) -> list[bytes]:
-        return list(serialize_tree(self.iteritems()))  # type: ignore[arg-type]
+        return list(serialize_tree(self.iteritems()))
 
     def as_pretty_string(self) -> str:
         """Return a human-readable string representation of this tree.

+ 36 - 27
dulwich/porcelain.py

@@ -1051,6 +1051,8 @@ def clean(
         # Reverse file visit order, so that files and subdirectories are
         # removed before containing directory
         for ap, is_dir in reversed(list(paths_in_wd)):
+            # target_dir and r.path are both str, so ap must be str
+            assert isinstance(ap, str)
             if is_dir:
                 # All subdirectories and files have been removed if untracked,
                 # so dir contains no tracked files iff it is empty.
@@ -1061,7 +1063,7 @@ def clean(
                 ip = path_to_tree_path(r.path, ap)
                 is_tracked = ip in index
 
-                rp = os.path.relpath(ap, r.path)  # type: ignore[arg-type]
+                rp = os.path.relpath(ap, r.path)
                 is_ignored = ignore_manager.is_ignored(rp)
 
                 if not is_tracked and not is_ignored:
@@ -2879,7 +2881,11 @@ def get_untracked_paths(
     if untracked_files == "no":
         return
 
-    with open_repo_closing(basepath) as r:
+    # Normalize paths to str
+    frompath_str = os.fsdecode(os.fspath(frompath))
+    basepath_str = os.fsdecode(os.fspath(basepath))
+
+    with open_repo_closing(basepath_str) as r:
         ignore_manager = IgnoreFilterManager.from_repo(r)
 
     ignored_dirs = []
@@ -2907,13 +2913,13 @@ def get_untracked_paths(
     def prune_dirnames(dirpath: str, dirnames: list[str]) -> list[str]:
         for i in range(len(dirnames) - 1, -1, -1):
             path = os.path.join(dirpath, dirnames[i])
-            ip = os.path.join(os.path.relpath(path, basepath), "")  # type: ignore[arg-type]
+            ip = os.path.join(os.path.relpath(path, basepath_str), "")
 
             # Check if directory is ignored
             if ignore_manager.is_ignored(ip) is True:
                 if not exclude_ignored:
                     ignored_dirs.append(
-                        os.path.join(os.path.relpath(path, frompath), "")  # type: ignore[arg-type]
+                        os.path.join(os.path.relpath(path, frompath_str), "")
                     )
                 del dirnames[i]
                 continue
@@ -2921,7 +2927,7 @@ def get_untracked_paths(
             # For "normal" mode, check if the directory is entirely untracked
             if untracked_files == "normal":
                 # Convert directory path to tree path for index lookup
-                dir_tree_path = path_to_tree_path(basepath, path)
+                dir_tree_path = path_to_tree_path(basepath_str, path)
 
                 # Check if any file in this directory is tracked
                 dir_prefix = dir_tree_path + b"/" if dir_tree_path else b""
@@ -2929,8 +2935,10 @@ def get_untracked_paths(
 
                 if not has_tracked_files:
                     # This directory is entirely untracked
-                    rel_path_base = os.path.relpath(path, basepath)  # type: ignore[arg-type]
-                    rel_path_from = os.path.join(os.path.relpath(path, frompath), "")  # type: ignore[arg-type]
+                    rel_path_base = os.path.relpath(path, basepath_str)
+                    rel_path_from = os.path.join(
+                        os.path.relpath(path, frompath_str), ""
+                    )
 
                     # If excluding ignored, check if directory contains any non-ignored files
                     if exclude_ignored:
@@ -2950,39 +2958,43 @@ def get_untracked_paths(
     # For "all" mode, use the original behavior
     if untracked_files == "all":
         for ap, is_dir in _walk_working_dir_paths(
-            frompath, basepath, prune_dirnames=prune_dirnames
+            frompath_str, basepath_str, prune_dirnames=prune_dirnames
         ):
+            # frompath_str and basepath_str are both str, so ap must be str
+            assert isinstance(ap, str)
             if not is_dir:
-                ip = path_to_tree_path(basepath, ap)
+                ip = path_to_tree_path(basepath_str, ap)
                 if ip not in index:
                     if not exclude_ignored or not ignore_manager.is_ignored(
-                        os.path.relpath(ap, basepath)  # type: ignore[arg-type]
+                        os.path.relpath(ap, basepath_str)
                     ):
-                        yield os.path.relpath(ap, frompath)  # type: ignore[arg-type]
+                        yield os.path.relpath(ap, frompath_str)
     else:  # "normal" mode
         # Walk directories, handling both files and directories
         for ap, is_dir in _walk_working_dir_paths(
-            frompath, basepath, prune_dirnames=prune_dirnames
+            frompath_str, basepath_str, prune_dirnames=prune_dirnames
         ):
+            # frompath_str and basepath_str are both str, so ap must be str
+            assert isinstance(ap, str)
             # This part won't be reached for pruned directories
             if is_dir:
                 # Check if this directory is entirely untracked
-                dir_tree_path = path_to_tree_path(basepath, ap)
+                dir_tree_path = path_to_tree_path(basepath_str, ap)
                 dir_prefix = dir_tree_path + b"/" if dir_tree_path else b""
                 has_tracked_files = any(name.startswith(dir_prefix) for name in index)
                 if not has_tracked_files:
                     if not exclude_ignored or not ignore_manager.is_ignored(
-                        os.path.relpath(ap, basepath)  # type: ignore[arg-type]
+                        os.path.relpath(ap, basepath_str)
                     ):
-                        yield os.path.join(os.path.relpath(ap, frompath), "")  # type: ignore[arg-type]
+                        yield os.path.join(os.path.relpath(ap, frompath_str), "")
             else:
                 # Check individual files in directories that contain tracked files
-                ip = path_to_tree_path(basepath, ap)
+                ip = path_to_tree_path(basepath_str, ap)
                 if ip not in index:
                     if not exclude_ignored or not ignore_manager.is_ignored(
-                        os.path.relpath(ap, basepath)  # type: ignore[arg-type]
+                        os.path.relpath(ap, basepath_str)
                     ):
-                        yield os.path.relpath(ap, frompath)  # type: ignore[arg-type]
+                        yield os.path.relpath(ap, frompath_str)
 
         # Yield any untracked directories found during pruning
         yield from untracked_dir_list
@@ -3939,26 +3951,23 @@ def check_ignore(
         ignore_manager = IgnoreFilterManager.from_repo(r)
         for original_path in paths:
             # Convert path to string for consistent handling
-            original_path_str = os.fspath(original_path)
+            original_path_fspath = os.fspath(original_path)
+            # Normalize to str
+            original_path_str = os.fsdecode(original_path_fspath)
 
             if not no_index and path_to_tree_path(r.path, original_path_str) in index:
                 continue
 
             # Preserve whether the original path had a trailing slash
-            if isinstance(original_path_str, bytes):
-                had_trailing_slash = original_path_str.endswith(
-                    (b"/", os.path.sep.encode())
-                )
-            else:
-                had_trailing_slash = original_path_str.endswith(("/", os.path.sep))
+            had_trailing_slash = original_path_str.endswith(("/", os.path.sep))
 
             if os.path.isabs(original_path_str):
-                path = os.path.relpath(original_path_str, r.path)  # type: ignore[arg-type]
+                path = os.path.relpath(original_path_str, r.path)
                 # Normalize Windows paths to use forward slashes
                 if os.path.sep != "/":
                     path = path.replace(os.path.sep, "/")
             else:
-                path = original_path_str  # type: ignore[assignment]
+                path = original_path_str
 
             # Restore trailing slash if it was in the original
             if had_trailing_slash and not path.endswith("/"):

+ 13 - 21
dulwich/repo.py

@@ -87,6 +87,7 @@ from .object_store import (
     MissingObjectFinder,
     ObjectStoreGraphWalker,
     PackBasedObjectStore,
+    PackCapableObjectStore,
     find_shallow,
     peel_sha,
 )
@@ -407,7 +408,9 @@ class BaseRepo:
         repository
     """
 
-    def __init__(self, object_store: PackBasedObjectStore, refs: RefsContainer) -> None:
+    def __init__(
+        self, object_store: "PackCapableObjectStore", refs: RefsContainer
+    ) -> None:
         """Open a repository.
 
         This shouldn't be called directly, but rather through one of the
@@ -608,7 +611,8 @@ class BaseRepo:
             if hasattr(graph_walker, "shallow"):
                 graph_walker.shallow.update(shallow - not_shallow)
                 new_shallow = graph_walker.shallow - current_shallow
-                unshallow = graph_walker.unshallow = not_shallow & current_shallow  # type: ignore[attr-defined]
+                unshallow = not_shallow & current_shallow
+                setattr(graph_walker, "unshallow", unshallow)
                 if hasattr(graph_walker, "update_shallow"):
                     graph_walker.update_shallow(new_shallow, unshallow)
         else:
@@ -622,24 +626,12 @@ class BaseRepo:
                 # Do not send a pack in shallow short-circuit path
                 return None
 
-            class DummyMissingObjectFinder:
-                """Dummy finder that returns no missing objects."""
-
-                def get_remote_has(self) -> None:
-                    """Get remote has (always returns None).
-
-                    Returns:
-                      None
-                    """
-                    return None
-
-                def __len__(self) -> int:
-                    return 0
-
-                def __iter__(self) -> Iterator[tuple[bytes, Optional[bytes]]]:
-                    yield from []
-
-            return DummyMissingObjectFinder()  # type: ignore
+            # Return an actual MissingObjectFinder with empty wants
+            return MissingObjectFinder(
+                self.object_store,
+                haves=[],
+                wants=[],
+            )
 
         # If the graph walker is set up with an implementation that can
         # ACK/NAK to the wire, it will write data to the client through
@@ -2228,7 +2220,7 @@ class MemoryRepo(BaseRepo):
 
         self._reflog: list[Any] = []
         refs_container = DictRefsContainer({}, logger=self._append_reflog)
-        BaseRepo.__init__(self, MemoryObjectStore(), refs_container)  # type: ignore[arg-type]
+        BaseRepo.__init__(self, MemoryObjectStore(), refs_container)
         self._named_files: dict[str, bytes] = {}
         self.bare = True
         self._config = ConfigFile()