Browse Source

walk: Add typing

Jelmer Vernooij 1 month ago
parent
commit
07078d719f
1 changed files with 72 additions and 29 deletions
  1. 72 29
      dulwich/walk.py

+ 72 - 29
dulwich/walk.py

@@ -23,8 +23,12 @@
 
 import collections
 import heapq
+from collections.abc import Iterator
 from itertools import chain
-from typing import Optional
+from typing import TYPE_CHECKING, Any, Callable, Optional, Union, cast
+
+if TYPE_CHECKING:
+    from .object_store import BaseObjectStore
 
 from .diff_tree import (
     RENAME_CHANGE_TYPES,
@@ -48,14 +52,16 @@ _MAX_EXTRA_COMMITS = 5
 class WalkEntry:
     """Object encapsulating a single result from a walk."""
 
-    def __init__(self, walker, commit) -> None:
+    def __init__(self, walker: "Walker", commit: Commit) -> None:
         self.commit = commit
         self._store = walker.store
         self._get_parents = walker.get_parents
-        self._changes: dict[str, list[TreeChange]] = {}
+        self._changes: dict[Optional[bytes], list[TreeChange]] = {}
         self._rename_detector = walker.rename_detector
 
-    def changes(self, path_prefix=None):
+    def changes(
+        self, path_prefix: Optional[bytes] = None
+    ) -> Union[list[TreeChange], list[list[TreeChange]]]:
         """Get the tree changes for this entry.
 
         Args:
@@ -75,7 +81,7 @@ class WalkEntry:
                 parent = None
             elif len(self._get_parents(commit)) == 1:
                 changes_func = tree_changes
-                parent = self._store[self._get_parents(commit)[0]].tree
+                parent = cast(Commit, self._store[self._get_parents(commit)[0]]).tree
                 if path_prefix:
                     mode, subtree_sha = parent.lookup_path(
                         self._store.__getitem__,
@@ -83,13 +89,28 @@ class WalkEntry:
                     )
                     parent = self._store[subtree_sha]
             else:
-                changes_func = tree_changes_for_merge
-                parent = [self._store[p].tree for p in self._get_parents(commit)]
+                # For merge commits, we need to handle multiple parents differently
+                parent = [
+                    cast(Commit, self._store[p]).tree for p in self._get_parents(commit)
+                ]
+                # Use a lambda to adapt the signature
+                changes_func = cast(
+                    Any,
+                    lambda store,
+                    parent_trees,
+                    tree_id,
+                    rename_detector=None: tree_changes_for_merge(
+                        store, parent_trees, tree_id, rename_detector
+                    ),
+                )
                 if path_prefix:
                     parent_trees = [self._store[p] for p in parent]
                     parent = []
                     for p in parent_trees:
                         try:
+                            from .objects import Tree
+
+                            assert isinstance(p, Tree)
                             mode, st = p.lookup_path(
                                 self._store.__getitem__,
                                 path_prefix,
@@ -101,6 +122,9 @@ class WalkEntry:
             commit_tree_sha = commit.tree
             if path_prefix:
                 commit_tree = self._store[commit_tree_sha]
+                from .objects import Tree
+
+                assert isinstance(commit_tree, Tree)
                 mode, commit_tree_sha = commit_tree.lookup_path(
                     self._store.__getitem__,
                     path_prefix,
@@ -117,7 +141,7 @@ class WalkEntry:
         return self._changes[path_prefix]
 
     def __repr__(self) -> str:
-        return f"<WalkEntry commit={self.commit.id}, changes={self.changes()!r}>"
+        return f"<WalkEntry commit={self.commit.id.decode('ascii')}, changes={self.changes()!r}>"
 
 
 class _CommitTimeQueue:
@@ -133,14 +157,14 @@ class _CommitTimeQueue:
         self._seen: set[ObjectID] = set()
         self._done: set[ObjectID] = set()
         self._min_time = walker.since
-        self._last = None
+        self._last: Optional[Commit] = None
         self._extra_commits_left = _MAX_EXTRA_COMMITS
         self._is_finished = False
 
         for commit_id in chain(walker.include, walker.excluded):
             self._push(commit_id)
 
-    def _push(self, object_id: bytes) -> None:
+    def _push(self, object_id: ObjectID) -> None:
         try:
             obj = self._store[object_id]
         except KeyError as exc:
@@ -149,13 +173,15 @@ class _CommitTimeQueue:
             self._push(obj.object[1])
             return
         # TODO(jelmer): What to do about non-Commit and non-Tag objects?
+        if not isinstance(obj, Commit):
+            return
         commit = obj
         if commit.id not in self._pq_set and commit.id not in self._done:
             heapq.heappush(self._pq, (-commit.commit_time, commit))
             self._pq_set.add(commit.id)
             self._seen.add(commit.id)
 
-    def _exclude_parents(self, commit) -> None:
+    def _exclude_parents(self, commit: Commit) -> None:
         excluded = self._excluded
         seen = self._seen
         todo = [commit]
@@ -167,10 +193,10 @@ class _CommitTimeQueue:
                     # some caching (which DiskObjectStore currently does not).
                     # We could either add caching in this class or pass around
                     # parsed queue entry objects instead of commits.
-                    todo.append(self._store[parent])
+                    todo.append(cast(Commit, self._store[parent]))
                 excluded.add(parent)
 
-    def next(self):
+    def next(self) -> Optional[WalkEntry]:
         if self._is_finished:
             return None
         while self._pq:
@@ -233,7 +259,7 @@ class Walker:
 
     def __init__(
         self,
-        store,
+        store: "BaseObjectStore",
         include: list[bytes],
         exclude: Optional[list[bytes]] = None,
         order: str = "date",
@@ -244,8 +270,8 @@ class Walker:
         follow: bool = False,
         since: Optional[int] = None,
         until: Optional[int] = None,
-        get_parents=lambda commit: commit.parents,
-        queue_cls=_CommitTimeQueue,
+        get_parents: Callable[[Commit], list[bytes]] = lambda commit: commit.parents,
+        queue_cls: type = _CommitTimeQueue,
     ) -> None:
         """Constructor.
 
@@ -300,7 +326,7 @@ class Walker:
         self._queue = queue_cls(self)
         self._out_queue: collections.deque[WalkEntry] = collections.deque()
 
-    def _path_matches(self, changed_path) -> bool:
+    def _path_matches(self, changed_path: Optional[bytes]) -> bool:
         if changed_path is None:
             return False
         if self.paths is None:
@@ -315,7 +341,7 @@ class Walker:
                 return True
         return False
 
-    def _change_matches(self, change) -> bool:
+    def _change_matches(self, change: TreeChange) -> bool:
         assert self.paths
         if not change:
             return False
@@ -331,7 +357,7 @@ class Walker:
             return True
         return False
 
-    def _should_return(self, entry) -> Optional[bool]:
+    def _should_return(self, entry: WalkEntry) -> Optional[bool]:
         """Determine if a walk entry should be returned..
 
         Args:
@@ -359,12 +385,24 @@ class Walker:
                     if self._change_matches(change):
                         return True
         else:
-            for change in entry.changes():
-                if self._change_matches(change):
-                    return True
+            changes = entry.changes()
+            # Handle both list[TreeChange] and list[list[TreeChange]]
+            if changes and isinstance(changes[0], list):
+                # It's list[list[TreeChange]], flatten it
+                for change_list in changes:
+                    for change in change_list:
+                        if self._change_matches(change):
+                            return True
+            else:
+                # It's list[TreeChange]
+                from .diff_tree import TreeChange
+
+                for change in changes:
+                    if isinstance(change, TreeChange) and self._change_matches(change):
+                        return True
         return None
 
-    def _next(self):
+    def _next(self) -> Optional[WalkEntry]:
         max_entries = self.max_entries
         while max_entries is None or self._num_entries < max_entries:
             entry = next(self._queue)
@@ -379,7 +417,9 @@ class Walker:
                     return entry
         return None
 
-    def _reorder(self, results):
+    def _reorder(
+        self, results: Iterator[WalkEntry]
+    ) -> Union[Iterator[WalkEntry], list[WalkEntry]]:
         """Possibly reorder a results iterator.
 
         Args:
@@ -394,11 +434,14 @@ class Walker:
             results = reversed(list(results))
         return results
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[WalkEntry]:
         return iter(self._reorder(iter(self._next, None)))
 
 
-def _topo_reorder(entries, get_parents=lambda commit: commit.parents):
+def _topo_reorder(
+    entries: Iterator[WalkEntry],
+    get_parents: Callable[[Commit], list[bytes]] = lambda commit: commit.parents,
+) -> Iterator[WalkEntry]:
     """Reorder an iterable of entries topologically.
 
     This works best assuming the entries are already in almost-topological
@@ -410,9 +453,9 @@ def _topo_reorder(entries, get_parents=lambda commit: commit.parents):
     Returns: iterator over WalkEntry objects from entries in FIFO order, except
         where a parent would be yielded before any of its children.
     """
-    todo = collections.deque()
-    pending = {}
-    num_children = collections.defaultdict(int)
+    todo: collections.deque[WalkEntry] = collections.deque()
+    pending: dict[bytes, WalkEntry] = {}
+    num_children: dict[bytes, int] = collections.defaultdict(int)
     for entry in entries:
         todo.append(entry)
         for p in get_parents(entry.commit):