Selaa lähdekoodia

Add more typing (#1747)

Jelmer Vernooij 5 kuukautta sitten
vanhempi
commit
a821dca472
5 muutettua tiedostoa jossa 87 lisäystä ja 54 poistoa
  1. 8 8
      dulwich/__init__.py
  2. 6 6
      dulwich/diff.py
  3. 37 20
      dulwich/dumb.py
  4. 35 19
      dulwich/web.py
  5. 1 1
      tests/test_dumb.py

+ 8 - 8
dulwich/__init__.py

@@ -24,7 +24,7 @@
 """Python implementation of the Git file formats and protocols."""
 
 import sys
-from typing import Optional, TypeVar, Union
+from typing import Callable, Optional, TypeVar, Union
 
 if sys.version_info >= (3, 10):
     from typing import ParamSpec
@@ -37,7 +37,7 @@ __all__ = ["__version__", "replace_me"]
 
 P = ParamSpec("P")
 R = TypeVar("R")
-F = TypeVar("F")
+F = TypeVar("F", bound=Callable[..., object])
 
 try:
     from dissolve import replace_me
@@ -47,12 +47,12 @@ except ImportError:
     def replace_me(
         since: Optional[Union[str, tuple[int, ...]]] = None,
         remove_in: Optional[Union[str, tuple[int, ...]]] = None,
-    ):
-        def decorator(func):
+    ) -> Callable[[F], F]:
+        def decorator(func: F) -> F:
             import functools
             import warnings
 
-            m = f"{func.__name__} is deprecated"
+            m = f"{func.__name__} is deprecated"  # type: ignore[attr-defined]
             since_str = str(since) if since is not None else None
             remove_in_str = str(remove_in) if remove_in is not None else None
 
@@ -66,14 +66,14 @@ except ImportError:
                 m += " and will be removed in a future version"
 
             @functools.wraps(func)
-            def _wrapped_func(*args, **kwargs):
+            def _wrapped_func(*args, **kwargs):  # type: ignore[no-untyped-def]
                 warnings.warn(
                     m,
                     DeprecationWarning,
                     stacklevel=2,
                 )
-                return func(*args, **kwargs)
+                return func(*args, **kwargs)  # type: ignore[operator]
 
-            return _wrapped_func
+            return _wrapped_func  # type: ignore[return-value]
 
         return decorator

+ 6 - 6
dulwich/diff.py

@@ -515,7 +515,7 @@ class ColorizedDiffStream:
     """
 
     @staticmethod
-    def is_available():
+    def is_available() -> bool:
         """Check if Rich is available for colorization.
 
         Returns:
@@ -528,7 +528,7 @@ class ColorizedDiffStream:
         except ImportError:
             return False
 
-    def __init__(self, output_stream):
+    def __init__(self, output_stream: BinaryIO) -> None:
         """Initialize the colorized stream wrapper.
 
         Args:
@@ -546,7 +546,7 @@ class ColorizedDiffStream:
         self.console = Console(file=self.text_wrapper, force_terminal=True)
         self.buffer = b""
 
-    def write(self, data):
+    def write(self, data: bytes) -> None:
         """Write data to the stream, applying colorization.
 
         Args:
@@ -560,7 +560,7 @@ class ColorizedDiffStream:
             line, self.buffer = self.buffer.split(b"\n", 1)
             self._colorize_and_write_line(line + b"\n")
 
-    def writelines(self, lines):
+    def writelines(self, lines: list[bytes]) -> None:
         """Write a list of lines to the stream.
 
         Args:
@@ -569,7 +569,7 @@ class ColorizedDiffStream:
         for line in lines:
             self.write(line)
 
-    def _colorize_and_write_line(self, line_bytes):
+    def _colorize_and_write_line(self, line_bytes: bytes) -> None:
         """Apply color formatting to a single line and write it.
 
         Args:
@@ -593,7 +593,7 @@ class ColorizedDiffStream:
             # Fallback to raw output if we can't decode/encode the text
             self.output_stream.write(line_bytes)
 
-    def flush(self):
+    def flush(self) -> None:
         """Flush any remaining buffered content and the underlying stream."""
         # Write any remaining buffer content
         if self.buffer:

+ 37 - 20
dulwich/dumb.py

@@ -26,7 +26,7 @@ import tempfile
 import zlib
 from collections.abc import Iterator
 from io import BytesIO
-from typing import Optional
+from typing import Any, Callable, Optional
 from urllib.parse import urljoin
 
 from .errors import NotGitRepository, ObjectFormatException
@@ -44,13 +44,18 @@ from .objects import (
 )
 from .pack import Pack, PackData, PackIndex, UnpackedObject, load_pack_index_file
 from .refs import Ref, read_info_refs, split_peeled_refs
-from .repo import BaseRepo
 
 
 class DumbHTTPObjectStore(BaseObjectStore):
     """Object store implementation that fetches objects over dumb HTTP."""
 
-    def __init__(self, base_url: str, http_request_func):
+    def __init__(
+        self,
+        base_url: str,
+        http_request_func: Callable[
+            [str, dict[str, str]], tuple[Any, Callable[..., bytes]]
+        ],
+    ) -> None:
         """Initialize a DumbHTTPObjectStore.
 
         Args:
@@ -62,9 +67,9 @@ class DumbHTTPObjectStore(BaseObjectStore):
         self._http_request = http_request_func
         self._packs: Optional[list[tuple[str, Optional[PackIndex]]]] = None
         self._cached_objects: dict[bytes, tuple[int, bytes]] = {}
-        self._temp_pack_dir = None
+        self._temp_pack_dir: Optional[str] = None
 
-    def _ensure_temp_pack_dir(self):
+    def _ensure_temp_pack_dir(self) -> None:
         """Ensure we have a temporary directory for storing pack files."""
         if self._temp_pack_dir is None:
             self._temp_pack_dir = tempfile.mkdtemp(prefix="dulwich-dumb-")
@@ -152,7 +157,7 @@ class DumbHTTPObjectStore(BaseObjectStore):
 
         return type_map[obj_type], content
 
-    def _load_packs(self):
+    def _load_packs(self) -> None:
         """Load the list of available packs from the remote."""
         if self._packs is not None:
             return
@@ -320,22 +325,26 @@ class DumbHTTPObjectStore(BaseObjectStore):
                     yield sha_to_hex(sha)
 
     @property
-    def packs(self):
+    def packs(self) -> list[Any]:
         """Iterable of pack objects.
 
         Note: Returns empty list as we don't have actual Pack objects.
         """
         return []
 
-    def add_object(self, obj) -> None:
+    def add_object(self, obj: ShaFile) -> None:
         """Add a single object to this object store."""
         raise NotImplementedError("Cannot add objects to dumb HTTP repository")
 
-    def add_objects(self, objects, progress=None) -> None:
+    def add_objects(
+        self,
+        objects: Iterator[ShaFile],
+        progress: Optional[Callable[[int], None]] = None,
+    ) -> None:
         """Add a set of objects to this object store."""
         raise NotImplementedError("Cannot add objects to dumb HTTP repository")
 
-    def __del__(self):
+    def __del__(self) -> None:
         """Clean up temporary directory on deletion."""
         if self._temp_pack_dir and os.path.exists(self._temp_pack_dir):
             import shutil
@@ -343,10 +352,16 @@ class DumbHTTPObjectStore(BaseObjectStore):
             shutil.rmtree(self._temp_pack_dir, ignore_errors=True)
 
 
-class DumbRemoteHTTPRepo(BaseRepo):
+class DumbRemoteHTTPRepo:
     """Repository implementation for dumb HTTP remotes."""
 
-    def __init__(self, base_url: str, http_request_func):
+    def __init__(
+        self,
+        base_url: str,
+        http_request_func: Callable[
+            [str, dict[str, str]], tuple[Any, Callable[..., bytes]]
+        ],
+    ) -> None:
         """Initialize a DumbRemoteHTTPRepo.
 
         Args:
