Ver Fonte

protocol: Add typing

Jelmer Vernooij há 3 meses atrás
pai
commit
7a486e03d0
1 ficheiros alterados com 65 adições e 39 exclusões
  1. 65 39
      dulwich/protocol.py

+ 65 - 39
dulwich/protocol.py

@@ -22,9 +22,11 @@
 
 """Generic functions for talking the git smart server protocol."""
 
+import types
+from collections.abc import Iterable
 from io import BytesIO
 from os import SEEK_END
-from typing import Optional
+from typing import Callable, Optional
 
 import dulwich
 
@@ -128,30 +130,30 @@ DEPTH_INFINITE = 0x7FFFFFFF
 NAK_LINE = b"NAK\n"
 
 
-def agent_string():
+def agent_string() -> bytes:
     return ("dulwich/" + ".".join(map(str, dulwich.__version__))).encode("ascii")
 
 
-def capability_agent():
+def capability_agent() -> bytes:
     return CAPABILITY_AGENT + b"=" + agent_string()
 
 
-def capability_symref(from_ref, to_ref):
+def capability_symref(from_ref: bytes, to_ref: bytes) -> bytes:
     return CAPABILITY_SYMREF + b"=" + from_ref + b":" + to_ref
 
 
-def extract_capability_names(capabilities):
+def extract_capability_names(capabilities: Iterable[bytes]) -> set[bytes]:
     return {parse_capability(c)[0] for c in capabilities}
 
 
-def parse_capability(capability):
+def parse_capability(capability: bytes) -> tuple[bytes, Optional[bytes]]:
     parts = capability.split(b"=", 1)
     if len(parts) == 1:
         return (parts[0], None)
-    return tuple(parts)
+    return (parts[0], parts[1])
 
 
-def symref_capabilities(symrefs):
+def symref_capabilities(symrefs: Iterable[tuple[bytes, bytes]]) -> list[bytes]:
     return [capability_symref(*k) for k in symrefs]
 
 
@@ -163,18 +165,18 @@ COMMAND_WANT = b"want"
 COMMAND_HAVE = b"have"
 
 
-def format_cmd_pkt(cmd, *args):
+def format_cmd_pkt(cmd: bytes, *args: bytes) -> bytes:
     return cmd + b" " + b"".join([(a + b"\0") for a in args])
 
 
-def parse_cmd_pkt(line):
+def parse_cmd_pkt(line: bytes) -> tuple[bytes, list[bytes]]:
     splice_at = line.find(b" ")
     cmd, args = line[:splice_at], line[splice_at + 1 :]
     assert args[-1:] == b"\x00"
     return cmd, args[:-1].split(b"\0")
 
 
-def pkt_line(data):
+def pkt_line(data: Optional[bytes]) -> bytes:
     """Wrap data in a pkt-line.
 
     Args:
@@ -187,7 +189,7 @@ def pkt_line(data):
     return ("%04x" % (len(data) + 4)).encode("ascii") + data
 
 
-def pkt_seq(*seq):
+def pkt_seq(*seq: Optional[bytes]) -> bytes:
     """Wrap a sequence of data in pkt-lines.
 
     Args:
@@ -196,7 +198,9 @@ def pkt_seq(*seq):
     return b"".join([pkt_line(s) for s in seq]) + pkt_line(None)
 
 
-def filter_ref_prefix(refs, prefixes):
+def filter_ref_prefix(
+    refs: dict[bytes, bytes], prefixes: Iterable[bytes]
+) -> dict[bytes, bytes]:
     """Filter refs to only include those with a given prefix.
 
     Args:
@@ -218,7 +222,13 @@ class Protocol:
         Documentation/technical/protocol-common.txt
     """
 
-    def __init__(self, read, write, close=None, report_activity=None) -> None:
+    def __init__(
+        self,
+        read: Callable[[int], bytes],
+        write: Callable[[bytes], Optional[int]],
+        close: Optional[Callable[[], None]] = None,
+        report_activity: Optional[Callable[[int, str], None]] = None,
+    ) -> None:
         self.read = read
         self.write = write
         self._close = close
