Selaa lähdekoodia

Add type annotations to dulwich/greenthreads.py

Jelmer Vernooij 5 kuukautta sitten
vanhempi
commit
e95407207a
1 muutettua tiedostoa jossa 16 lisäystä ja 14 poistoa
  1. 16 14
      dulwich/greenthreads.py

+ 16 - 14
dulwich/greenthreads.py

@@ -23,12 +23,13 @@
 
 """Utility module for querying an ObjectStore with gevent."""
 
-from typing import Optional
+from typing import Callable, Optional
 
 import gevent
 from gevent import pool
 
 from .object_store import (
+    BaseObjectStore,
     MissingObjectFinder,
     _collect_ancestors,
     _collect_filetree_revs,
@@ -36,7 +37,7 @@ from .object_store import (
 from .objects import Commit, ObjectID, Tag
 
 
-def _split_commits_and_tags(obj_store, lst, *, ignore_unknown=False, pool=None):
+def _split_commits_and_tags(obj_store: BaseObjectStore, lst: list[ObjectID], *, ignore_unknown: bool = False, pool: pool.Pool) -> tuple[set[ObjectID], set[ObjectID]]:
     """Split object id list into two list with commit SHA1s and tag SHA1s.
 
     Same implementation as object_store._split_commits_and_tags
@@ -45,7 +46,7 @@ def _split_commits_and_tags(obj_store, lst, *, ignore_unknown=False, pool=None):
     commits = set()
     tags = set()
 
-    def find_commit_type(sha) -> None:
+    def find_commit_type(sha: ObjectID) -> None:
         try:
             o = obj_store[sha]
         except KeyError:
@@ -58,7 +59,7 @@ def _split_commits_and_tags(obj_store, lst, *, ignore_unknown=False, pool=None):
                 tags.add(sha)
                 commits.add(o.object[1])
             else:
-                raise KeyError(f"Not a commit or a tag: {sha}")
+                raise KeyError(f"Not a commit or a tag: {sha!r}")
 
     jobs = [pool.spawn(find_commit_type, s) for s in lst]
     gevent.joinall(jobs)
@@ -74,18 +75,19 @@ class GreenThreadsMissingObjectFinder(MissingObjectFinder):
 
     def __init__(
         self,
-        object_store,
-        haves,
-        wants,
-        progress=None,
-        get_tagged=None,
-        concurrency=1,
-        get_parents=None,
+        object_store: BaseObjectStore,
+        haves: list[ObjectID],
+        wants: list[ObjectID],
+        progress: Optional[Callable[[str], None]] = None,
+        get_tagged: Optional[Callable[[], dict[ObjectID, ObjectID]]] = None,
+        concurrency: int = 1,
+        get_parents: Optional[Callable[[ObjectID], list[ObjectID]]] = None,
     ) -> None:
-        def collect_tree_sha(sha) -> None:
+        def collect_tree_sha(sha: ObjectID) -> None:
             self.sha_done.add(sha)
-            cmt = object_store[sha]
-            _collect_filetree_revs(object_store, cmt.tree, self.sha_done)
+            obj = object_store[sha]
+            if isinstance(obj, Commit):
+                _collect_filetree_revs(object_store, obj.tree, self.sha_done)
 
         self.object_store = object_store
         p = pool.Pool(size=concurrency)