@@ -357,12 +372,7 @@ class DumbRemoteHTTPRepo(BaseRepo):
         self._http_request = http_request_func
         self._refs: Optional[dict[Ref, ObjectID]] = None
         self._peeled: Optional[dict[Ref, ObjectID]] = None
-        self._object_store = DumbHTTPObjectStore(base_url, http_request_func)
-
-    @property
-    def object_store(self):
-        """ObjectStore for this repository."""
-        return self._object_store
+        self.object_store = DumbHTTPObjectStore(base_url, http_request_func)
 
     def _fetch_url(self, path: str) -> bytes:
         """Fetch content from a URL path relative to base_url."""
@@ -417,7 +427,14 @@ class DumbRemoteHTTPRepo(BaseRepo):
         sha = self.get_refs().get(ref, None)
         return sha if sha is not None else ZERO_SHA
 
-    def fetch_pack_data(self, graph_walker, determine_wants, progress=None, depth=None):
+    def fetch_pack_data(
+        self,
+        graph_walker: object,
+        determine_wants: Callable[[dict[Ref, ObjectID]], list[ObjectID]],
+        progress: Optional[Callable[[bytes], None]] = None,
+        get_tagged: Optional[bool] = None,
+        depth: Optional[int] = None,
+    ) -> Iterator[UnpackedObject]:
         """Fetch pack data from the remote.
 
         This is the main method for fetching objects from a dumb HTTP remote.
@@ -451,7 +468,7 @@ class DumbRemoteHTTPRepo(BaseRepo):
 
             # Fetch the object
             try:
-                type_num, content = self._object_store.get_raw(sha)
+                type_num, content = self.object_store.get_raw(sha)
             except KeyError:
                 # Object not found, skip it
                 continue

+ 35 - 19
dulwich/web.py

@@ -28,7 +28,7 @@ import sys
 import time
 from collections.abc import Iterator
 from io import BytesIO
-from typing import Callable, ClassVar, Optional
+from typing import BinaryIO, Callable, ClassVar, Optional, cast
 from urllib.parse import parse_qs
 from wsgiref.simple_server import (
     ServerHandler,
@@ -66,7 +66,7 @@ NO_CACHE_HEADERS = [
 ]
 
 
-def cache_forever_headers(now=None):
+def cache_forever_headers(now: Optional[float] = None) -> list[tuple[str, str]]:
     if now is None:
         now = time.time()
     return [
@@ -113,7 +113,7 @@ def date_time_string(timestamp: Optional[float] = None) -> str:
     )
 
 
-def url_prefix(mat) -> str:
+def url_prefix(mat: re.Match[str]) -> str:
     """Extract the URL prefix from a regex match.
 
     Args:
@@ -125,12 +125,14 @@ def url_prefix(mat) -> str:
     return "/" + mat.string[: mat.start()].strip("/")
 
 
-def get_repo(backend, mat) -> BaseRepo:
+def get_repo(backend: "Backend", mat: re.Match[str]) -> BaseRepo:
     """Get a Repo instance for the given backend and URL regex match."""
