Browse Source

Add type annotations to dulwich/tests/utils.py

Jelmer Vernooij 5 months ago
parent
commit
7cb26e8e0d
1 changed files with 36 additions and 28 deletions
  1. 36 28
      dulwich/tests/utils.py

+ 36 - 28
dulwich/tests/utils.py

@@ -21,6 +21,8 @@
 
 """Utility functions common to Dulwich tests."""
 
+# ruff: noqa: ANN401
+
 import datetime
 import os
 import shutil
@@ -28,10 +30,12 @@ import tempfile
 import time
 import types
 import warnings
+from typing import Any, BinaryIO, Callable, Optional, TypeVar, Union
 from unittest import SkipTest
 
 from dulwich.index import commit_tree
-from dulwich.objects import Commit, FixedSha, Tag, object_class
+from dulwich.object_store import BaseObjectStore
+from dulwich.objects import Commit, FixedSha, ShaFile, Tag, object_class
 from dulwich.pack import (
     DELTA_TYPES,
     OFS_DELTA,
@@ -47,8 +51,10 @@ from dulwich.repo import Repo
 # Plain files are very frequently used in tests, so let the mode be very short.
 F = 0o100644  # Shorthand mode for Files.
 
+T = TypeVar("T", bound=ShaFile)
+
 
-def open_repo(name, temp_dir=None):
+def open_repo(name: str, temp_dir: Optional[str] = None) -> Repo:
     """Open a copy of a repo in a temporary directory.
 
     Use this function for accessing repos in dulwich/tests/data/repos to avoid
@@ -72,14 +78,14 @@ def open_repo(name, temp_dir=None):
     return Repo(temp_repo_dir)
 
 
-def tear_down_repo(repo) -> None:
+def tear_down_repo(repo: Repo) -> None:
     """Tear down a test repository."""
     repo.close()
     temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
     shutil.rmtree(temp_dir)
 
 
-def make_object(cls, **attrs):
+def make_object(cls: type[T], **attrs: Any) -> T:
     """Make an object for testing and assign some members.
 
     This method creates a new subclass to allow arbitrary attribute
@@ -92,7 +98,7 @@ def make_object(cls, **attrs):
     Returns: A newly initialized object of type cls.
     """
 
-    class TestObject(cls):
+    class TestObject(cls):  # type: ignore[misc,valid-type]
         """Class that inherits from the given class, but without __slots__.
 
         Note that classes with __slots__ can't have arbitrary attributes
@@ -102,7 +108,7 @@ def make_object(cls, **attrs):
 
     TestObject.__name__ = "TestObject_" + cls.__name__
 
-    obj = TestObject()
+    obj = TestObject()  # type: ignore[abstract]
     for name, value in attrs.items():
         if name == "id":
             # id property is read-only, so we overwrite sha instead.
@@ -113,7 +119,7 @@ def make_object(cls, **attrs):
     return obj
 
 
-def make_commit(**attrs):
+def make_commit(**attrs: Any) -> Commit:
     """Make a Commit object with a default set of members.
 
     Args:
@@ -136,7 +142,7 @@ def make_commit(**attrs):
     return make_object(Commit, **all_attrs)
 
 
-def make_tag(target, **attrs):
+def make_tag(target: ShaFile, **attrs: Any) -> Tag:
     """Make a Tag object with a default set of values.
 
     Args:
@@ -159,16 +165,16 @@ def make_tag(target, **attrs):
     return make_object(Tag, **all_attrs)
 
 
-def functest_builder(method, func):
+def functest_builder(method: Callable[[Any, Any], None], func: Any) -> Callable[[Any], None]:
     """Generate a test method that tests the given function."""
 
-    def do_test(self) -> None:
+    def do_test(self: Any) -> None:
         method(self, func)
 
     return do_test
 
 
-def ext_functest_builder(method, func):
+def ext_functest_builder(method: Callable[[Any, Any], None], func: Any) -> Callable[[Any], None]:
     """Generate a test method that tests the given extension function.
 
     This is intended to generate test methods that test both a pure-Python
@@ -190,7 +196,7 @@ def ext_functest_builder(method, func):
       func: The function implementation to pass to method.
     """
 
-    def do_test(self) -> None:
+    def do_test(self: Any) -> None:
         if not isinstance(func, types.BuiltinFunctionType):
             raise SkipTest(f"{func} extension not found")
         method(self, func)
@@ -198,7 +204,7 @@ def ext_functest_builder(method, func):
     return do_test
 
 
-def build_pack(f, objects_spec, store=None):
+def build_pack(f: BinaryIO, objects_spec: list[tuple[int, Any]], store: Optional[BaseObjectStore] = None) -> list[tuple[int, int, bytes, bytes, int]]:
     """Write test pack data from a concise spec.
 
     Args:
@@ -221,14 +227,14 @@ def build_pack(f, objects_spec, store=None):
     num_objects = len(objects_spec)
     write_pack_header(sf.write, num_objects)
 
-    full_objects = {}
-    offsets = {}
-    crc32s = {}
+    full_objects: dict[int, tuple[int, bytes, bytes]] = {}
+    offsets: dict[int, int] = {}
+    crc32s: dict[int, int] = {}
 
     while len(full_objects) < num_objects:
         for i, (type_num, data) in enumerate(objects_spec):
             if type_num not in DELTA_TYPES:
-                full_objects[i] = (type_num, data, obj_sha(type_num, [data]))
+                full_objects[i] = (type_num, data, obj_sha(type_num, [data]))  # type: ignore[no-untyped-call]
                 continue
             base, data = data
             if isinstance(base, int):
@@ -236,11 +242,12 @@ def build_pack(f, objects_spec, store=None):
                     continue
                 base_type_num, _, _ = full_objects[base]
             else:
+                assert store is not None
                 base_type_num, _ = store.get_raw(base)
             full_objects[i] = (
                 base_type_num,
                 data,
-                obj_sha(base_type_num, [data]),
+                obj_sha(base_type_num, [data]),  # type: ignore[no-untyped-call]
             )
 
     for i, (type_num, obj) in enumerate(objects_spec):
@@ -249,17 +256,18 @@ def build_pack(f, objects_spec, store=None):
             base_index, data = obj
             base = offset - offsets[base_index]
             _, base_data, _ = full_objects[base_index]
-            obj = (base, list(create_delta(base_data, data)))
+            obj = (base, list(create_delta(base_data, data)))  # type: ignore[no-untyped-call]
         elif type_num == REF_DELTA:
             base_ref, data = obj
             if isinstance(base_ref, int):
                 _, base_data, base = full_objects[base_ref]
             else:
+                assert store is not None
                 base_type_num, base_data = store.get_raw(base_ref)
-                base = obj_sha(base_type_num, base_data)
-            obj = (base, list(create_delta(base_data, data)))
+                base = obj_sha(base_type_num, base_data)  # type: ignore[no-untyped-call]
+            obj = (base, list(create_delta(base_data, data)))  # type: ignore[no-untyped-call]
 
-        crc32 = write_pack_object(sf.write, type_num, obj)
+        crc32 = write_pack_object(sf.write, type_num, obj)  # type: ignore[no-untyped-call]
         offsets[i] = offset
         crc32s[i] = crc32
 
@@ -269,12 +277,12 @@ def build_pack(f, objects_spec, store=None):
         assert len(sha) == 20
         expected.append((offsets[i], type_num, data, sha, crc32s[i]))
 
-    sf.write_sha()
+    sf.write_sha()  # type: ignore[no-untyped-call]
     f.seek(0)
     return expected
 
 
-def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
+def build_commit_graph(object_store: BaseObjectStore, commit_spec: list[list[int]], trees: Optional[dict[int, list[Union[tuple[bytes, ShaFile], tuple[bytes, ShaFile, int]]]]] = None, attrs: Optional[dict[int, dict[str, Any]]] = None) -> list[Commit]:
     """Build a commit graph from a concise specification.
 
     Sample usage:
@@ -311,7 +319,7 @@ def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
     if attrs is None:
         attrs = {}
     commit_time = 0
-    nums = {}
+    nums: dict[int, bytes] = {}
     commits = []
 
     for commit in commit_spec:
@@ -343,7 +351,7 @@ def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
 
         # By default, increment the time by a lot. Out-of-order commits should
         # be closer together than this because their main cause is clock skew.
-        commit_time = commit_attrs["commit_time"] + 100
+        commit_time = commit_attrs["commit_time"] + 100  # type: ignore[operator]
         nums[commit_num] = commit_obj.id
         object_store.add_object(commit_obj)
         commits.append(commit_obj)
@@ -351,12 +359,12 @@ def build_commit_graph(object_store, commit_spec, trees=None, attrs=None):
     return commits
 
 
-def setup_warning_catcher():
+def setup_warning_catcher() -> tuple[list[Warning], Callable[[], None]]:
     """Wrap warnings.showwarning with code that records warnings."""
     caught_warnings = []
     original_showwarning = warnings.showwarning
 
-    def custom_showwarning(*args, **kwargs) -> None:
+    def custom_showwarning(*args: Any, **kwargs: Any) -> None:
         caught_warnings.append(args[0])
 
     warnings.showwarning = custom_showwarning