瀏覽代碼

Make AsyncProtocol fully async.

Jelmer Vernooij 5 年之前
父節點
當前提交
2e3c9c2941
共有 2 個文件被更改,包括 56 次插入54 次删除
  1. 42 41
      dulwich/client.py
  2. 14 13
      dulwich/protocol.py

+ 42 - 41
dulwich/client.py

@@ -528,23 +528,23 @@ class GitClient(object):
                             prefix=None):
         raise NotImplementedError(self.archive_async)
 
-    def _parse_status_report(self, proto):
-        unpack = proto.read_pkt_line().strip()
+    async def _parse_status_report(self, proto):
+        unpack = await proto.read_pkt_line().strip()
         if unpack != b'unpack ok':
             st = True
             # flush remaining error data
             while st is not None:
-                st = proto.read_pkt_line()
+                st = await proto.read_pkt_line()
             raise SendPackError(unpack)
         statuses = []
         errs = False
-        ref_status = proto.read_pkt_line()
+        ref_status = await proto.read_pkt_line()
         while ref_status:
             ref_status = ref_status.strip()
             statuses.append(ref_status)
             if not ref_status.startswith(b'ok '):
                 errs = True
-            ref_status = proto.read_pkt_line()
+            ref_status = await proto.read_pkt_line()
 
         if errs:
             ref_status = {}
@@ -618,16 +618,16 @@ class GitClient(object):
 
             if old_sha1 != new_sha1:
                 if sent_capabilities:
-                    proto.write_pkt_line(old_sha1 + b' ' + new_sha1 + b' ' +
-                                         refname)
+                    await proto.write_pkt_line(
+                        old_sha1 + b' ' + new_sha1 + b' ' + refname)
                 else:
-                    proto.write_pkt_line(
+                    await proto.write_pkt_line(
                         old_sha1 + b' ' + new_sha1 + b' ' + refname + b'\0' +
                         b' '.join(sorted(capabilities)))
                     sent_capabilities = True
             if new_sha1 not in have and new_sha1 != ZERO_SHA:
                 want.append(new_sha1)
-        proto.write_pkt_line(None)
+        await proto.write_pkt_line(None)
         return (have, want)
 
     def _negotiate_receive_pack_capabilities(self, server_capabilities):
@@ -704,32 +704,33 @@ class GitClient(object):
 
         """
         assert isinstance(wants, list) and isinstance(wants[0], bytes)
-        proto.write_pkt_line(COMMAND_WANT + b' ' + wants[0] + b' ' +
-                             b' '.join(sorted(capabilities)) + b'\n')
+        await proto.write_pkt_line(
+            COMMAND_WANT + b' ' + wants[0] + b' ' +
+            b' '.join(sorted(capabilities)) + b'\n')
         for want in wants[1:]:
-            proto.write_pkt_line(COMMAND_WANT + b' ' + want + b'\n')
+            await proto.write_pkt_line(COMMAND_WANT + b' ' + want + b'\n')
         if depth not in (0, None) or getattr(graph_walker, 'shallow', None):
             if CAPABILITY_SHALLOW not in capabilities:
                 raise GitProtocolError(
                     "server does not support shallow capability required for "
                     "depth")
             for sha in graph_walker.shallow:
-                proto.write_pkt_line(COMMAND_SHALLOW + b' ' + sha + b'\n')
-            proto.write_pkt_line(COMMAND_DEEPEN + b' ' +
-                                 str(depth).encode('ascii') + b'\n')
-            proto.write_pkt_line(None)
+                await proto.write_pkt_line(COMMAND_SHALLOW + b' ' + sha + b'\n')
+            await proto.write_pkt_line(
+                COMMAND_DEEPEN + b' ' + str(depth).encode('ascii') + b'\n')
+            await proto.write_pkt_line(None)
             if can_read is not None:
                 (new_shallow, new_unshallow) = await _read_shallow_updates(proto)
             else:
                 new_shallow = new_unshallow = None
         else:
             new_shallow = new_unshallow = set()
-            proto.write_pkt_line(None)
+            await proto.write_pkt_line(None)
         have = next(graph_walker)
         while have:
-            proto.write_pkt_line(COMMAND_HAVE + b' ' + have + b'\n')
+            await proto.write_pkt_line(COMMAND_HAVE + b' ' + have + b'\n')
             if can_read is not None and can_read():
-                pkt = proto.read_pkt_line()
+                pkt = await proto.read_pkt_line()
                 parts = pkt.rstrip(b'\n').split(b' ')
                 if parts[0] == b'ACK':
                     graph_walker.ack(parts[1])
@@ -742,7 +743,7 @@ class GitClient(object):
                             "%s not in ('continue', 'ready', 'common)" %
                             parts[2])
             have = next(graph_walker)
-        proto.write_pkt_line(COMMAND_DONE + b'\n')
+        await proto.write_pkt_line(COMMAND_DONE + b'\n')
         return (new_shallow, new_unshallow)
 
     async def _handle_upload_pack_tail(
@@ -761,7 +762,7 @@ class GitClient(object):
         Returns:
 
         """
-        pkt = proto.read_pkt_line()
+        pkt = await proto.read_pkt_line()
         while pkt:
             parts = pkt.rstrip(b'\n').split(b' ')
             if parts[0] == b'ACK':
@@ -769,7 +770,7 @@ class GitClient(object):
             if len(parts) < 3 or parts[2] not in (
                     b'ready', b'continue', b'common'):
                 break
-            pkt = proto.read_pkt_line()
+            pkt = await proto.read_pkt_line()
         if CAPABILITY_SIDE_BAND_64K in capabilities:
             if progress is None:
                 # Just ignore progress data
@@ -865,7 +866,7 @@ class TraditionalGitClient(GitClient):
         """
         proto, unused_can_read, stderr = await self._connect(
             b'receive-pack', path)
-        with proto:
+        async with proto:
             try:
                 old_refs, server_capabilities = await read_pkt_refs(proto)
             except HangupException:
@@ -879,7 +880,7 @@ class TraditionalGitClient(GitClient):
             try:
                 new_refs = orig_new_refs = update_refs(dict(old_refs))
             except BaseException:
-                proto.write_pkt_line(None)
+                await proto.write_pkt_line(None)
                 raise
 
             if CAPABILITY_DELETE_REFS not in server_capabilities:
@@ -895,12 +896,12 @@ class TraditionalGitClient(GitClient):
                         del new_refs[ref]
 
             if new_refs is None:
-                proto.write_pkt_line(None)
+                await proto.write_pkt_line(None)
                 return old_refs
 
             if len(new_refs) == 0 and len(orig_new_refs):
                 # NOOP - Original new refs filtered out by policy
-                proto.write_pkt_line(None)
+                await proto.write_pkt_line(None)
                 if report_status_parser is not None:
                     report_status_parser.check()
                 return old_refs
@@ -945,7 +946,7 @@ class TraditionalGitClient(GitClient):
         """
         proto, can_read, stderr = await self._connect(
             b'upload-pack', path)
-        with proto:
+        async with proto:
             try:
                 refs, server_capabilities = await read_pkt_refs(proto)
             except HangupException:
@@ -955,18 +956,18 @@ class TraditionalGitClient(GitClient):
                             server_capabilities))
 
             if refs is None:
-                proto.write_pkt_line(None)
+                await proto.write_pkt_line(None)
                 return FetchPackResult(refs, symrefs, agent)
 
             try:
                 wants = determine_wants(refs)
             except BaseException:
-                proto.write_pkt_line(None)
+                await proto.write_pkt_line(None)
                 raise
             if wants is not None:
                 wants = [cid for cid in wants if cid != ZERO_SHA]
             if not wants:
-                proto.write_pkt_line(None)
+                await proto.write_pkt_line(None)
                 return FetchPackResult(refs, symrefs, agent)
             (new_shallow, new_unshallow) = await self._handle_upload_pack_head(
                 proto, negotiated_capabilities, graph_walker, wants, can_read,
@@ -982,30 +983,30 @@ class TraditionalGitClient(GitClient):
         """
         # stock `git ls-remote` uses upload-pack
         proto, _, stderr = await self._connect(b'upload-pack', path)
-        with proto:
+        async with proto:
             try:
                 refs, _ = await read_pkt_refs(proto)
             except HangupException:
                 raise remote_error_from_stderr(stderr)
-            proto.write_pkt_line(None)
+            await proto.write_pkt_line(None)
             return refs
 
     async def archive_async(self, path, committish, write_data, progress=None,
                             write_error=None, format=None, subdirs=None,
                             prefix=None):
         proto, can_read, stderr = await self._connect(b'upload-archive', path)
-        with proto:
+        async with proto:
             if format is not None:
-                proto.write_pkt_line(b"argument --format=" + format)
-            proto.write_pkt_line(b"argument " + committish)
+                await proto.write_pkt_line(b"argument --format=" + format)
+            await proto.write_pkt_line(b"argument " + committish)
             if subdirs is not None:
                 for subdir in subdirs:
-                    proto.write_pkt_line(b"argument " + subdir)
+                    await proto.write_pkt_line(b"argument " + subdir)
             if prefix is not None:
-                proto.write_pkt_line(b"argument --prefix=" + prefix)
-            proto.write_pkt_line(None)
+                await proto.write_pkt_line(b"argument --prefix=" + prefix)
+            await proto.write_pkt_line(None)
             try:
-                pkt = proto.read_pkt_line()
+                pkt = await proto.read_pkt_line()
             except HangupException:
                 raise remote_error_from_stderr(stderr)
             if pkt == b"NACK\n":
@@ -1017,7 +1018,7 @@ class TraditionalGitClient(GitClient):
                         pkt[4:].rstrip(b"\n").decode('utf-8', 'replace'))
             else:
                 raise AssertionError("invalid response %r" % pkt)
-            ret = proto.read_pkt_line()
+            ret = await proto.read_pkt_line()
             if ret is not None:
                 raise AssertionError("expected pkt tail")
             await self._read_side_band64k_data(proto, {
@@ -1084,7 +1085,7 @@ class TCPGitClient(TraditionalGitClient):
         if path.startswith(b"/~"):
             path = path[1:]
         # TODO(jelmer): Alternative to ascii?
-        proto.send_cmd(
+        await proto.send_cmd(
             b'git-' + cmd, path, b'host=' + self._host.encode('ascii'))
         return proto, lambda: _fileno_can_read(s), None
 

+ 14 - 13
dulwich/protocol.py

@@ -362,13 +362,13 @@ class AsyncProtocol(object):
         if self._close:
             self._close()
 
-    def __enter__(self):
+    async def __aenter__(self):
         return self
 
-    def __exit__(self, exc_type, exc_val, exc_tb):
+    async def __aexit__(self, exc_type, exc_val, exc_tb):
         self.close()
 
-    def read_pkt_line(self):
+    async def read_pkt_line(self):
         """Reads a pkt-line from the remote git process.
 
         This method may read from the readahead buffer; see unread_pkt_line.
@@ -412,7 +412,7 @@ class AsyncProtocol(object):
         Returns: True if the stream is at EOF, False otherwise.
         """
         try:
-            next_line = self.read_pkt_line()
+            next_line = await self.read_pkt_line()
         except HangupException:
             return True
         self.unread_pkt_line(next_line)
@@ -439,12 +439,12 @@ class AsyncProtocol(object):
         Returns: Yields each line of data up to but not including the next
             flush-pkt.
         """
-        pkt = self.read_pkt_line()
+        pkt = await self.read_pkt_line()
         while pkt:
             yield pkt
-            pkt = self.read_pkt_line()
+            pkt = await self.read_pkt_line()
 
-    def write_pkt_line(self, line):
+    async def write_pkt_line(self, line):
         """Sends a pkt-line to the remote git process.
 
         Args:
@@ -480,7 +480,7 @@ class AsyncProtocol(object):
 
         return ProtocolFile(self)
 
-    def write_sideband(self, channel, blob):
+    async def write_sideband(self, channel, blob):
         """Write multiplexed data to the sideband.
 
         Args:
@@ -491,10 +491,11 @@ class AsyncProtocol(object):
         # 65520-5 = 65515
         # WTF: Why have the len in ASCII, but the channel in binary.
         while blob:
-            self.write_pkt_line(bytes(bytearray([channel])) + blob[:65515])
+            await self.write_pkt_line(
+                bytes(bytearray([channel])) + blob[:65515])
             blob = blob[65515:]
 
-    def send_cmd(self, cmd, *args):
+    async def send_cmd(self, cmd, *args):
         """Send a command and some arguments to a git server.
 
         Only used for the TCP git protocol (git://).
@@ -503,16 +504,16 @@ class AsyncProtocol(object):
           cmd: The remote service to access.
           args: List of arguments to send to remove service.
         """
-        self.write_pkt_line(cmd + b" " + b"".join([(a + b"\0") for a in args]))
+        await self.write_pkt_line(cmd + b" " + b"".join([(a + b"\0") for a in args]))
 
-    def read_cmd(self):
+    async def read_cmd(self):
         """Read a command and some arguments from the git client
 
         Only used for the TCP git protocol (git://).
 
         Returns: A tuple of (command, [list of arguments]).
         """
-        line = self.read_pkt_line()
+        line = await self.read_pkt_line()
         splice_at = line.find(b" ")
         cmd, args = line[:splice_at], line[splice_at+1:]
         assert args[-1:] == b"\x00"