-    return backend.open_repository(url_prefix(mat))
+    return cast(BaseRepo, backend.open_repository(url_prefix(mat)))
 
 
-def send_file(req, f, content_type):
+def send_file(
+    req: "HTTPGitRequest", f: Optional[BinaryIO], content_type: str
+) -> Iterator[bytes]:
     """Send a file-like object to the request output.
 
     Args:
@@ -155,18 +157,22 @@ def send_file(req, f, content_type):
         f.close()
 
 
-def _url_to_path(url):
+def _url_to_path(url: str) -> str:
     return url.replace("/", os.path.sep)
 
 
-def get_text_file(req, backend, mat):
+def get_text_file(
+    req: "HTTPGitRequest", backend: "Backend", mat: re.Match[str]
+) -> Iterator[bytes]:
     req.nocache()
     path = _url_to_path(mat.group())
     logger.info("Sending plain text file %s", path)
     return send_file(req, get_repo(backend, mat).get_named_file(path), "text/plain")
 
 
-def get_loose_object(req, backend, mat):
+def get_loose_object(
+    req: "HTTPGitRequest", backend: "Backend", mat: re.Match[str]
+) -> Iterator[bytes]:
     sha = (mat.group(1) + mat.group(2)).encode("ascii")
     logger.info("Sending loose object %s", sha)
     object_store = get_repo(backend, mat).object_store
@@ -183,7 +189,9 @@ def get_loose_object(req, backend, mat):
     yield data
 
 
-def get_pack_file(req, backend, mat):
+def get_pack_file(
+    req: "HTTPGitRequest", backend: "Backend", mat: re.Match[str]
+) -> Iterator[bytes]:
     req.cache_forever()
     path = _url_to_path(mat.group())
     logger.info("Sending pack file %s", path)
@@ -194,7 +202,9 @@ def get_pack_file(req, backend, mat):
     )
 
 
-def get_idx_file(req, backend, mat):
+def get_idx_file(
+    req: "HTTPGitRequest", backend: "Backend", mat: re.Match[str]
+) -> Iterator[bytes]:
     req.cache_forever()
     path = _url_to_path(mat.group())
     logger.info("Sending pack file %s", path)
@@ -205,7 +215,9 @@ def get_idx_file(req, backend, mat):
     )
 
 
-def get_info_refs(req, backend, mat):
+def get_info_refs(
+    req: "HTTPGitRequest", backend: "Backend", mat: re.Match[str]
+) -> Iterator[bytes]:
     params = parse_qs(req.environ["QUERY_STRING"])
     service = params.get("service", [None])[0]
     try:
@@ -240,14 +252,16 @@ def get_info_refs(req, backend, mat):
         yield from generate_info_refs(repo)
 
 
-def get_info_packs(req, backend, mat):
+def get_info_packs(
+    req: "HTTPGitRequest", backend: "Backend", mat: re.Match[str]
+) -> Iterator[bytes]:
     req.nocache()
     req.respond(HTTP_OK, "text/plain")
     logger.info("Emulating dumb info/packs")
     return generate_objects_info_packs(get_repo(backend, mat))
 
 
-def _chunk_iter(f):
+def _chunk_iter(f: BinaryIO) -> Iterator[bytes]:
     while True:
         line = f.readline()
         length = int(line.rstrip(), 16)
@@ -260,11 +274,11 @@ def _chunk_iter(f):
 class ChunkReader:
     """Reader for chunked transfer encoding streams."""
 
-    def __init__(self, f) -> None:
+    def __init__(self, f: BinaryIO) -> None:
         self._iter = _chunk_iter(f)
         self._buffer: list[bytes] = []
 
-    def read(self, n):
+    def read(self, n: int) -> bytes:
         while sum(map(len, self._buffer)) < n:
             try:
                 self._buffer.append(next(self._iter))
@@ -284,11 +298,11 @@ class _LengthLimitedFile:
     but not implemented in wsgiref as of 2.5.
     """
 
-    def __init__(self, input, max_bytes) -> None:
+    def __init__(self, input: BinaryIO, max_bytes: int) -> None:
         self._input = input
         self._bytes_avail = max_bytes
 
-    def read(self, size=-1):
+    def read(self, size: int = -1) -> bytes:
         if self._bytes_avail <= 0:
             return b""
         if size == -1 or size > self._bytes_avail:
@@ -299,7 +313,9 @@ class _LengthLimitedFile:
     # TODO: support more methods as necessary
 
 
-def handle_service_request(req, backend, mat):
+def handle_service_request(
+    req: "HTTPGitRequest", backend: "Backend", mat: re.Match[str]
+) -> Iterator[bytes]:
     service = mat.group().lstrip("/")
     logger.info("Handling service request for %s", service)
     handler_cls = req.handlers.get(service.encode("ascii"), None)

+ 1 - 1
tests/test_dumb.py

@@ -268,7 +268,7 @@ fedcba9876543210fedcba9876543210fedcba98\trefs/tags/v1.0
         blob.data = b"Test content"
         blob_sha = blob.id
         # Add blob response
-        self.repo._object_store._cached_objects[blob_sha] = (
+        self.repo.object_store._cached_objects[blob_sha] = (
             Blob.type_num,
             blob.as_raw_string(),
         )