@@ -229,13 +239,18 @@ class Protocol:
         if self._close:
             self._close()
 
-    def __enter__(self):
+    def __enter__(self) -> "Protocol":
         return self
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    def __exit__(
+        self,
+        exc_type: Optional[type[BaseException]],
+        exc_val: Optional[BaseException],
+        exc_tb: Optional[types.TracebackType],
+    ) -> None:
         self.close()
 
-    def read_pkt_line(self):
+    def read_pkt_line(self) -> Optional[bytes]:
         """Reads a pkt-line from the remote git process.
 
         This method may read from the readahead buffer; see unread_pkt_line.
@@ -287,7 +302,7 @@ class Protocol:
         self.unread_pkt_line(next_line)
         return False
 
-    def unread_pkt_line(self, data) -> None:
+    def unread_pkt_line(self, data: Optional[bytes]) -> None:
         """Unread a single line of data into the readahead buffer.
 
         This method can be used to unread a single pkt-line into a fixed
@@ -303,7 +318,7 @@ class Protocol:
             raise ValueError("Attempted to unread multiple pkt-lines.")
         self._readahead = BytesIO(pkt_line(data))
 
-    def read_pkt_seq(self):
+    def read_pkt_seq(self) -> Iterable[bytes]:
         """Read a sequence of pkt-lines from the remote git process.
 
         Returns: Yields each line of data up to but not including the next
@@ -314,7 +329,7 @@ class Protocol:
             yield pkt
             pkt = self.read_pkt_line()
 
-    def write_pkt_line(self, line) -> None:
+    def write_pkt_line(self, line: Optional[bytes]) -> None:
         """Sends a pkt-line to the remote git process.
 
         Args:
@@ -329,7 +344,7 @@ class Protocol:
         except OSError as exc:
             raise GitProtocolError(str(exc)) from exc
 
