Sfoglia il codice sorgente

Add type annotations to multiple dulwich modules

Jelmer Vernooij 5 mesi fa
parent
commit
4783f9f4d4
5 ha cambiato i file con 54 aggiunte e 43 eliminazioni
  1. 4 4
      dulwich/annotate.py
  2. 11 11
      dulwich/file.py
  3. 17 15
      dulwich/hooks.py
  4. 7 5
      dulwich/lfs.py
  5. 15 8
      dulwich/objects.py

+ 4 - 4
dulwich/annotate.py

@@ -53,7 +53,7 @@ def update_lines(
     new_blob: "Blob",
 ) -> list[tuple[tuple["Commit", "TreeEntry"], bytes]]:
     """Update annotation lines with old blob lines."""
-    ret: list[tuple[tuple["Commit", "TreeEntry"], bytes]] = []
+    ret: list[tuple[tuple[Commit, TreeEntry], bytes]] = []
     new_lines = new_blob.splitlines()
     matcher = difflib.SequenceMatcher(
         a=[line for (h, line) in annotated_lines], b=new_lines
@@ -92,10 +92,10 @@ def annotate_lines(
     walker = Walker(
         store, include=[commit_id], paths=[path], order=order, follow=follow
     )
-    revs: list[tuple["Commit", "TreeEntry"]] = []
+    revs: list[tuple[Commit, TreeEntry]] = []
     for log_entry in walker:
         for tree_change in log_entry.changes():
-            changes: list["TreeChange"]
+            changes: list[TreeChange]
             if isinstance(tree_change, list):
                 changes = tree_change
             else:
@@ -106,7 +106,7 @@ def annotate_lines(
                     revs.append((log_entry.commit, change.new))
                     break
 
-    lines_annotated: list[tuple[tuple["Commit", "TreeEntry"], bytes]] = []
+    lines_annotated: list[tuple[tuple[Commit, TreeEntry], bytes]] = []
     for commit, entry in reversed(revs):
         lines_annotated = update_lines(lines_annotated, (commit, entry), cast("Blob", store[entry.sha]))
     return lines_annotated

+ 11 - 11
dulwich/file.py

@@ -24,10 +24,10 @@
 import os
 import sys
 import warnings
-from typing import ClassVar, Union
+from typing import Any, ClassVar, Union
 
 
-def ensure_dir_exists(dirname) -> None:
+def ensure_dir_exists(dirname: Union[str, bytes, os.PathLike]) -> None:
     """Ensure a directory exists, creating if necessary."""
     try:
         os.makedirs(dirname)
@@ -35,7 +35,7 @@ def ensure_dir_exists(dirname) -> None:
         pass
 
 
-def _fancy_rename(oldname, newname) -> None:
+def _fancy_rename(oldname: Union[str, bytes], newname: Union[str, bytes]) -> None:
     """Rename file with temporary backup file to rollback if rename fails."""
     if not os.path.exists(newname):
         os.rename(oldname, newname)
@@ -45,7 +45,7 @@ def _fancy_rename(oldname, newname) -> None:
     import tempfile
 
     # destination file exists
-    (fd, tmpfile) = tempfile.mkstemp(".tmp", prefix=oldname, dir=".")
+    (fd, tmpfile) = tempfile.mkstemp(".tmp", prefix=str(oldname), dir=".")
     os.close(fd)
     os.remove(tmpfile)
     os.rename(newname, tmpfile)
@@ -58,8 +58,8 @@ def _fancy_rename(oldname, newname) -> None:
 
 
 def GitFile(
-    filename: Union[str, bytes, os.PathLike], mode="rb", bufsize=-1, mask=0o644
-):
+    filename: Union[str, bytes, os.PathLike], mode: str = "rb", bufsize: int = -1, mask: int = 0o644
+) -> Any:
     """Create a file object that obeys the git file locking protocol.
 
     Returns: a builtin file object or a _GitFile object
@@ -90,7 +90,7 @@ def GitFile(
 class FileLocked(Exception):
     """File is already locked."""
 
-    def __init__(self, filename, lockfilename) -> None:
+    def __init__(self, filename: Union[str, bytes, os.PathLike], lockfilename: Union[str, bytes]) -> None:
         self.filename = filename
         self.lockfilename = lockfilename
         super().__init__(filename, lockfilename)
@@ -132,7 +132,7 @@ class _GitFile:
     }
 
     def __init__(
-        self, filename: Union[str, bytes, os.PathLike], mode, bufsize, mask
+        self, filename: Union[str, bytes, os.PathLike], mode: str, bufsize: int, mask: int
     ) -> None:
         # Convert PathLike to str/bytes for our internal use
         self._filename: Union[str, bytes] = os.fspath(filename)
@@ -205,16 +205,16 @@ class _GitFile:
             warnings.warn(f"unclosed {self!r}", ResourceWarning, stacklevel=2)
             self.abort()
 
-    def __enter__(self):
+    def __enter__(self) -> "_GitFile":
         return self
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
         if exc_type is not None:
             self.abort()
         else:
             self.close()
 
-    def __getattr__(self, name):
+    def __getattr__(self, name: str) -> Any:
         """Proxy property calls to the underlying file."""
         if name in self.PROXY_PROPERTIES:
             return getattr(self._file, name)

+ 17 - 15
dulwich/hooks.py

@@ -23,6 +23,7 @@
 
 import os
 import subprocess
+from typing import Any, Callable, Optional
 
 from .errors import HookError
 
@@ -30,7 +31,7 @@ from .errors import HookError
 class Hook:
     """Generic hook object."""
 
-    def execute(self, *args):
+    def execute(self, *args: Any) -> Any:
         """Execute the hook with the given args.
 
         Args:
@@ -53,12 +54,12 @@ class ShellHook(Hook):
 
     def __init__(
         self,
-        name,
-        path,
-        numparam,
-        pre_exec_callback=None,
-        post_exec_callback=None,
-        cwd=None,
+        name: str,
+        path: str,
+        numparam: int,
+        pre_exec_callback: Optional[Callable[..., Any]] = None,
+        post_exec_callback: Optional[Callable[..., Any]] = None,
+        cwd: Optional[str] = None,
     ) -> None:
         """Setup shell hook definition.
 
@@ -85,7 +86,7 @@ class ShellHook(Hook):
 
         self.cwd = cwd
 
-    def execute(self, *args):
+    def execute(self, *args: Any) -> Any:
         """Execute the hook with given args."""
         if len(args) != self.numparam:
             raise HookError(
@@ -113,7 +114,7 @@ class ShellHook(Hook):
 class PreCommitShellHook(ShellHook):
     """pre-commit shell hook."""
 
-    def __init__(self, cwd, controldir) -> None:
+    def __init__(self, cwd: str, controldir: str) -> None:
         filepath = os.path.join(controldir, "hooks", "pre-commit")
 
         ShellHook.__init__(self, "pre-commit", filepath, 0, cwd=cwd)
@@ -122,7 +123,7 @@ class PreCommitShellHook(ShellHook):
 class PostCommitShellHook(ShellHook):
     """post-commit shell hook."""
 
-    def __init__(self, controldir) -> None:
+    def __init__(self, controldir: str) -> None:
         filepath = os.path.join(controldir, "hooks", "post-commit")
 
         ShellHook.__init__(self, "post-commit", filepath, 0, cwd=controldir)
@@ -131,10 +132,10 @@ class PostCommitShellHook(ShellHook):
 class CommitMsgShellHook(ShellHook):
     """commit-msg shell hook."""
 
-    def __init__(self, controldir) -> None:
+    def __init__(self, controldir: str) -> None:
         filepath = os.path.join(controldir, "hooks", "commit-msg")
 
-        def prepare_msg(*args):
+        def prepare_msg(*args: bytes) -> tuple[str, ...]:
             import tempfile
 
             (fd, path) = tempfile.mkstemp()
@@ -144,13 +145,14 @@ class CommitMsgShellHook(ShellHook):
 
             return (path,)
 
-        def clean_msg(success, *args):
+        def clean_msg(success: int, *args: str) -> Optional[bytes]:
             if success:
                 with open(args[0], "rb") as f:
                     new_msg = f.read()
                 os.unlink(args[0])
                 return new_msg
             os.unlink(args[0])
+            return None
 
         ShellHook.__init__(
             self, "commit-msg", filepath, 1, prepare_msg, clean_msg, controldir
@@ -160,12 +162,12 @@ class CommitMsgShellHook(ShellHook):
 class PostReceiveShellHook(ShellHook):
     """post-receive shell hook."""
 
-    def __init__(self, controldir) -> None:
+    def __init__(self, controldir: str) -> None:
         self.controldir = controldir
         filepath = os.path.join(controldir, "hooks", "post-receive")
         ShellHook.__init__(self, "post-receive", path=filepath, numparam=0)
 
-    def execute(self, client_refs):
+    def execute(self, client_refs: list[tuple[bytes, bytes, bytes]]) -> Optional[bytes]:
         # do nothing if the script doesn't exist
         if not os.path.exists(self.filepath):
             return None

+ 7 - 5
dulwich/lfs.py

@@ -31,6 +31,8 @@ from urllib.parse import urljoin, urlparse
 from urllib.request import Request, urlopen
 
 if TYPE_CHECKING:
+    import urllib3
+    
     from .config import Config
     from .repo import Repo
 
@@ -289,7 +291,7 @@ class LFSFilterDriver:
         return content
 
 
-def _get_lfs_user_agent(config):
+def _get_lfs_user_agent(config: Optional["Config"]) -> str:
     """Get User-Agent string for LFS requests, respecting git config."""
     try:
         if config:
@@ -365,12 +367,12 @@ class LFSClient:
         """Get the LFS server URL without trailing slash."""
         return self._base_url.rstrip("/")
 
-    def _get_pool_manager(self):
+    def _get_pool_manager(self) -> "urllib3.PoolManager":
         """Get urllib3 pool manager with git config applied."""
         if self._pool_manager is None:
             from dulwich.client import default_urllib3_manager
 
-            self._pool_manager = default_urllib3_manager(self.config)
+            self._pool_manager = default_urllib3_manager(self.config)  # type: ignore[assignment]
         return self._pool_manager
 
     def _make_request(
@@ -397,7 +399,7 @@ class LFSClient:
             raise ValueError(
                 f"HTTP {response.status}: {response.data.decode('utf-8', errors='ignore')}"
             )
-        return response.data
+        return response.data  # type: ignore[return-value]
 
     def batch(
         self,
@@ -513,7 +515,7 @@ class LFSClient:
         if actual_oid != oid:
             raise LFSError(f"Downloaded OID {actual_oid} != expected {oid}")
 
-        return content
+        return content  # type: ignore[return-value]
 
     def upload(
         self, oid: str, size: int, content: bytes, ref: Optional[str] = None

+ 15 - 8
dulwich/objects.py

@@ -45,7 +45,7 @@ else:
     from typing_extensions import Self
 
 if sys.version_info >= (3, 10):
-    from typing import TypeGuard  # type: ignore
+    from typing import TypeGuard
 else:
     from typing_extensions import TypeGuard
 
@@ -244,6 +244,13 @@ def check_identity(identity: Optional[bytes], error_msg: str) -> None:
         raise ObjectFormatException(error_msg)
 
 
+def _path_to_bytes(path: Union[str, bytes]) -> bytes:
+    """Convert a path to bytes for use in error messages."""
+    if isinstance(path, str):
+        return path.encode("utf-8", "surrogateescape")
+    return path
+
+
 def check_time(time_seconds: int) -> None:
     """Check if the specified time is not prone to overflow error.
 
@@ -270,7 +277,7 @@ class FixedSha:
 
     def __init__(self, hexsha: Union[str, bytes]) -> None:
         if isinstance(hexsha, str):
-            hexsha = hexsha.encode("ascii")  # type: ignore
+            hexsha = hexsha.encode("ascii")
         if not isinstance(hexsha, bytes):
             raise TypeError(f"Expected bytes for hexsha, got {hexsha!r}")
         self._hexsha = hexsha
@@ -432,7 +439,7 @@ class ShaFile:
         if sha is None:
             self._sha = None
         else:
-            self._sha = FixedSha(sha)  # type: ignore
+            self._sha = FixedSha(sha)
         self._needs_serialization = False
 
     @staticmethod
@@ -686,7 +693,7 @@ class Blob(ShaFile):
     def from_path(cls, path: Union[str, bytes]) -> "Blob":
         blob = ShaFile.from_path(path)
         if not isinstance(blob, cls):
-            raise NotBlobError(path)
+            raise NotBlobError(_path_to_bytes(path))
         return blob
 
     def check(self) -> None:
@@ -706,7 +713,7 @@ class Blob(ShaFile):
         if not chunks:
             return []
         if len(chunks) == 1:
-            return chunks[0].splitlines(True)
+            return chunks[0].splitlines(True)  # type: ignore[no-any-return]
         remaining = None
         ret = []
         for chunk in chunks:
@@ -834,7 +841,7 @@ class Tag(ShaFile):
     def from_path(cls, filename: Union[str, bytes]) -> "Tag":
         tag = ShaFile.from_path(filename)
         if not isinstance(tag, cls):
-            raise NotTagError(filename)
+            raise NotTagError(_path_to_bytes(filename))
         return tag
 
     def check(self) -> None:
@@ -1186,7 +1193,7 @@ class Tree(ShaFile):
     def from_path(cls, filename: Union[str, bytes]) -> "Tree":
         tree = ShaFile.from_path(filename)
         if not isinstance(tree, cls):
-            raise NotTreeError(filename)
+            raise NotTreeError(_path_to_bytes(filename))
         return tree
 
     def __contains__(self, name: bytes) -> bool:
@@ -1536,7 +1543,7 @@ class Commit(ShaFile):
     def from_path(cls, path: Union[str, bytes]) -> "Commit":
         commit = ShaFile.from_path(path)
         if not isinstance(commit, cls):
-            raise NotCommitError(path)
+            raise NotCommitError(_path_to_bytes(path))
         return commit
 
     def _deserialize(self, chunks: list[bytes]) -> None: