Bläddra i källkod

diff_tree: Add typing

Jelmer Vernooij 1 månad sedan
förälder
incheckning
80996bd004
1 ändrade filer med 98 tillägg och 46 borttagningar
  1. 98 46
      dulwich/diff_tree.py

+ 98 - 46
dulwich/diff_tree.py

@@ -23,9 +23,10 @@
 
 import stat
 from collections import defaultdict, namedtuple
+from collections.abc import Iterator
 from io import BytesIO
 from itertools import chain
-from typing import Optional
+from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar
 
 from .object_store import BaseObjectStore
 from .objects import S_ISGITLINK, ObjectID, ShaFile, Tree, TreeEntry
@@ -52,11 +53,11 @@ class TreeChange(namedtuple("TreeChange", ["type", "old", "new"])):
     """Named tuple a single change between two trees."""
 
     @classmethod
-    def add(cls, new):
+    def add(cls, new: TreeEntry) -> "TreeChange":
         return cls(CHANGE_ADD, _NULL_ENTRY, new)
 
     @classmethod
-    def delete(cls, old):
+    def delete(cls, old: TreeEntry) -> "TreeChange":
         return cls(CHANGE_DELETE, old, _NULL_ENTRY)
 
 
@@ -112,14 +113,19 @@ def _merge_entries(
     return result
 
 
-def _is_tree(entry):
+def _is_tree(entry: TreeEntry) -> bool:
     mode = entry.mode
     if mode is None:
         return False
     return stat.S_ISDIR(mode)
 
 
-def walk_trees(store, tree1_id, tree2_id, prune_identical=False):
+def walk_trees(
+    store: BaseObjectStore,
+    tree1_id: Optional[ObjectID],
+    tree2_id: Optional[ObjectID],
+    prune_identical: bool = False,
+) -> Iterator[tuple[TreeEntry, TreeEntry]]:
     """Recursively walk all the entries of two trees.
 
     Iteration is depth-first pre-order, as in e.g. os.walk.
@@ -152,25 +158,38 @@ def walk_trees(store, tree1_id, tree2_id, prune_identical=False):
         tree1 = (is_tree1 and store[entry1.sha]) or None
         tree2 = (is_tree2 and store[entry2.sha]) or None
         path = entry1.path or entry2.path
-        todo.extend(reversed(_merge_entries(path, tree1, tree2)))
+
+        # Ensure trees are Tree objects before merging
+        if tree1 is not None and not isinstance(tree1, Tree):
+            tree1 = None
+        if tree2 is not None and not isinstance(tree2, Tree):
+            tree2 = None
+
+        if tree1 is not None or tree2 is not None:
+            # Use empty trees for None values
+            if tree1 is None:
+                tree1 = Tree()
+            if tree2 is None:
+                tree2 = Tree()
+            todo.extend(reversed(_merge_entries(path, tree1, tree2)))
         yield entry1, entry2
 
 
-def _skip_tree(entry, include_trees):
+def _skip_tree(entry: TreeEntry, include_trees: bool) -> TreeEntry:
     if entry.mode is None or (not include_trees and stat.S_ISDIR(entry.mode)):
         return _NULL_ENTRY
     return entry
 
 
 def tree_changes(
-    store,
-    tree1_id,
-    tree2_id,
-    want_unchanged=False,
-    rename_detector=None,
-    include_trees=False,
-    change_type_same=False,
-):
+    store: BaseObjectStore,
+    tree1_id: Optional[ObjectID],
+    tree2_id: Optional[ObjectID],
+    want_unchanged: bool = False,
+    rename_detector: Optional["RenameDetector"] = None,
+    include_trees: bool = False,
+    change_type_same: bool = False,
+) -> Iterator[TreeChange]:
     """Find the differences between the contents of two trees.
 
     Args:
@@ -231,14 +250,18 @@ def tree_changes(
         yield TreeChange(change_type, entry1, entry2)
 
 
-def _all_eq(seq, key, value) -> bool:
+T = TypeVar("T")
+U = TypeVar("U")
+
+
+def _all_eq(seq: list[T], key: Callable[[T], U], value: U) -> bool:
     for e in seq:
         if key(e) != value:
             return False
     return True
 
 
-def _all_same(seq, key):
+def _all_same(seq: list[Any], key: Callable[[Any], Any]) -> bool:
     return _all_eq(seq[1:], key, key(seq[0]))
 
 
@@ -246,8 +269,8 @@ def tree_changes_for_merge(
     store: BaseObjectStore,
     parent_tree_ids: list[ObjectID],
     tree_id: ObjectID,
-    rename_detector=None,
-):
+    rename_detector: Optional["RenameDetector"] = None,
+) -> Iterator[list[Optional[TreeChange]]]:
     """Get the tree changes for a merge tree relative to all its parents.
 
     Args:
@@ -286,10 +309,10 @@ def tree_changes_for_merge(
                 path = change.new.path
             changes_by_path[path][i] = change
 
-    def old_sha(c):
+    def old_sha(c: TreeChange) -> Optional[ObjectID]:
         return c.old.sha
 
-    def change_type(c):
+    def change_type(c: TreeChange) -> str:
         return c.type
 
     # Yield only conflicting changes.
@@ -348,7 +371,7 @@ def _count_blocks(obj: ShaFile) -> dict[int, int]:
     return block_counts
 
 
-def _common_bytes(blocks1, blocks2):
+def _common_bytes(blocks1: dict[int, int], blocks2: dict[int, int]) -> int:
     """Count the number of common bytes in two block count dicts.
 
     Args:
@@ -370,7 +393,11 @@ def _common_bytes(blocks1, blocks2):
     return score
 
 
-def _similarity_score(obj1, obj2, block_cache=None):
+def _similarity_score(
+    obj1: ShaFile,
+    obj2: ShaFile,
+    block_cache: Optional[dict[ObjectID, dict[int, int]]] = None,
+) -> int:
     """Compute a similarity score for two objects.
 
     Args:
@@ -398,7 +425,7 @@ def _similarity_score(obj1, obj2, block_cache=None):
     return int(float(common_bytes) * _MAX_SCORE / max_size)
 
 
-def _tree_change_key(entry):
+def _tree_change_key(entry: TreeChange) -> tuple[bytes, bytes]:
     # Sort by old path then new path. If only one exists, use it for both keys.
     path1 = entry.old.path
     path2 = entry.new.path
@@ -419,11 +446,11 @@ class RenameDetector:
 
     def __init__(
         self,
-        store,
-        rename_threshold=RENAME_THRESHOLD,
-        max_files=MAX_FILES,
-        rewrite_threshold=REWRITE_THRESHOLD,
-        find_copies_harder=False,
+        store: BaseObjectStore,
+        rename_threshold: int = RENAME_THRESHOLD,
+        max_files: Optional[int] = MAX_FILES,
+        rewrite_threshold: Optional[int] = REWRITE_THRESHOLD,
+        find_copies_harder: bool = False,
     ) -> None:
         """Initialize the rename detector.
 
@@ -454,7 +481,7 @@ class RenameDetector:
         self._deletes = []
         self._changes = []
 
-    def _should_split(self, change):
+    def _should_split(self, change: TreeChange) -> bool:
         if (
             self._rewrite_threshold is None
             or change.type != CHANGE_MODIFY
@@ -465,7 +492,7 @@ class RenameDetector:
         new_obj = self._store[change.new.sha]
         return _similarity_score(old_obj, new_obj) < self._rewrite_threshold
 
-    def _add_change(self, change) -> None:
+    def _add_change(self, change: TreeChange) -> None:
         if change.type == CHANGE_ADD:
             self._adds.append(change)
         elif change.type == CHANGE_DELETE:
@@ -484,7 +511,9 @@ class RenameDetector:
         else:
             self._changes.append(change)
 
-    def _collect_changes(self, tree1_id, tree2_id) -> None:
+    def _collect_changes(
+        self, tree1_id: Optional[ObjectID], tree2_id: Optional[ObjectID]
+    ) -> None:
         want_unchanged = self._find_copies_harder or self._want_unchanged
         for change in tree_changes(
             self._store,
@@ -495,7 +524,7 @@ class RenameDetector:
         ):
             self._add_change(change)
 
-    def _prune(self, add_paths, delete_paths) -> None:
+    def _prune(self, add_paths: set[bytes], delete_paths: set[bytes]) -> None:
         self._adds = [a for a in self._adds if a.new.path not in add_paths]
         self._deletes = [d for d in self._deletes if d.old.path not in delete_paths]
 
@@ -532,10 +561,14 @@ class RenameDetector:
                     self._changes.append(TreeChange(CHANGE_COPY, old, new))
         self._prune(add_paths, delete_paths)
 
-    def _should_find_content_renames(self):
+    def _should_find_content_renames(self) -> bool:
+        if self._max_files is None:
+            return True
         return len(self._adds) * len(self._deletes) <= self._max_files**2
 
-    def _rename_type(self, check_paths, delete, add):
+    def _rename_type(
+        self, check_paths: bool, delete: TreeChange, add: TreeChange
+    ) -> str:
         if check_paths and delete.old.path == add.new.path:
             # If the paths match, this must be a split modify, so make sure it
             # comes out as a modify.
@@ -618,7 +651,7 @@ class RenameDetector:
         self._deletes = [a for a in self._deletes if a.new.path not in modifies]
         self._changes += modifies.values()
 
-    def _sorted_changes(self):
+    def _sorted_changes(self) -> list[TreeChange]:
         result = []
         result.extend(self._adds)
         result.extend(self._deletes)
@@ -632,8 +665,12 @@ class RenameDetector:
         self._deletes = [d for d in self._deletes if d.type != CHANGE_UNCHANGED]
 
     def changes_with_renames(
-        self, tree1_id, tree2_id, want_unchanged=False, include_trees=False
-    ):
+        self,
+        tree1_id: Optional[ObjectID],
+        tree2_id: Optional[ObjectID],
+        want_unchanged: bool = False,
+        include_trees: bool = False,
+    ) -> list[TreeChange]:
         """Iterate TreeChanges between two tree SHAs, with rename detection."""
         self._reset()
         self._want_unchanged = want_unchanged
@@ -651,12 +688,27 @@ class RenameDetector:
 _is_tree_py = _is_tree
 _merge_entries_py = _merge_entries
 _count_blocks_py = _count_blocks
-try:
-    # Try to import Rust versions
-    from dulwich._diff_tree import (  # type: ignore
-        _count_blocks,
-        _is_tree,
-        _merge_entries,
-    )
-except ImportError:
+
+if TYPE_CHECKING:
+    # For type checking, use the Python implementations
     pass
+else:
+    # At runtime, try to import Rust extensions
+    try:
+        # Try to import Rust versions
+        from dulwich._diff_tree import (
+            _count_blocks as _rust_count_blocks,
+        )
+        from dulwich._diff_tree import (
+            _is_tree as _rust_is_tree,
+        )
+        from dulwich._diff_tree import (
+            _merge_entries as _rust_merge_entries,
+        )
+
+        # Override with Rust versions
+        _count_blocks = _rust_count_blocks
+        _is_tree = _rust_is_tree
+        _merge_entries = _rust_merge_entries
+    except ImportError:
+        pass