Explorar el Código

Add type annotations to RefsContainer and SymrefLoop classes

Jelmer Vernooij hace 5 meses
padre
commit
8a2fcd3b7c
Se han modificado 1 ficheros con 24 adiciones y 36 borrados
  1. 24 36
      dulwich/refs.py

+ 24 - 36
dulwich/refs.py

@@ -27,7 +27,7 @@ import types
 import warnings
 from collections.abc import Iterator
 from contextlib import suppress
-from typing import TYPE_CHECKING, Any, Optional, Union
+from typing import TYPE_CHECKING, Any, Callable, Optional, Union
 
 if TYPE_CHECKING:
     from .file import _GitFile
@@ -55,13 +55,7 @@ ANNOTATED_TAG_SUFFIX = PEELED_TAG_SUFFIX
 class SymrefLoop(Exception):
     """There is a loop between one or more symrefs."""
 
-    def __init__(self, ref, depth) -> None:
-        """Initialize a SymrefLoop exception.
-
-        Args:
-          ref: The ref that caused the loop
-          depth: Depth at which the loop was detected
-        """
+    def __init__(self, ref: bytes, depth: int) -> None:
         self.ref = ref
         self.depth = depth
 
@@ -142,23 +136,18 @@ def parse_remote_ref(ref: bytes) -> tuple[bytes, bytes]:
 class RefsContainer:
     """A container for refs."""
 
-    def __init__(self, logger=None) -> None:
-        """Initialize a RefsContainer.
-
-        Args:
-          logger: Optional logger for reflog updates
-        """
+    def __init__(self, logger: Optional[Callable[[bytes, Optional[bytes], Optional[bytes], Optional[bytes], Optional[int], Optional[int], Optional[bytes]], None]] = None) -> None:
         self._logger = logger
 
     def _log(
         self,
-        ref,
-        old_sha,
-        new_sha,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        ref: bytes,
+        old_sha: Optional[bytes],
+        new_sha: Optional[bytes],
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> None:
         if self._logger is None:
             return
@@ -168,12 +157,12 @@ class RefsContainer:
 
     def set_symbolic_ref(
         self,
-        name,
-        other,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
+        name: bytes,
+        other: bytes,
+        committer: Optional[bytes] = None,
+        timestamp: Optional[int] = None,
+        timezone: Optional[int] = None,
+        message: Optional[bytes] = None,
     ) -> None:
         """Make a ref point at another ref.
 
@@ -206,7 +195,7 @@ class RefsContainer:
         """
         raise NotImplementedError(self.add_packed_refs)
 
-    def get_peeled(self, name) -> Optional[ObjectID]:
+    def get_peeled(self, name: bytes) -> Optional[ObjectID]:
         """Return the cached peeled value of a ref, if available.
 
         Args:
@@ -261,11 +250,10 @@ class RefsContainer:
         """All refs present in this container."""
         raise NotImplementedError(self.allkeys)
 
-    def __iter__(self):
-        """Iterate over all ref names."""
+    def __iter__(self) -> Iterator[Ref]:
         return iter(self.allkeys())
 
-    def keys(self, base=None):
+    def keys(self, base: Optional[bytes] = None) -> Union[Iterator[Ref], set[bytes]]:
         """Refs present in this container.
 
         Args:
@@ -278,7 +266,7 @@ class RefsContainer:
         else:
             return self.allkeys()
 
-    def subkeys(self, base):
+    def subkeys(self, base: bytes) -> set[bytes]:
         """Refs present in this container under a base.
 
         Args:
@@ -293,7 +281,7 @@ class RefsContainer:
                 keys.add(refname[base_len:])
         return keys
 
-    def as_dict(self, base=None) -> dict[Ref, ObjectID]:
+    def as_dict(self, base: Optional[bytes] = None) -> dict[Ref, ObjectID]:
         """Return the contents of this container as a dictionary."""
         ret = {}
         keys = self.keys(base)
@@ -309,7 +297,7 @@ class RefsContainer:
 
         return ret
 
-    def _check_refname(self, name) -> None:
+    def _check_refname(self, name: bytes) -> None:
         """Ensure a refname is valid and lives in refs or is HEAD.
 
         HEAD is not a valid refname according to git-check-ref-format, but this
@@ -328,7 +316,7 @@ class RefsContainer:
         if not name.startswith(b"refs/") or not check_ref_format(name[5:]):
             raise RefFormatError(name)
 
-    def read_ref(self, refname):
+    def read_ref(self, refname: bytes) -> Optional[bytes]:
         """Read a reference without following any references.
 
         Args:
@@ -341,7 +329,7 @@ class RefsContainer:
             contents = self.get_packed_refs().get(refname, None)
         return contents
 
-    def read_loose_ref(self, name) -> bytes:
+    def read_loose_ref(self, name: bytes) -> Optional[bytes]:
         """Read a loose reference and return its contents.
 
         Args: