Jelmer Vernooij 1 месяц назад
Родитель
Сommit
5f0660904f
1 измененных файлов с 127 добавлено и 66 удалено
  1. 127 66
      dulwich/index.py

+ 127 - 66
dulwich/index.py

@@ -25,17 +25,24 @@ import os
 import stat
 import struct
 import sys
-from collections.abc import Iterable, Iterator
+import types
+from collections.abc import Generator, Iterable, Iterator
 from dataclasses import dataclass
 from enum import Enum
 from typing import (
+    TYPE_CHECKING,
     Any,
     BinaryIO,
     Callable,
     Optional,
     Union,
+    cast,
 )
 
+if TYPE_CHECKING:
+    from .file import _GitFile
+    from .repo import BaseRepo
+
 from .file import GitFile
 from .object_store import iter_tree_contents
 from .objects import (
@@ -194,7 +201,9 @@ def _decompress_path(
     return path, new_offset
 
 
-def _decompress_path_from_stream(f, previous_path: bytes) -> tuple[bytes, int]:
+def _decompress_path_from_stream(
+    f: BinaryIO, previous_path: bytes
+) -> tuple[bytes, int]:
     """Decompress a path from index version 4 compressed format, reading from stream.
 
     Args:
@@ -459,12 +468,12 @@ def pathsplit(path: bytes) -> tuple[bytes, bytes]:
         return (dirname, basename)
 
 
-def pathjoin(*args):
+def pathjoin(*args: bytes) -> bytes:
     """Join a /-delimited path."""
     return b"/".join([p for p in args if p])
 
 
-def read_cache_time(f):
+def read_cache_time(f: BinaryIO) -> tuple[int, int]:
     """Read a cache time.
 
     Args:
@@ -475,7 +484,7 @@ def read_cache_time(f):
     return struct.unpack(">LL", f.read(8))
 
 
-def write_cache_time(f, t) -> None:
+def write_cache_time(f: BinaryIO, t: Union[int, float, tuple[int, int]]) -> None:
     """Write a cache time.
 
     Args:
@@ -493,7 +502,7 @@ def write_cache_time(f, t) -> None:
 
 
 def read_cache_entry(
-    f, version: int, previous_path: bytes = b""
+    f: BinaryIO, version: int, previous_path: bytes = b""
 ) -> SerializedIndexEntry:
     """Read an entry from a cache file.
 
@@ -551,7 +560,7 @@ def read_cache_entry(
 
 
 def write_cache_entry(
-    f, entry: SerializedIndexEntry, version: int, previous_path: bytes = b""
+    f: BinaryIO, entry: SerializedIndexEntry, version: int, previous_path: bytes = b""
 ) -> None:
     """Write an index entry to a file.
 
@@ -608,7 +617,7 @@ def write_cache_entry(
 class UnsupportedIndexFormat(Exception):
     """An unsupported index format was encountered."""
 
-    def __init__(self, version) -> None:
+    def __init__(self, version: int) -> None:
         self.index_format_version = version
 
 
@@ -682,7 +691,9 @@ def read_index_dict_with_version(
     return ret, version
 
 
-def read_index_dict(f) -> dict[bytes, Union[IndexEntry, ConflictedIndexEntry]]:
+def read_index_dict(
+    f: BinaryIO,
+) -> dict[bytes, Union[IndexEntry, ConflictedIndexEntry]]:
     """Read an index file and return it as a dictionary.
        Dict Key is tuple of path and stage number, as
             path alone is not unique
@@ -799,7 +810,7 @@ class Index:
     def __init__(
         self,
         filename: Union[bytes, str, os.PathLike],
-        read=True,
+        read: bool = True,
         skip_hash: bool = False,
         version: Optional[int] = None,
     ) -> None:
@@ -820,7 +831,7 @@ class Index:
             self.read()
 
     @property
-    def path(self):
+    def path(self) -> Union[bytes, str]:
         return self._filename
 
     def __repr__(self) -> str:
@@ -828,18 +839,22 @@ class Index:
 
     def write(self) -> None:
         """Write current contents of index to disk."""
+        from typing import BinaryIO, cast
+
         f = GitFile(self._filename, "wb")
         try:
             if self._skip_hash:
                 # When skipHash is enabled, write the index without computing SHA1
-                write_index_dict(f, self._byname, version=self._version)
+                write_index_dict(cast(BinaryIO, f), self._byname, version=self._version)
                 # Write 20 zero bytes instead of SHA1
                 f.write(b"\x00" * 20)
                 f.close()
             else:
-                f = SHA1Writer(f)
-                write_index_dict(f, self._byname, version=self._version)
-                f.close()
+                sha1_writer = SHA1Writer(cast(BinaryIO, f))
+                write_index_dict(
+                    cast(BinaryIO, sha1_writer), self._byname, version=self._version
+                )
+                sha1_writer.close()
         except:
             f.close()
             raise
@@ -850,15 +865,15 @@ class Index:
             return
         f = GitFile(self._filename, "rb")
         try:
-            f = SHA1Reader(f)
-            entries, version = read_index_dict_with_version(f)
+            sha1_reader = SHA1Reader(f)
+            entries, version = read_index_dict_with_version(cast(BinaryIO, sha1_reader))
             self._version = version
             self.update(entries)
             # Read any remaining data before the SHA
-            remaining = os.path.getsize(self._filename) - f.tell() - 20
+            remaining = os.path.getsize(self._filename) - sha1_reader.tell() - 20
             if remaining > 0:
-                f.read(remaining)
-            f.check_sha(allow_empty=True)
+                sha1_reader.read(remaining)
+            sha1_reader.check_sha(allow_empty=True)
         finally:
             f.close()
 
@@ -878,7 +893,7 @@ class Index:
         """Iterate over the paths and stages in this index."""
         return iter(self._byname)
 
-    def __contains__(self, key) -> bool:
+    def __contains__(self, key: bytes) -> bool:
         return key in self._byname
 
     def get_sha1(self, path: bytes) -> bytes:
@@ -936,12 +951,23 @@ class Index:
         for key, value in entries.items():
             self[key] = value
 
-    def paths(self):
+    def paths(self) -> Generator[bytes, None, None]:
         yield from self._byname.keys()
 
     def changes_from_tree(
-        self, object_store, tree: ObjectID, want_unchanged: bool = False
-    ):
+        self,
+        object_store: ObjectContainer,
+        tree: ObjectID,
+        want_unchanged: bool = False,
+    ) -> Generator[
+        tuple[
+            tuple[Optional[bytes], Optional[bytes]],
+            tuple[Optional[int], Optional[int]],
+            tuple[Optional[bytes], Optional[bytes]],
+        ],
+        None,
+        None,
+    ]:
         """Find the differences between the contents of this index and a tree.
 
         Args:
@@ -952,9 +978,13 @@ class Index:
             newmode), (oldsha, newsha)
         """
 
-        def lookup_entry(path):
+        def lookup_entry(path: bytes) -> tuple[bytes, int]:
             entry = self[path]
-            return entry.sha, cleanup_mode(entry.mode)
+            if hasattr(entry, "sha") and hasattr(entry, "mode"):
+                return entry.sha, cleanup_mode(entry.mode)
+            else:
+                # Handle ConflictedIndexEntry case
+                return b"", 0
 
         yield from changes_from_tree(
             self.paths(),
@@ -964,7 +994,7 @@ class Index:
             want_unchanged=want_unchanged,
         )
 
-    def commit(self, object_store):
+    def commit(self, object_store: ObjectContainer) -> bytes:
         """Create a new tree from an index.
 
         Args:
@@ -988,13 +1018,13 @@ def commit_tree(
     """
     trees: dict[bytes, Any] = {b"": {}}
 
-    def add_tree(path):
+    def add_tree(path: bytes) -> dict[bytes, Any]:
         if path in trees:
             return trees[path]
         dirname, basename = pathsplit(path)
         t = add_tree(dirname)
         assert isinstance(basename, bytes)
-        newtree = {}
+        newtree: dict[bytes, Any] = {}
         t[basename] = newtree
         trees[path] = newtree
         return newtree
@@ -1004,7 +1034,7 @@ def commit_tree(
         tree = add_tree(tree_path)
         tree[basename] = (mode, sha)
 
-    def build_tree(path):
+    def build_tree(path: bytes) -> bytes:
         tree = Tree()
         for basename, entry in trees[path].items():
             if isinstance(entry, dict):
@@ -1036,7 +1066,7 @@ def changes_from_tree(
     lookup_entry: Callable[[bytes], tuple[bytes, int]],
     object_store: ObjectContainer,
     tree: Optional[bytes],
-    want_unchanged=False,
+    want_unchanged: bool = False,
 ) -> Iterable[
     tuple[
         tuple[Optional[bytes], Optional[bytes]],
@@ -1082,10 +1112,10 @@ def changes_from_tree(
 
 
 def index_entry_from_stat(
-    stat_val,
+    stat_val: os.stat_result,
     hex_sha: bytes,
     mode: Optional[int] = None,
-):
+) -> IndexEntry:
     """Create a new index entry from a stat value.
 
     Args:
@@ -1118,14 +1148,20 @@ if sys.platform == "win32":
     # https://github.com/jelmer/dulwich/issues/1005
 
     class WindowsSymlinkPermissionError(PermissionError):
-        def __init__(self, errno, msg, filename) -> None:
+        def __init__(self, errno: int, msg: str, filename: Optional[str]) -> None:
             super(PermissionError, self).__init__(
                 errno,
                 f"Unable to create symlink; do you have developer mode enabled? {msg}",
                 filename,
             )
 
-    def symlink(src, dst, target_is_directory=False, *, dir_fd=None):
+    def symlink(
+        src: Union[str, bytes],
+        dst: Union[str, bytes],
+        target_is_directory: bool = False,
+        *,
+        dir_fd: Optional[int] = None,
+    ) -> None:
         try:
             return os.symlink(
                 src, dst, target_is_directory=target_is_directory, dir_fd=dir_fd
@@ -1141,10 +1177,10 @@ def build_file_from_blob(
     mode: int,
     target_path: bytes,
     *,
-    honor_filemode=True,
-    tree_encoding="utf-8",
-    symlink_fn=None,
-):
+    honor_filemode: bool = True,
+    tree_encoding: str = "utf-8",
+    symlink_fn: Optional[Callable] = None,
+) -> os.stat_result:
     """Build a file or symlink on disk based on a Git object.
 
     Args:
@@ -1166,8 +1202,8 @@ def build_file_from_blob(
             os.unlink(target_path)
         if sys.platform == "win32":
             # os.readlink on Python3 on Windows requires a unicode string.
-            contents = contents.decode(tree_encoding)  # type: ignore
-            target_path = target_path.decode(tree_encoding)  # type: ignore
+            contents = contents.decode(tree_encoding)
+            target_path = target_path.decode(tree_encoding)
         (symlink_fn or symlink)(contents, target_path)
     else:
         if oldstat is not None and oldstat.st_size == len(contents):
@@ -1201,7 +1237,10 @@ def validate_path_element_ntfs(element: bytes) -> bool:
     return True
 
 
-def validate_path(path: bytes, element_validator=validate_path_element_default) -> bool:
+def validate_path(
+    path: bytes,
+    element_validator: Callable[[bytes], bool] = validate_path_element_default,
+) -> bool:
     """Default path validator that just checks for .git/."""
     parts = path.split(b"/")
     for p in parts:
@@ -1217,8 +1256,8 @@ def build_index_from_tree(
     object_store: ObjectContainer,
     tree_id: bytes,
     honor_filemode: bool = True,
-    validate_path_element=validate_path_element_default,
-    symlink_fn=None,
+    validate_path_element: Callable[[bytes], bool] = validate_path_element_default,
+    symlink_fn: Optional[Callable] = None,
 ) -> None:
     """Generate and materialize index from a tree.
 
@@ -1289,7 +1328,9 @@ def build_index_from_tree(
     index.write()
 
 
-def blob_from_path_and_mode(fs_path: bytes, mode: int, tree_encoding="utf-8"):
+def blob_from_path_and_mode(
+    fs_path: bytes, mode: int, tree_encoding: str = "utf-8"
+) -> Blob:
     """Create a blob from a path and a stat object.
 
     Args:
@@ -1311,7 +1352,9 @@ def blob_from_path_and_mode(fs_path: bytes, mode: int, tree_encoding="utf-8"):
     return blob
 
 
-def blob_from_path_and_stat(fs_path: bytes, st, tree_encoding="utf-8"):
+def blob_from_path_and_stat(
+    fs_path: bytes, st: os.stat_result, tree_encoding: str = "utf-8"
+) -> Blob:
     """Create a blob from a path and a stat object.
 
     Args:
@@ -1346,7 +1389,7 @@ def read_submodule_head(path: Union[str, bytes]) -> Optional[bytes]:
         return None
 
 
-def _has_directory_changed(tree_path: bytes, entry) -> bool:
+def _has_directory_changed(tree_path: bytes, entry: IndexEntry) -> bool:
     """Check if a directory has changed after getting an error.
 
     When handling an error trying to create a blob from a path, call this
@@ -1372,14 +1415,14 @@ def _has_directory_changed(tree_path: bytes, entry) -> bool:
 
 
 def update_working_tree(
-    repo,
-    old_tree_id,
-    new_tree_id,
-    honor_filemode=True,
-    validate_path_element=None,
-    symlink_fn=None,
-    force_remove_untracked=False,
-):
+    repo: "BaseRepo",
+    old_tree_id: Optional[bytes],
+    new_tree_id: bytes,
+    honor_filemode: bool = True,
+    validate_path_element: Optional[Callable[[bytes], bool]] = None,
+    symlink_fn: Optional[Callable] = None,
+    force_remove_untracked: bool = False,
+) -> None:
     """Update the working tree and index to match a new tree.
 
     This function handles:
@@ -1415,6 +1458,8 @@ def update_working_tree(
     handled_paths = set()
 
     # Get repo path as string for comparisons
+    if not hasattr(repo, "path"):
+        raise ValueError("Repository must have a path attribute")
     repo_path_str = repo.path if isinstance(repo.path, str) else repo.path.decode()
 
     # First, update/add all files in the new tree
@@ -1433,7 +1478,9 @@ def update_working_tree(
         full_path = os.path.join(repo_path_str, entry.path.decode())
 
         # Get the blob
-        blob = repo.object_store[entry.sha]
+        blob_obj = repo.object_store[entry.sha]
+        if not isinstance(blob_obj, Blob):
+            raise ValueError(f"Object {entry.sha!r} is not a blob")
 
         # Ensure parent directory exists
         parent_dir = os.path.dirname(full_path)
@@ -1442,7 +1489,7 @@ def update_working_tree(
 
         # Write the file
         st = build_file_from_blob(
-            blob,
+            blob_obj,
             entry.mode,
             full_path.encode(),
             honor_filemode=honor_filemode,
@@ -1523,8 +1570,10 @@ def update_working_tree(
 
 
 def get_unstaged_changes(
-    index: Index, root_path: Union[str, bytes], filter_blob_callback=None
-):
+    index: Index,
+    root_path: Union[str, bytes],
+    filter_blob_callback: Optional[Callable] = None,
+) -> Generator[bytes, None, None]:
     """Walk through an index and check for differences against working tree.
 
     Args:
@@ -1569,7 +1618,7 @@ def get_unstaged_changes(
 os_sep_bytes = os.sep.encode("ascii")
 
 
-def _tree_to_fs_path(root_path: bytes, tree_path: bytes):
+def _tree_to_fs_path(root_path: bytes, tree_path: bytes) -> bytes:
     """Convert a git tree path to a file system path.
 
     Args:
@@ -1605,7 +1654,7 @@ def _fs_to_tree_path(fs_path: Union[str, bytes]) -> bytes:
     return tree_path
 
 
-def index_entry_from_directory(st, path: bytes) -> Optional[IndexEntry]:
+def index_entry_from_directory(st: os.stat_result, path: bytes) -> Optional[IndexEntry]:
     if os.path.exists(os.path.join(path, b".git")):
         head = read_submodule_head(path)
         if head is None:
@@ -1666,7 +1715,10 @@ def iter_fresh_entries(
 
 
 def iter_fresh_objects(
-    paths: Iterable[bytes], root_path: bytes, include_deleted=False, object_store=None
+    paths: Iterable[bytes],
+    root_path: bytes,
+    include_deleted: bool = False,
+    object_store: Optional[ObjectContainer] = None,
 ) -> Iterator[tuple[bytes, Optional[bytes], Optional[int]]]:
     """Iterate over versions of objects on disk referenced by index.
 
@@ -1705,21 +1757,30 @@ class locked_index:
     Works as a context manager.
     """
 
+    _file: "_GitFile"
+
     def __init__(self, path: Union[bytes, str]) -> None:
         self._path = path
 
-    def __enter__(self):
+    def __enter__(self) -> Index:
         self._file = GitFile(self._path, "wb")
         self._index = Index(self._path)
         return self._index
 
-    def __exit__(self, exc_type, exc_value, traceback):
+    def __exit__(
+        self,
+        exc_type: Optional[type],
+        exc_value: Optional[BaseException],
+        traceback: Optional[types.TracebackType],
+    ) -> None:
         if exc_type is not None:
             self._file.abort()
             return
         try:
-            f = SHA1Writer(self._file)
-            write_index_dict(f, self._index._byname)
+            from typing import BinaryIO, cast
+
+            f = SHA1Writer(cast(BinaryIO, self._file))
+            write_index_dict(cast(BinaryIO, f), self._index._byname)
         except BaseException:
             self._file.abort()
         else: