Jelajahi Sumber

Add type annotations to dulwich/errors.py

Jelmer Vernooij 5 bulan lalu
induk
melakukan
19acee77ad
1 mengubah file dengan 28 tambahan dan 23 penghapusan
  1. 28 23
      dulwich/errors.py

+ 28 - 23
dulwich/errors.py

@@ -27,20 +27,25 @@
 # that raises the error.
 
 import binascii
+from typing import Optional, Union
 
 
 class ChecksumMismatch(Exception):
     """A checksum didn't match the expected contents."""
 
-    def __init__(self, expected, got, extra=None) -> None:
-        if len(expected) == 20:
-            expected = binascii.hexlify(expected)
-        if len(got) == 20:
-            got = binascii.hexlify(got)
-        self.expected = expected
-        self.got = got
+    def __init__(self, expected: Union[bytes, str], got: Union[bytes, str], extra: Optional[str] = None) -> None:
+        if isinstance(expected, bytes) and len(expected) == 20:
+            expected_str = binascii.hexlify(expected).decode('ascii')
+        else:
+            expected_str = expected if isinstance(expected, str) else expected.decode('ascii')
+        if isinstance(got, bytes) and len(got) == 20:
+            got_str = binascii.hexlify(got).decode('ascii')
+        else:
+            got_str = got if isinstance(got, str) else got.decode('ascii')
+        self.expected = expected_str
+        self.got = got_str
         self.extra = extra
-        message = f"Checksum mismatch: Expected {expected}, got {got}"
+        message = f"Checksum mismatch: Expected {expected_str}, got {got_str}"
         if self.extra is not None:
             message += f"; {extra}"
         Exception.__init__(self, message)
@@ -57,8 +62,8 @@ class WrongObjectException(Exception):
 
     type_name: str
 
-    def __init__(self, sha, *args, **kwargs) -> None:
-        Exception.__init__(self, f"{sha} is not a {self.type_name}")
+    def __init__(self, sha: bytes, *args: object, **kwargs: object) -> None:
+        Exception.__init__(self, f"{sha.decode('ascii')} is not a {self.type_name}")
 
 
 class NotCommitError(WrongObjectException):
@@ -88,40 +93,40 @@ class NotBlobError(WrongObjectException):
 class MissingCommitError(Exception):
     """Indicates that a commit was not found in the repository."""
 
-    def __init__(self, sha, *args, **kwargs) -> None:
+    def __init__(self, sha: bytes, *args: object, **kwargs: object) -> None:
         self.sha = sha
-        Exception.__init__(self, f"{sha} is not in the revision store")
+        Exception.__init__(self, f"{sha.decode('ascii')} is not in the revision store")
 
 
 class ObjectMissing(Exception):
     """Indicates that a requested object is missing."""
 
-    def __init__(self, sha, *args, **kwargs) -> None:
-        Exception.__init__(self, f"{sha} is not in the pack")
+    def __init__(self, sha: bytes, *args: object, **kwargs: object) -> None:
+        Exception.__init__(self, f"{sha.decode('ascii')} is not in the pack")
 
 
 class ApplyDeltaError(Exception):
     """Indicates that applying a delta failed."""
 
-    def __init__(self, *args, **kwargs) -> None:
+    def __init__(self, *args: object, **kwargs: object) -> None:
         Exception.__init__(self, *args, **kwargs)
 
 
 class NotGitRepository(Exception):
     """Indicates that no Git repository was found."""
 
-    def __init__(self, *args, **kwargs) -> None:
+    def __init__(self, *args: object, **kwargs: object) -> None:
         Exception.__init__(self, *args, **kwargs)
 
 
 class GitProtocolError(Exception):
     """Git protocol exception."""
 
-    def __init__(self, *args, **kwargs) -> None:
+    def __init__(self, *args: object, **kwargs: object) -> None:
         Exception.__init__(self, *args, **kwargs)
 
-    def __eq__(self, other):
-        return isinstance(self, type(other)) and self.args == other.args
+    def __eq__(self, other: object) -> bool:
+        return isinstance(other, GitProtocolError) and self.args == other.args
 
 
 class SendPackError(GitProtocolError):
@@ -131,7 +136,7 @@ class SendPackError(GitProtocolError):
 class HangupException(GitProtocolError):
     """Hangup exception."""
 
-    def __init__(self, stderr_lines=None) -> None:
+    def __init__(self, stderr_lines: Optional[list[bytes]] = None) -> None:
         if stderr_lines:
             super().__init__(
                 "\n".join(
@@ -142,14 +147,14 @@ class HangupException(GitProtocolError):
             super().__init__("The remote server unexpectedly closed the connection.")
         self.stderr_lines = stderr_lines
 
-    def __eq__(self, other):
-        return isinstance(self, type(other)) and self.stderr_lines == other.stderr_lines
+    def __eq__(self, other: object) -> bool:
+        return isinstance(other, HangupException) and self.stderr_lines == other.stderr_lines
 
 
 class UnexpectedCommandError(GitProtocolError):
     """Unexpected command received in a proto line."""
 
-    def __init__(self, command) -> None:
+    def __init__(self, command: Optional[str]) -> None:
         command_str = "flush-pkt" if command is None else f"command {command}"
         super().__init__(f"Protocol got unexpected {command_str}")