-    def write_sideband(self, channel, blob) -> None:
+    def write_sideband(self, channel: int, blob: bytes) -> None:
         """Write multiplexed data to the sideband.
 
         Args:
@@ -343,7 +358,7 @@ class Protocol:
             self.write_pkt_line(bytes(bytearray([channel])) + blob[:65515])
             blob = blob[65515:]
 
-    def send_cmd(self, cmd, *args) -> None:
+    def send_cmd(self, cmd: bytes, *args: bytes) -> None:
         """Send a command and some arguments to a git server.
 
         Only used for the TCP git protocol (git://).
@@ -354,7 +369,7 @@ class Protocol:
         """
         self.write_pkt_line(format_cmd_pkt(cmd, *args))
 
-    def read_cmd(self):
+    def read_cmd(self) -> tuple[bytes, list[bytes]]:
         """Read a command and some arguments from the git client.
 
         Only used for the TCP git protocol (git://).
@@ -362,6 +377,8 @@ class Protocol:
         Returns: A tuple of (command, [list of arguments]).
         """
         line = self.read_pkt_line()
+        if line is None:
+            raise GitProtocolError("Expected command, got flush packet")
         return parse_cmd_pkt(line)
 
 
@@ -381,14 +398,19 @@ class ReceivableProtocol(Protocol):
     """
 
     def __init__(
-        self, recv, write, close=None, report_activity=None, rbufsize=_RBUFSIZE
+        self,
+        recv: Callable[[int], bytes],
+        write: Callable[[bytes], Optional[int]],
+        close: Optional[Callable[[], None]] = None,
+        report_activity: Optional[Callable[[int, str], None]] = None,
+        rbufsize: int = _RBUFSIZE,
     ) -> None:
         super().__init__(self.read, write, close=close, report_activity=report_activity)
         self._recv = recv
         self._rbuf = BytesIO()
         self._rbufsize = rbufsize
 
-    def read(self, size):
+    def read(self, size: int) -> bytes:
         # From _fileobj.read in socket.py in the Python 2.6.5 standard library,
         # with the following modifications:
         #  - omit the size <= 0 branch
@@ -449,7 +471,7 @@ class ReceivableProtocol(Protocol):
         buf.seek(start)
         return buf.read()
 
-    def recv(self, size):
+    def recv(self, size: int) -> bytes:
         assert size > 0
 
         buf = self._rbuf
@@ -473,7 +495,7 @@ class ReceivableProtocol(Protocol):
         return buf.read(size)
 
 
-def extract_capabilities(text):
+def extract_capabilities(text: bytes) -> tuple[bytes, list[bytes]]:
     """Extract a capabilities list from a string, if present.
 
     Args:
@@ -486,7 +508,7 @@ def extract_capabilities(text):
     return (text, capabilities.strip().split(b" "))
 
 
-def extract_want_line_capabilities(text):
+def extract_want_line_capabilities(text: bytes) -> tuple[bytes, list[bytes]]:
     """Extract a capabilities list from a want line, if present.
 
     Note that want lines have capabilities separated from the rest of the line
@@ -504,7 +526,7 @@ def extract_want_line_capabilities(text):
     return (b" ".join(split_text[:2]), split_text[2:])
 
 
-def ack_type(capabilities):
+def ack_type(capabilities: Iterable[bytes]) -> int:
     """Extract the ack type from a capabilities list."""
     if b"multi_ack_detailed" in capabilities:
         return MULTI_ACK_DETAILED
@@ -521,7 +543,9 @@ class BufferedPktLineWriter:
     (including length prefix) reach the buffer size.
     """
 
-    def __init__(self, write, bufsize=65515) -> None:
+    def __init__(
+        self, write: Callable[[bytes], Optional[int]], bufsize: int = 65515
+    ) -> None:
         """Initialize the BufferedPktLineWriter.
 
         Args:
@@ -533,7 +557,7 @@ class BufferedPktLineWriter:
         self._wbuf = BytesIO()
         self._buflen = 0
 
-    def write(self, data) -> None:
+    def write(self, data: bytes) -> None:
         """Write data, wrapping it in a pkt-line."""
         line = pkt_line(data)
         line_len = len(line)
@@ -560,11 +584,11 @@ class BufferedPktLineWriter:
 class PktLineParser:
     """Packet line parser that hands completed packets off to a callback."""
 
-    def __init__(self, handle_pkt) -> None:
+    def __init__(self, handle_pkt: Callable[[Optional[bytes]], None]) -> None:
         self.handle_pkt = handle_pkt
         self._readahead = BytesIO()
 
-    def parse(self, data) -> None:
+    def parse(self, data: bytes) -> None:
         """Parse a fragment of data and call back for any completed packets."""
         self._readahead.write(data)
         buf = self._readahead.getvalue()
@@ -583,31 +607,33 @@ class PktLineParser:
         self._readahead = BytesIO()
         self._readahead.write(buf)
 
-    def get_tail(self):
+    def get_tail(self) -> bytes:
         """Read back any unused data."""
         return self._readahead.getvalue()
 
 
-def format_capability_line(capabilities):
+def format_capability_line(capabilities: Iterable[bytes]) -> bytes:
     return b"".join([b" " + c for c in capabilities])
 
 
-def format_ref_line(ref, sha, capabilities=None):
+def format_ref_line(
+    ref: bytes, sha: bytes, capabilities: Optional[list[bytes]] = None
+) -> bytes:
     if capabilities is None:
         return sha + b" " + ref + b"\n"
     else:
         return sha + b" " + ref + b"\0" + format_capability_line(capabilities) + b"\n"
 
 
-def format_shallow_line(sha):
+def format_shallow_line(sha: bytes) -> bytes:
     return COMMAND_SHALLOW + b" " + sha
 
 
-def format_unshallow_line(sha):
+def format_unshallow_line(sha: bytes) -> bytes:
     return COMMAND_UNSHALLOW + b" " + sha
 
 
-def format_ack_line(sha, ack_type=b""):
+def format_ack_line(sha: bytes, ack_type: bytes = b"") -> bytes:
     if ack_type:
         ack_type = b" " + ack_type
     return b"ACK " + sha + ack_type + b"\n"