소스 검색

Add AsyncProtocol.

Jelmer Vernooij 5 년 전
부모
커밋
71f0064457
3개의 변경된 파일213개의 추가작업 그리고 29개의 파일을 삭제
  1. 30 27
      dulwich/client.py
  2. 181 0
      dulwich/protocol.py
  3. 2 2
      dulwich/tests/test_client.py

+ 30 - 27
dulwich/client.py

@@ -95,7 +95,7 @@ from dulwich.protocol import (
     SIDE_BAND_CHANNEL_PROGRESS,
     SIDE_BAND_CHANNEL_FATAL,
     PktLineParser,
-    Protocol,
+    AsyncProtocol,
     ProtocolFile,
     TCP_GIT_PORT,
     ZERO_SHA,
@@ -210,11 +210,11 @@ class ReportStatusParser(object):
                 self._ref_status_ok = False
 
 
-def read_pkt_refs(proto):
+async def read_pkt_refs(proto):
     server_capabilities = None
     refs = {}
     # Receive refs from server
-    for pkt in proto.read_pkt_seq():
+    async for pkt in proto.read_pkt_seq():
         (sha, ref) = pkt.rstrip(b'\n').split(None, 1)
         if sha == b'ERR':
             raise GitProtocolError(ref.decode('utf-8', 'replace'))
@@ -293,10 +293,10 @@ class FetchPackResult(object):
                 self.__class__.__name__, self.refs, self.symrefs, self.agent)
 
 
-def _read_shallow_updates(proto):
+async def _read_shallow_updates(proto):
     new_shallow = set()
     new_unshallow = set()
-    for pkt in proto.read_pkt_seq():
+    async for pkt in proto.read_pkt_seq():
         cmd, sha = pkt.split(b' ', 1)
         if cmd == COMMAND_SHALLOW:
             new_shallow.add(sha.strip())
@@ -575,7 +575,7 @@ class GitClient(object):
           channel_callbacks: Dictionary mapping channels to packet
             handlers to use. None for a callback discards channel data.
         """
-        for pkt in proto.read_pkt_seq():
+        async for pkt in proto.read_pkt_seq():
             channel = ord(pkt[:1])
             pkt = pkt[1:]
             try:
@@ -662,7 +662,7 @@ class GitClient(object):
             await self._read_side_band64k_data(proto, channel_callbacks)
         else:
             if CAPABILITY_REPORT_STATUS in capabilities:
-                for pkt in proto.read_pkt_seq():
+                async for pkt in proto.read_pkt_seq():
                     self._report_status_parser.handle_packet(pkt)
         if self._report_status_parser is not None:
             self._report_status_parser.check()
@@ -719,7 +719,7 @@ class GitClient(object):
                                  str(depth).encode('ascii') + b'\n')
             proto.write_pkt_line(None)
             if can_read is not None:
-                (new_shallow, new_unshallow) = _read_shallow_updates(proto)
+                (new_shallow, new_unshallow) = await _read_shallow_updates(proto)
             else:
                 new_shallow = new_unshallow = None
         else:
@@ -867,7 +867,7 @@ class TraditionalGitClient(GitClient):
             b'receive-pack', path)
         with proto:
             try:
-                old_refs, server_capabilities = read_pkt_refs(proto)
+                old_refs, server_capabilities = await read_pkt_refs(proto)
             except HangupException:
                 raise remote_error_from_stderr(stderr)
             negotiated_capabilities = \
@@ -947,7 +947,7 @@ class TraditionalGitClient(GitClient):
             b'upload-pack', path)
         with proto:
             try:
-                refs, server_capabilities = read_pkt_refs(proto)
+                refs, server_capabilities = await read_pkt_refs(proto)
             except HangupException:
                 raise remote_error_from_stderr(stderr)
             negotiated_capabilities, symrefs, agent = (
@@ -984,7 +984,7 @@ class TraditionalGitClient(GitClient):
         proto, _, stderr = await self._connect(b'upload-pack', path)
         with proto:
             try:
-                refs, _ = read_pkt_refs(proto)
+                refs, _ = await read_pkt_refs(proto)
             except HangupException:
                 raise remote_error_from_stderr(stderr)
             proto.write_pkt_line(None)
@@ -1078,8 +1078,9 @@ class TCPGitClient(TraditionalGitClient):
             wfile.close()
             s.close()
 
-        proto = Protocol(rfile.read, wfile.write, close,
-                         report_activity=self._report_activity)
+        proto = AsyncProtocol(
+            rfile.read, wfile.write, close,
+            report_activity=self._report_activity)
         if path.startswith(b"/~"):
             path = path[1:]
         # TODO(jelmer): Alternative to ascii?
@@ -1151,9 +1152,9 @@ class SubprocessGitClient(TraditionalGitClient):
                              stdout=subprocess.PIPE,
                              stderr=subprocess.PIPE)
         pw = SubprocessWrapper(p)
-        return (Protocol(pw.read, pw.write, pw.close,
-                         report_activity=self._report_activity),
-                pw.can_read, p.stderr)
+        return (AsyncProtocol(
+            pw.read, pw.write, pw.close,
+            report_activity=self._report_activity), pw.can_read, p.stderr)
 
 
 class LocalGitClient(GitClient):
@@ -1476,9 +1477,11 @@ class SSHGitClient(TraditionalGitClient):
         con = await self.ssh_vendor.run_command(
             self.host, argv, port=self.port, username=self.username,
             **kwargs)
-        return (Protocol(con.read, con.write, con.close,
-                         report_activity=self._report_activity),
-                con.can_read, getattr(con, 'stderr', None))
+        return (
+            AsyncProtocol(
+                con.read, con.write, con.close,
+                report_activity=self._report_activity),
+            con.can_read, getattr(con, 'stderr', None))
 
 
 def default_user_agent_string():
@@ -1693,17 +1696,17 @@ class HttpGitClient(GitClient):
         try:
             self.dumb = not resp.content_type.startswith("application/x-git-")
             if not self.dumb:
-                proto = Protocol(read, None)
+                proto = AsyncProtocol(read, None)
                 # The first line should mention the service
                 try:
-                    [pkt] = list(proto.read_pkt_seq())
+                    [pkt] = [pkt async for pkt in proto.read_pkt_seq()]
                 except ValueError:
                     raise GitProtocolError(
                         "unexpected number of packets received")
                 if pkt.rstrip(b'\n') != (b'# service=' + service):
                     raise GitProtocolError(
                         "unexpected first line %r from smart server" % pkt)
-                return read_pkt_refs(proto) + (base_url, )
+                return await read_pkt_refs(proto) + (base_url, )
             else:
                 return read_info_refs(resp), set(), base_url
         finally:
@@ -1764,7 +1767,7 @@ class HttpGitClient(GitClient):
         if self.dumb:
             raise NotImplementedError(self.fetch_pack)
         req_data = BytesIO()
-        req_proto = Protocol(None, req_data.write)
+        req_proto = AsyncProtocol(None, req_data.write)
         (have, want) = await self._handle_receive_pack_head(
             req_proto, negotiated_capabilities, old_refs, new_refs)
         if not want and set(new_refs.items()).issubset(set(old_refs.items())):
@@ -1777,7 +1780,7 @@ class HttpGitClient(GitClient):
         resp, read = await self._smart_request("git-receive-pack", url,
                                          data=req_data.getvalue())
         try:
-            resp_proto = Protocol(read, None)
+            resp_proto = AsyncProtocol(read, None)
             await self._handle_receive_pack_tail(
                 resp_proto, negotiated_capabilities, progress)
             return new_refs
@@ -1814,16 +1817,16 @@ class HttpGitClient(GitClient):
         if self.dumb:
             raise NotImplementedError(self.send_pack)
         req_data = BytesIO()
-        req_proto = Protocol(None, req_data.write)
+        req_proto = AsyncProtocol(None, req_data.write)
         (new_shallow, new_unshallow) = await self._handle_upload_pack_head(
                 req_proto, negotiated_capabilities, graph_walker, wants,
                 can_read=None, depth=depth)
         resp, read = await self._smart_request(
             "git-upload-pack", url, data=req_data.getvalue())
         try:
-            resp_proto = Protocol(read, None)
+            resp_proto = AsyncProtocol(read, None)
             if new_shallow is None and new_unshallow is None:
-                (new_shallow, new_unshallow) = _read_shallow_updates(
+                (new_shallow, new_unshallow) = await _read_shallow_updates(
                         resp_proto)
             await self._handle_upload_pack_tail(
                 resp_proto, negotiated_capabilities, graph_walker, pack_data,

+ 181 - 0
dulwich/protocol.py

@@ -339,6 +339,187 @@ class Protocol(object):
         return cmd, args[:-1].split(b"\0")
 
 
+class AsyncProtocol(object):
+    """Async variant of Protocol.
+
+    Parts of the git wire protocol use 'pkt-lines' to communicate. A pkt-line
+    consists of the length of the line as a 4-byte hex string, followed by the
+    payload data. The length includes the 4-byte header. The special line
+    '0000' indicates the end of a section of input and is called a 'flush-pkt'.
+
+    For details on the pkt-line format, see the cgit distribution:
+        Documentation/technical/protocol-common.txt
+    """
+
+    def __init__(self, read, write, close=None, report_activity=None):
+        self.read = read
+        self.write = write
+        self._close = close
+        self.report_activity = report_activity
+        self._readahead = None
+
+    def close(self):
+        if self._close:
+            self._close()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
+
+    def read_pkt_line(self):
+        """Reads a pkt-line from the remote git process.
+
+        This method may read from the readahead buffer; see unread_pkt_line.
+
+        Returns: The next string from the stream, without the length prefix, or
+            None for a flush-pkt ('0000').
+        """
+        if self._readahead is None:
+            read = self.read
+        else:
+            read = self._readahead.read
+            self._readahead = None
+
+        try:
+            sizestr = read(4)
+            if not sizestr:
+                raise HangupException()
+            size = int(sizestr, 16)
+            if size == 0:
+                if self.report_activity:
+                    self.report_activity(4, 'read')
+                return None
+            if self.report_activity:
+                self.report_activity(size, 'read')
+            pkt_contents = read(size-4)
+        except socket.error as e:
+            raise GitProtocolError(e)
+        else:
+            if len(pkt_contents) + 4 != size:
+                raise GitProtocolError(
+                    'Length of pkt read %04x does not match length prefix %04x'
+                    % (len(pkt_contents) + 4, size))
+            return pkt_contents
+
+    async def eof(self):
+        """Test whether the protocol stream has reached EOF.
+
+        Note that this refers to the actual stream EOF and not just a
+        flush-pkt.
+
+        Returns: True if the stream is at EOF, False otherwise.
+        """
+        try:
+            next_line = self.read_pkt_line()
+        except HangupException:
+            return True
+        self.unread_pkt_line(next_line)
+        return False
+
+    def unread_pkt_line(self, data):
+        """Unread a single line of data into the readahead buffer.
+
+        This method can be used to unread a single pkt-line into a fixed
+        readahead buffer.
+
+        Args:
+          data: The data to unread, without the length prefix.
+        Raises:
+          ValueError: If more than one pkt-line is unread.
+        """
+        if self._readahead is not None:
+            raise ValueError('Attempted to unread multiple pkt-lines.')
+        self._readahead = BytesIO(pkt_line(data))
+
+    async def read_pkt_seq(self):
+        """Read a sequence of pkt-lines from the remote git process.
+
+        Returns: Yields each line of data up to but not including the next
+            flush-pkt.
+        """
+        pkt = self.read_pkt_line()
+        while pkt:
+            yield pkt
+            pkt = self.read_pkt_line()
+
+    def write_pkt_line(self, line):
+        """Sends a pkt-line to the remote git process.
+
+        Args:
+          line: A string containing the data to send, without the length
+            prefix.
+        """
+        try:
+            line = pkt_line(line)
+            self.write(line)
+            if self.report_activity:
+                self.report_activity(len(line), 'write')
+        except socket.error as e:
+            raise GitProtocolError(e)
+
+    def write_file(self):
+        """Return a writable file-like object for this protocol."""
+
+        class ProtocolFile(object):
+
+            def __init__(self, proto):
+                self._proto = proto
+                self._offset = 0
+
+            def write(self, data):
+                self._proto.write(data)
+                self._offset += len(data)
+
+            def tell(self):
+                return self._offset
+
+            def close(self):
+                pass
+
+        return ProtocolFile(self)
+
+    def write_sideband(self, channel, blob):
+        """Write multiplexed data to the sideband.
+
+        Args:
+          channel: An int specifying the channel to write to.
+          blob: A blob of data (as a string) to send on this channel.
+        """
+        # a pktline can be a max of 65520. a sideband line can therefore be
+        # 65520-5 = 65515
+        # WTF: Why have the len in ASCII, but the channel in binary.
+        while blob:
+            self.write_pkt_line(bytes(bytearray([channel])) + blob[:65515])
+            blob = blob[65515:]
+
+    def send_cmd(self, cmd, *args):
+        """Send a command and some arguments to a git server.
+
+        Only used for the TCP git protocol (git://).
+
+        Args:
+          cmd: The remote service to access.
+          args: List of arguments to send to remove service.
+        """
+        self.write_pkt_line(cmd + b" " + b"".join([(a + b"\0") for a in args]))
+
+    def read_cmd(self):
+        """Read a command and some arguments from the git client
+
+        Only used for the TCP git protocol (git://).
+
+        Returns: A tuple of (command, [list of arguments]).
+        """
+        line = self.read_pkt_line()
+        splice_at = line.find(b" ")
+        cmd, args = line[:splice_at], line[splice_at+1:]
+        assert args[-1:] == b"\x00"
+        return cmd, args[:-1].split(b"\0")
+
+
+
 _RBUFSIZE = 8192  # Default read buffer size.
 
 

+ 2 - 2
dulwich/tests/test_client.py

@@ -70,7 +70,7 @@ from dulwich.tests import (
     )
 from dulwich.protocol import (
     TCP_GIT_PORT,
-    Protocol,
+    AsyncProtocol,
     )
 from dulwich.pack import (
     pack_objects_to_data,
@@ -102,7 +102,7 @@ class DummyClient(TraditionalGitClient):
         TraditionalGitClient.__init__(self)
 
     async def _connect(self, service, path):
-        return Protocol(self.read, self.write), self.can_read, None
+        return AsyncProtocol(self.read, self.write), self.can_read, None
 
 
 class DummyPopen():