Kaynağa Gözat

Add type annotations to fix mypy errors

Jelmer Vernooij 5 ay önce
ebeveyn
işleme
27abb6ee8a
4 değiştirilmiş dosya ile 25 ekleme ve 19 silme
  1. 8 8
      dulwich/__init__.py
  2. 10 4
      dulwich/dumb.py
  3. 6 6
      dulwich/web.py
  4. 1 1
      tests/test_dumb.py

+ 8 - 8
dulwich/__init__.py

@@ -24,7 +24,7 @@
 """Python implementation of the Git file formats and protocols."""
 """Python implementation of the Git file formats and protocols."""
 
 
 import sys
 import sys
-from typing import Optional, TypeVar, Union
+from typing import Callable, Optional, TypeVar, Union
 
 
 if sys.version_info >= (3, 10):
 if sys.version_info >= (3, 10):
     from typing import ParamSpec
     from typing import ParamSpec
@@ -37,7 +37,7 @@ __all__ = ["__version__", "replace_me"]
 
 
 P = ParamSpec("P")
 P = ParamSpec("P")
 R = TypeVar("R")
 R = TypeVar("R")
-F = TypeVar("F")
+F = TypeVar("F", bound=Callable[..., object])
 
 
 try:
 try:
     from dissolve import replace_me
     from dissolve import replace_me
@@ -47,12 +47,12 @@ except ImportError:
     def replace_me(
     def replace_me(
         since: Optional[Union[str, tuple[int, ...]]] = None,
         since: Optional[Union[str, tuple[int, ...]]] = None,
         remove_in: Optional[Union[str, tuple[int, ...]]] = None,
         remove_in: Optional[Union[str, tuple[int, ...]]] = None,
-    ):
-        def decorator(func):
+    ) -> Callable[[F], F]:
+        def decorator(func: F) -> F:
             import functools
             import functools
             import warnings
             import warnings
 
 
-            m = f"{func.__name__} is deprecated"
+            m = f"{func.__name__} is deprecated"  # type: ignore[attr-defined]
             since_str = str(since) if since is not None else None
             since_str = str(since) if since is not None else None
             remove_in_str = str(remove_in) if remove_in is not None else None
             remove_in_str = str(remove_in) if remove_in is not None else None
 
 
@@ -66,14 +66,14 @@ except ImportError:
                 m += " and will be removed in a future version"
                 m += " and will be removed in a future version"
 
 
             @functools.wraps(func)
             @functools.wraps(func)
-            def _wrapped_func(*args, **kwargs):
+            def _wrapped_func(*args, **kwargs):  # type: ignore[no-untyped-def]
                 warnings.warn(
                 warnings.warn(
                     m,
                     m,
                     DeprecationWarning,
                     DeprecationWarning,
                     stacklevel=2,
                     stacklevel=2,
                 )
                 )
-                return func(*args, **kwargs)
+                return func(*args, **kwargs)  # type: ignore[operator]
 
 
-            return _wrapped_func
+            return _wrapped_func  # type: ignore[return-value]
 
 
         return decorator
         return decorator

+ 10 - 4
dulwich/dumb.py

@@ -44,7 +44,6 @@ from .objects import (
 )
 )
 from .pack import Pack, PackData, PackIndex, UnpackedObject, load_pack_index_file
 from .pack import Pack, PackData, PackIndex, UnpackedObject, load_pack_index_file
 from .refs import Ref, read_info_refs, split_peeled_refs
 from .refs import Ref, read_info_refs, split_peeled_refs
-from .repo import BaseRepo
 
 
 
 
 class DumbHTTPObjectStore(BaseObjectStore):
 class DumbHTTPObjectStore(BaseObjectStore):
@@ -343,7 +342,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
             shutil.rmtree(self._temp_pack_dir, ignore_errors=True)
             shutil.rmtree(self._temp_pack_dir, ignore_errors=True)
 
 
 
 
-class DumbRemoteHTTPRepo(BaseRepo):
+class DumbRemoteHTTPRepo:
     """Repository implementation for dumb HTTP remotes."""
     """Repository implementation for dumb HTTP remotes."""
 
 
     def __init__(self, base_url: str, http_request_func: Callable[[str, dict[str, str]], tuple[Any, Callable[..., bytes]]]) -> None:
     def __init__(self, base_url: str, http_request_func: Callable[[str, dict[str, str]], tuple[Any, Callable[..., bytes]]]) -> None:
@@ -412,7 +411,14 @@ class DumbRemoteHTTPRepo(BaseRepo):
         sha = self.get_refs().get(ref, None)
         sha = self.get_refs().get(ref, None)
         return sha if sha is not None else ZERO_SHA
         return sha if sha is not None else ZERO_SHA
 
 
-    def fetch_pack_data(self, graph_walker: object, determine_wants: Callable[[dict[Ref, ObjectID]], list[ObjectID]], progress: Optional[Callable[[str], None]] = None, *, get_tagged: Optional[bool] = None, depth: Optional[int] = None) -> Iterator[UnpackedObject]:
+    def fetch_pack_data(
+        self,
+        graph_walker: object,
+        determine_wants: Callable[[dict[Ref, ObjectID]], list[ObjectID]],
+        progress: Optional[Callable[[bytes], None]] = None,
+        get_tagged: Optional[bool] = None,
+        depth: Optional[int] = None,
+    ) -> Iterator[UnpackedObject]:
         """Fetch pack data from the remote.
         """Fetch pack data from the remote.
 
 
         This is the main method for fetching objects from a dumb HTTP remote.
         This is the main method for fetching objects from a dumb HTTP remote.
@@ -468,4 +474,4 @@ class DumbRemoteHTTPRepo(BaseRepo):
                     to_fetch.add(item_sha)
                     to_fetch.add(item_sha)
 
 
             if progress:
             if progress:
-                progress(f"Fetching objects: {len(seen)} done\n")
+                progress(f"Fetching objects: {len(seen)} done\n".encode())

+ 6 - 6
dulwich/web.py

@@ -28,7 +28,7 @@ import sys
 import time
 import time
 from collections.abc import Iterator
 from collections.abc import Iterator
 from io import BytesIO
 from io import BytesIO
-from typing import Callable, ClassVar, Optional
+from typing import BinaryIO, Callable, ClassVar, Optional, cast
 from urllib.parse import parse_qs
 from urllib.parse import parse_qs
 from wsgiref.simple_server import (
 from wsgiref.simple_server import (
     ServerHandler,
     ServerHandler,
@@ -127,10 +127,10 @@ def url_prefix(mat: re.Match[str]) -> str:
 
 
 def get_repo(backend: 'Backend', mat: re.Match[str]) -> BaseRepo:
 def get_repo(backend: 'Backend', mat: re.Match[str]) -> BaseRepo:
     """Get a Repo instance for the given backend and URL regex match."""
     """Get a Repo instance for the given backend and URL regex match."""
-    return backend.open_repository(url_prefix(mat))
+    return cast(BaseRepo, backend.open_repository(url_prefix(mat)))
 
 
 
 
-def send_file(req: 'HTTPGitRequest', f: BytesIO, content_type: str) -> Iterator[bytes]:
+def send_file(req: 'HTTPGitRequest', f: Optional[BinaryIO], content_type: str) -> Iterator[bytes]:
     """Send a file-like object to the request output.
     """Send a file-like object to the request output.
 
 
     Args:
     Args:
@@ -247,7 +247,7 @@ def get_info_packs(req: 'HTTPGitRequest', backend: 'Backend', mat: re.Match[str]
     return generate_objects_info_packs(get_repo(backend, mat))
     return generate_objects_info_packs(get_repo(backend, mat))
 
 
 
 
-def _chunk_iter(f: BytesIO) -> Iterator[bytes]:
+def _chunk_iter(f: BinaryIO) -> Iterator[bytes]:
     while True:
     while True:
         line = f.readline()
         line = f.readline()
         length = int(line.rstrip(), 16)
         length = int(line.rstrip(), 16)
@@ -260,7 +260,7 @@ def _chunk_iter(f: BytesIO) -> Iterator[bytes]:
 class ChunkReader:
 class ChunkReader:
     """Reader for chunked transfer encoding streams."""
     """Reader for chunked transfer encoding streams."""
 
 
-    def __init__(self, f: BytesIO) -> None:
+    def __init__(self, f: BinaryIO) -> None:
         self._iter = _chunk_iter(f)
         self._iter = _chunk_iter(f)
         self._buffer: list[bytes] = []
         self._buffer: list[bytes] = []
 
 
@@ -284,7 +284,7 @@ class _LengthLimitedFile:
     but not implemented in wsgiref as of 2.5.
     but not implemented in wsgiref as of 2.5.
     """
     """
 
 
-    def __init__(self, input: BytesIO, max_bytes: int) -> None:
+    def __init__(self, input: BinaryIO, max_bytes: int) -> None:
         self._input = input
         self._input = input
         self._bytes_avail = max_bytes
         self._bytes_avail = max_bytes
 
 

+ 1 - 1
tests/test_dumb.py

@@ -281,7 +281,7 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
             return [blob_sha]
             return [blob_sha]
 
 
         def progress(msg):
         def progress(msg):
-            assert isinstance(msg, str)
+            assert isinstance(msg, bytes)
 
 
         result = list(
         result = list(
             self.repo.fetch_pack_data(graph_walker, determine_wants, progress)
             self.repo.fetch_pack_data(graph_walker, determine_wants, progress)