فهرست منبع

Add basic aiohttp server implementation.

Jelmer Vernooij 2 سال پیش
والد
کامیت
176382a228
7فایلهای تغییر یافته به همراه1145 افزوده شده و 10 حذف شده
  1. 271 0
      dulwich/aiohttp.py
  2. 16 0
      dulwich/object_store.py
  3. 62 3
      dulwich/pack.py
  4. 147 0
      dulwich/protocol.py
  5. 87 0
      dulwich/repo.py
  6. 561 7
      dulwich/server.py
  7. 1 0
      setup.py

+ 271 - 0
dulwich/aiohttp.py

@@ -0,0 +1,271 @@
+# aiohttp.py -- aiohttp smart client/server
+# Copyright (C) 2022 Jelmer Vernooij <jelmer@jelmer.uk>
+#
+# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
+# General Public License as public by the Free Software Foundation; version 2.0
+# or (at your option) any later version. You can redistribute it and/or
+# modify it under the terms of either of these two licenses.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# You should have received a copy of the licenses; if not, see
+# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
+# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
+# License, Version 2.0.
+#
+
+"""aiohttp client/server support."""
+
+import os
+import sys
+import time
+
+from aiohttp import web
+
+from . import log_utils
+from .protocol import AsyncProtocol
+from .server import (
+    DEFAULT_HANDLERS,
+    DictBackend,
+    generate_objects_info_packs,
+    generate_info_refs,
+)
+from .repo import Repo, NotGitRepository
+from .web import date_time_string, NO_CACHE_HEADERS
+
+
+logger = log_utils.getLogger(__name__)
+
+
+def cache_forever_headers():
+    now = time.time()
+    return {
+        "Date": date_time_string(now),
+        "Expires": date_time_string(now + 31536000),
+        "Cache-Control": "public, max-age=31536000",
+    }
+
+
+async def send_file(req, f, headers):
+    """Send a file-like object to the request output.
+
+    Args:
+      req: The HTTPGitRequest object to send output to.
+      f: An open file-like object to send; will be closed.
+      headers: Headers to send
+    Returns: Iterator over the contents of the file, as chunks.
+    """
+    if f is None:
+        raise web.HTTPNotFound(text="File not found")
+    response = web.StreamResponse(status=200, reason='OK', headers=headers)
+    await response.prepare(req)
+    try:
+        while True:
+            data = f.read(10240)
+            if not data:
+                break
+            await response.write(data)
+    except IOError:
+        raise web.HTTPInternalServerError(text="Error reading file")
+    finally:
+        f.close()
+    await response.write_eof()
+    return response
+
+
+class HTTPGitApplication(web.Application):
+
+    async def _get_loose_object(self, request):
+        sha = (request.match_info['dir']
+               + request.match_info['filename']).encode("ascii")
+        logger.info("Sending loose object %s", sha)
+        try:
+            object_store = self.backend.open_repository('/').object_store
+        except NotGitRepository as e:
+            raise web.HTTPNotFound(text=str(e))
+        if not object_store.contains_loose(sha):
+            raise web.HTTPNotFound(text="Object not found")
+        try:
+            data = object_store[sha].as_legacy_object()
+        except IOError:
+            raise web.HTTPInternalServerError(text="Error reading object")
+        headers = {
+            'Content-Type': "application/x-git-loose-object"
+        }
+        headers.update(cache_forever_headers())
+        return web.Response(status=200, headers=headers, body=data)
+
+    async def _get_text_file(self, request):
+        headers = {
+            'Content-Type': 'text/plain',
+        }
+        headers.update(NO_CACHE_HEADERS)
+        path = request.match_info['file']
+        logger.info("Sending plain text file %s", path)
+        repo = self.backend.open_repository('/')
+        return await send_file(
+            request, repo.get_named_file(path), headers)
+
+    async def _get_info_refs(self, request):
+        service = request.query.get("service")
+        try:
+            repo = self.backend.open_repository('/')
+        except NotGitRepository as e:
+            raise web.HTTPNotFound(text=str(e))
+        if service:
+            handler_cls = self.handlers.get(service.encode("ascii"), None)
+            if handler_cls is None:
+                raise web.HTTPForbidden(text="Unsupported service")
+            headers = {
+                'Content-Type': "application/x-%s-advertisement" % service}
+            headers.update(NO_CACHE_HEADERS)
+            response = web.StreamResponse(headers=headers, status=200)
+            await response.prepare(request)
+            proto = AsyncProtocol(request.content.readexactly, response.write)
+            handler = handler_cls(
+                self.backend,
+                ['/'],
+                proto,
+                stateless_rpc=True,
+                advertise_refs=True,
+            )
+            await handler.proto.write_pkt_line(
+                b"# service=" + service.encode("ascii") + b"\n")
+            await handler.proto.write_pkt_line(None)
+            await handler.handle_async()
+            await response.write_eof()
+            return response
+        else:
+            # non-smart fallback
+            headers = {'Content-Type': 'text/plain'}
+            headers.update(NO_CACHE_HEADERS)
+            logger.info("Emulating dumb info/refs")
+            return web.Response(body=b''.join(generate_info_refs(repo)), headers=headers)
+
+    async def _get_info_packs(self, request):
+        headers = {'Content-Type': 'text/plain'}
+        headers.update(NO_CACHE_HEADERS)
+        logger.info("Emulating dumb info/packs")
+        try:
+            repo = self.backend.open_repository('/')
+        except NotGitRepository as e:
+            raise web.HTTPNotFound(text=str(e))
+        return web.Response(
+            body=b''.join(generate_objects_info_packs(repo)), headers=headers)
+
+    async def _get_pack_file(self, request):
+        headers = {'Content-Type': "application/x-git-packed-objects"}
+        headers.update(cache_forever_headers())
+        sha = request.match_info['sha']
+        path = 'objects/pack/pack-%s.pack' % sha
+        logger.info("Sending pack file %s", path)
+        repo = self.backend.open_repository('/')
+        return await send_file(
+            request,
+            repo.get_named_file(path),
+            headers=headers,
+        )
+
+    async def _get_index_file(self, request):
+        headers = {
+            'Content-Type': "application/x-git-packed-objects-toc"
+        }
+        headers.update(cache_forever_headers())
+        sha = request.match_info['sha']
+        path = 'objects/pack/pack-%s.idx' % sha
+        logger.info("Sending pack file %s", path)
+        repo = self.backend.open_repository('/')
+        return await send_file(
+            request,
+            repo.get_named_file(path),
+            headers=headers
+        )
+
+    async def _handle_service_request(self, request):
+        service = request.match_info['service']
+        logger.info("Handling service request for %s", service)
+        handler_cls = self.handlers.get(service.encode("ascii"), None)
+        if handler_cls is None:
+            raise web.HTTPForbidden(text="Unsupported service")
+        headers = {
+            'Content-Type': "application/x-%s-result" % service
+        }
+        headers.update(NO_CACHE_HEADERS)
+        response = web.StreamResponse(status=200, headers=headers)
+        await response.prepare(request)
+        proto = AsyncProtocol(request.content.readexactly, response.write)
+        handler = handler_cls(self.backend, ['/'], proto, stateless_rpc=True)
+        await handler.handle_async()
+        await response.write_eof()
+        return response
+
+    def __init__(self, backend):
+        super(HTTPGitApplication, self).__init__()
+        self.backend = backend
+        self.handlers = dict(DEFAULT_HANDLERS)
+        self.router.add_get(
+            '/{file:HEAD}', self._get_text_file)
+        self.router.add_get(
+            '/info/refs', self._get_info_refs)
+        self.router.add_get(
+            '/{file:objects/info/alternates}', self._get_text_file)
+        self.router.add_get(
+            '/{file:objects/info/http-alternates}', self._get_text_file)
+        self.router.add_get(
+            '/objects/info/packs', self._get_info_packs)
+        self.router.add_get(
+            '/objects/{dir:[0-9a-f]{2}}/{file:[0-9a-f]{38}}',
+            self._get_loose_object)
+        self.router.add_get(
+            '/objects/pack/pack-{sha:[0-9a-f]{40}}\\.pack', self._get_pack_file)
+        self.router.add_get(
+            '/objects/pack/pack-{sha:[0-9a-f]{40}}\\.idx', self._get_index_file)
+        self.router.add_post(
+            '/{service:git-upload-pack|git-receive-pack}',
+            self._handle_service_request)
+
+
+def main(argv=sys.argv):
+    """Entry point for starting an HTTP git server."""
+    import optparse
+
+    parser = optparse.OptionParser()
+    parser.add_option(
+        "-l",
+        "--listen_address",
+        dest="listen_address",
+        default="localhost",
+        help="Binding IP address.",
+    )
+    parser.add_option(
+        "-p",
+        "--port",
+        dest="port",
+        type=int,
+        default=8000,
+        help="Port to listen on.",
+    )
+    options, args = parser.parse_args(argv)
+
+    if len(args) > 1:
+        gitdir = args[1]
+    else:
+        gitdir = os.getcwd()
+
+    log_utils.default_logging_config()
+    backend = DictBackend({"/": Repo(gitdir)})
+    app = HTTPGitApplication(backend)
+    logger.info(
+        "Listening for HTTP connections on %s:%d",
+        options.listen_address,
+        options.port,
+    )
+    web.run_app(app, port=options.port, host=options.listen_address)
+
+
+if __name__ == "__main__":
+    main()

+ 16 - 0
dulwich/object_store.py

@@ -295,6 +295,22 @@ class BaseObjectStore(object):
             sha = next(graphwalker)
         return haves
 
+    async def find_common_revisions_async(self, graphwalker):
+        """Find which revisions this store has in common using graphwalker.
+
+        Args:
+          graphwalker: A graphwalker object.
+        Returns: List of SHAs that are in common
+        """
+        haves = []
+        sha = await graphwalker.next()
+        while sha:
+            if sha in self:
+                haves.append(sha)
+                await graphwalker.ack(sha)
+            sha = await graphwalker.next()
+        return haves
+
     def generate_pack_contents(self, have, want, shallow=None, progress=None):
         """Iterate over the contents of a pack file.
 

+ 62 - 3
dulwich/pack.py

@@ -1674,7 +1674,42 @@ def write_pack_objects(
     """Write a new pack data file.
 
     Args:
-      write: write function to use
+      write: Write function to use
+      objects: Iterable of (object, path) tuples to write. Should provide
+         __len__
+      delta_window_size: Sliding window size for searching for deltas;
+                         Set to None for default window size.
+      deltify: Whether to deltify objects
+      compression_level: the zlib compression level to use
+    Returns: Dict mapping id -> (offset, crc32 checksum), pack checksum
+    """
+    if hasattr(write, 'write'):
+        write = write.write
+    if deltify is None:
+        # PERFORMANCE/TODO(jelmer): This should be enabled but is *much* too
+        # slow at the moment.
+        deltify = False
+    if deltify:
+        pack_contents = deltify_pack_objects(objects, delta_window_size)
+        pack_contents_count = len(objects)
+    else:
+        pack_contents_count, pack_contents = pack_objects_to_data(objects)
+
+    return write_pack_data(
+        write,
+        pack_contents_count,
+        pack_contents,
+        compression_level=compression_level,
+    )
+
+
+async def write_pack_objects_async(
+    write, objects, delta_window_size=None, deltify=None, compression_level=-1
+):
+    """Write a new pack data file.
+
+    Args:
+      write: Write function to use
       objects: Iterable of (object, path) tuples to write. Should provide
          __len__
       delta_window_size: Sliding window size for searching for deltas;
@@ -1699,7 +1734,7 @@ def write_pack_objects(
     else:
         pack_contents_count, pack_contents = pack_objects_to_data(objects)
 
-    return write_pack_data(
+    return await write_pack_data_async(
         write,
         pack_contents_count,
         pack_contents,
@@ -1794,6 +1829,26 @@ def write_pack_data(write, num_records=None, records=None, progress=None, compre
     return chunk_generator.entries, chunk_generator.sha1digest()
 
 
+async def write_pack_data_async(write, num_records=None, records=None, progress=None, compression_level=-1):
+    """Write a new pack data file.
+
+    Args:
+      write: Write function to use
+      num_records: Number of records (defaults to len(records) if None)
+      records: Iterator over type_num, object_id, delta_base, raw
+      progress: Function to report progress to
+      compression_level: the zlib compression level
+    Returns: Dict mapping id -> (offset, crc32 checksum), pack checksum
+    """
+    chunk_generator = PackChunkGenerator(
+        num_records=num_records, records=records, progress=progress,
+        compression_level=compression_level)
+    for chunk in chunk_generator:
+        await write(chunk)
+    return chunk_generator.entries, chunk_generator.sha1digest()
+
+
+
 def write_pack_index_v1(f, entries, pack_checksum):
     """Write a new pack index file.
 
@@ -1867,7 +1922,11 @@ def create_delta(base_buf, target_buf):
     out_buf += _delta_encode_size(len(base_buf))
     out_buf += _delta_encode_size(len(target_buf))
     # write out delta opcodes
-    seq = difflib.SequenceMatcher(a=base_buf, b=target_buf)
+    try:
+        from patiencediff import PatienceSequenceMatcher as SequenceMatcher
+    except ImportError:
+        from difflib import SequenceMatcher
+    seq = SequenceMatcher(a=base_buf, b=target_buf)
     for opcode, i1, i2, j1, j2 in seq.get_opcodes():
         # Git patch opcodes don't care about deletes!
         # if opcode == 'replace' or opcode == 'delete':

+ 147 - 0
dulwich/protocol.py

@@ -440,6 +440,153 @@ class ReceivableProtocol(Protocol):
         return buf.read(size)
 
 
+class AsyncProtocol(object):
+    """Async version of protocol.
+
+    """
+
+    def __init__(self, read, write, close=None, report_activity=None):
+        self.read = read
+        self.write = write
+        self._close = close
+        self.report_activity = report_activity
+        self._readahead = []
+
+    def close(self):
+        if self._close:
+            self._close()
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.close()
+
+    async def read_pkt_line(self):
+        """Reads a pkt-line from the remote git process.
+
+        This method may read from the readahead buffer; see unread_pkt_line.
+
+        Returns: The next string from the stream, without the length prefix, or
+            None for a flush-pkt ('0000').
+        """
+        if self._readahead:
+            return self._readahead.pop(0)
+
+        try:
+            sizestr = await self.read(4)
+            if not sizestr:
+                raise HangupException()
+            size = int(sizestr, 16)
+            if size == 0:
+                if self.report_activity:
+                    self.report_activity(4, "read")
+                return None
+            if self.report_activity:
+                self.report_activity(size, "read")
+            pkt_contents = await self.read(size - 4)
+        except ConnectionResetError:
+            raise HangupException()
+        except socket.error as e:
+            raise GitProtocolError(e)
+        else:
+            if len(pkt_contents) + 4 != size:
+                raise GitProtocolError(
+                    "Length of pkt read %04x does not match length prefix %04x"
+                    % (len(pkt_contents) + 4, size)
+                )
+            return pkt_contents
+
+    async def eof(self):
+        """Test whether the protocol stream has reached EOF.
+
+        Note that this refers to the actual stream EOF and not just a
+        flush-pkt.
+
+        Returns: True if the stream is at EOF, False otherwise.
+        """
+        try:
+            next_line = await self.read_pkt_line()
+        except HangupException:
+            return True
+        self.unread_pkt_line(next_line)
+        return False
+
+    def unread_pkt_line(self, data):
+        """Unread a single line of data into the readahead buffer.
+
+        This method can be used to unread a single pkt-line into a fixed
+        readahead buffer.
+
+        Args:
+          data: The data to unread, without the length prefix.
+        Raises:
+          ValueError: If more than one pkt-line is unread.
+        """
+        self._readahead.append(data)
+
+    async def read_pkt_seq(self):
+        """Read a sequence of pkt-lines from the remote git process.
+
+        Returns: Yields each line of data up to but not including the next
+            flush-pkt.
+        """
+        pkt = self.read_pkt_line()
+        while pkt:
+            yield pkt
+            pkt = await self.read_pkt_line()
+
+    async def write_pkt_line(self, line):
+        """Sends a pkt-line to the remote git process.
+
+        Args:
+          line: A string containing the data to send, without the length
+            prefix.
+        """
+        try:
+            line = pkt_line(line)
+            await self.write(line)
+            if self.report_activity:
+                self.report_activity(len(line), "write")
+        except socket.error as e:
+            raise GitProtocolError(e)
+
+    async def write_sideband(self, channel, blob):
+        """Write multiplexed data to the sideband.
+
+        Args:
+          channel: An int specifying the channel to write to.
+          blob: A blob of data (as a string) to send on this channel.
+        """
+        # a pktline can be a max of 65520. a sideband line can therefore be
+        # 65520-5 = 65515
+        # WTF: Why have the len in ASCII, but the channel in binary.
+        while blob:
+            await self.write_pkt_line(bytes(bytearray([channel])) + blob[:65515])
+            blob = blob[65515:]
+
+    async def send_cmd(self, cmd, *args):
+        """Send a command and some arguments to a git server.
+
+        Only used for the TCP git protocol (git://).
+
+        Args:
+          cmd: The remote service to access.
+          args: List of arguments to send to remove service.
+        """
+        await self.write_pkt_line(format_cmd_pkt(cmd, *args))
+
+    async def read_cmd(self):
+        """Read a command and some arguments from the git client
+
+        Only used for the TCP git protocol (git://).
+
+        Returns: A tuple of (command, [list of arguments]).
+        """
+        line = await self.read_pkt_line()
+        return parse_cmd_pkt(line)
+
+
 def extract_capabilities(text):
     """Extract a capabilities list from a string, if present.
 

+ 87 - 0
dulwich/repo.py

@@ -549,6 +549,93 @@ class BaseRepo(object):
             )
         )
 
+    async def fetch_objects_async(
+        self,
+        determine_wants,
+        graph_walker,
+        progress,
+        get_tagged=None,
+        depth=None,
+    ):
+        """Fetch the missing objects required for a set of revisions.
+
+        Args:
+          determine_wants: Function that takes a dictionary with heads
+            and returns the list of heads to fetch.
+          graph_walker: Object that can iterate over the list of revisions
+            to fetch and has an "ack" method that will be called to acknowledge
+            that a revision is present.
+          progress: Simple progress function that will be called with
+            updated progress strings.
+          get_tagged: Function that returns a dict of pointed-to sha ->
+            tag sha for including tags.
+          depth: Shallow fetch depth
+        Returns: iterator over objects, with __len__ implemented
+        """
+        if depth not in (None, 0):
+            raise NotImplementedError("depth not supported yet")
+
+        refs = {}
+        for ref, sha in self.get_refs().items():
+            try:
+                obj = self.object_store[sha]
+            except KeyError:
+                warnings.warn(
+                    "ref %s points at non-present sha %s"
+                    % (ref.decode("utf-8", "replace"), sha.decode("ascii")),
+                    UserWarning,
+                )
+                continue
+            else:
+                if isinstance(obj, Tag):
+                    refs[ref + ANNOTATED_TAG_SUFFIX] = obj.object[1]
+                refs[ref] = sha
+
+        wants = await determine_wants(refs)
+        if not isinstance(wants, list):
+            raise TypeError("determine_wants() did not return a list")
+
+        shallows = getattr(graph_walker, "shallow", frozenset())
+        unshallows = getattr(graph_walker, "unshallow", frozenset())
+
+        if wants == []:
+            # TODO(dborowitz): find a way to short-circuit that doesn't change
+            # this interface.
+
+            if shallows or unshallows:
+                # Do not send a pack in shallow short-circuit path
+                return None
+
+            return []
+
+        # If the graph walker is set up with an implementation that can
+        # ACK/NAK to the wire, it will write data to the client through
+        # this call as a side-effect.
+        haves = await self.object_store.find_common_revisions_async(graph_walker)
+
+        # Deal with shallow requests separately because the haves do
+        # not reflect what objects are missing
+        if shallows or unshallows:
+            # TODO: filter the haves commits from iter_shas. the specific
+            # commits aren't missing.
+            haves = []
+
+        parents_provider = ParentsProvider(self.object_store, shallows=shallows)
+
+        def get_parents(commit):
+            return parents_provider.get_parents(commit.id, commit)
+
+        return self.object_store.iter_shas(
+            self.object_store.find_missing_objects(
+                haves,
+                wants,
+                self.get_shallow(),
+                progress,
+                get_tagged,
+                get_parents=get_parents,
+            )
+        )
+
     def generate_pack_data(self, have, want, progress=None, ofs_delta=None):
         """Generate pack data objects for a set of wants/haves.
 

+ 561 - 7
dulwich/server.py

@@ -69,6 +69,7 @@ from dulwich.objects import (
 )
 from dulwich.pack import (
     write_pack_objects,
+    write_pack_objects_async,
 )
 from dulwich.protocol import (
     BufferedPktLineWriter,
@@ -182,6 +183,17 @@ class BackendRepo(object):
         """
         raise NotImplementedError
 
+    async def fetch_objects_async(self, determine_wants, graph_walker, progress, get_tagged=None):
+        """
+        Yield the objects required for a list of commits.
+
+        Args:
+          progress: is a callback to send progress messages to the client
+          get_tagged: Function that returns a dict of pointed-to sha ->
+            tag sha for including tags.
+        """
+        raise NotImplementedError
+
 
 class DictBackend(Backend):
     """Trivial backend that looks up Git repositories in a dictionary."""
@@ -227,6 +239,9 @@ class Handler(object):
     def handle(self):
         raise NotImplementedError(self.handle)
 
+    async def handle_async(self):
+        raise NotImplementedError(self.handle_async)
+
 
 class PackHandler(Handler):
     """Protocol handler for packs."""
@@ -327,6 +342,11 @@ class UploadPackHandler(PackHandler):
             return
         self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, message)
 
+    async def progress_async(self, message):
+        if self.has_capability(CAPABILITY_NO_PROGRESS) or self._processing_have_lines:
+            return
+        await self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, message)
+
     def get_tagged(self, refs=None, repo=None):
         """Get a dict of peeled values of tags to their original tag shas.
 
@@ -412,6 +432,59 @@ class UploadPackHandler(PackHandler):
         # we are done
         self.proto.write_pkt_line(None)
 
+    async def handle_async(self):
+        async def write(x):
+            return await self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
+
+        graph_walker = _AsyncProtocolGraphWalker(
+            self,
+            self.repo.object_store,
+            self.repo.get_peeled,
+            self.repo.refs.get_symrefs,
+        )
+        wants = []
+
+        async def wants_wrapper(refs, **kwargs):
+            wants.extend(await graph_walker.determine_wants(refs, **kwargs))
+            return wants
+
+        objects_iter = await self.repo.fetch_objects_async(
+            wants_wrapper,
+            graph_walker,
+            self.progress_async,
+            get_tagged=self.get_tagged,
+        )
+
+        # Note the fact that client is only processing responses related
+        # to the have lines it sent, and any other data (including side-
+        # band) will be be considered a fatal error.
+        self._processing_have_lines = True
+
+        # Did the process short-circuit (e.g. in a stateless RPC call)? Note
+        # that the client still expects a 0-object pack in most cases.
+        # Also, if it also happens that the object_iter is instantiated
+        # with a graph walker with an implementation that talks over the
+        # wire (which is this instance of this class) this will actually
+        # iterate through everything and write things out to the wire.
+        if len(wants) == 0:
+            return
+
+        # The provided haves are processed, and it is safe to send side-
+        # band data now.
+        self._processing_have_lines = False
+
+        if not await graph_walker.handle_done(
+            not self.has_capability(CAPABILITY_NO_DONE), self._done_received
+        ):
+            return
+
+        await self.progress_async(
+            ("counting objects: %d, done.\n" % len(objects_iter)).encode("ascii")
+        )
+        await write_pack_objects_async(write, objects_iter)
+        # we are done
+        await self.proto.write_pkt_line(None)
+
 
 def _split_proto_line(line, allowed):
     """Split a line read from the wire.
@@ -796,6 +869,49 @@ class SingleAckGraphWalkerImpl(object):
         return True
 
 
+class AsyncSingleAckGraphWalkerImpl(object):
+    """Graph walker implementation that speaks the single-ack protocol."""
+
+    def __init__(self, walker):
+        self.walker = walker
+        self._common = []
+
+    async def ack(self, have_ref):
+        if not self._common:
+            await self.walker.send_ack(have_ref)
+            self._common.append(have_ref)
+
+    async def next(self):
+        command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
+        if command in (None, COMMAND_DONE):
+            # defer the handling of done
+            self.walker.notify_done()
+            return None
+        elif command == COMMAND_HAVE:
+            return sha
+
+    async def handle_done(self, done_required, done_received):
+        if not self._common:
+            await self.walker.send_nak()
+
+        if done_required and not done_received:
+            # we are not done, especially when done is required; skip
+            # the pack for this request and especially do not handle
+            # the done.
+            return False
+
+        if not done_received and not self._common:
+            # Okay we are not actually done then since the walker picked
+            # up no haves.  This is usually triggered when client attempts
+            # to pull from a source that has no common base_commit.
+            # See: test_server.MultiAckDetailedGraphWalkerImplTestCase.\
+            #          test_multi_ack_stateless_nodone
+            return False
+
+        return True
+
+
+
 class MultiAckGraphWalkerImpl(object):
     """Graph walker implementation that speaks the multi-ack protocol."""
 
@@ -855,6 +971,63 @@ class MultiAckGraphWalkerImpl(object):
         return True
 
 
+class AsyncMultiAckGraphWalkerImpl(object):
+    """Graph walker implementation that speaks the multi-ack protocol."""
+
+    def __init__(self, walker):
+        self.walker = walker
+        self._found_base = False
+        self._common = []
+
+    async def ack(self, have_ref):
+        self._common.append(have_ref)
+        if not self._found_base:
+            await self.walker.send_ack(have_ref, b"continue")
+            if self.walker.all_wants_satisfied(self._common):
+                self._found_base = True
+        # else we blind ack within next
+
+    async def next(self):
+        while True:
+            command, sha = await self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
+            if command is None:
+                await self.walker.send_nak()
+                # in multi-ack mode, a flush-pkt indicates the client wants to
+                # flush but more have lines are still coming
+                continue
+            elif command == COMMAND_DONE:
+                self.walker.notify_done()
+                return None
+            elif command == COMMAND_HAVE:
+                if self._found_base:
+                    # blind ack
+                    await self.walker.send_ack(sha, b"continue")
+                return sha
+
+    async def handle_done(self, done_required, done_received):
+        if done_required and not done_received:
+            # we are not done, especially when done is required; skip
+            # the pack for this request and especially do not handle
+            # the done.
+            return False
+
+        if not done_received and not self._common:
+            # Okay we are not actually done then since the walker picked
+            # up no haves.  This is usually triggered when client attempts
+            # to pull from a source that has no common base_commit.
+            # See: test_server.MultiAckDetailedGraphWalkerImplTestCase.\
+            #          test_multi_ack_stateless_nodone
+            return False
+
+        # don't nak unless no common commits were found, even if not
+        # everything is satisfied
+        if self._common:
+            await self.walker.send_ack(self._common[-1])
+        else:
+            await self.walker.send_nak()
+        return True
+
+
 class MultiAckDetailedGraphWalkerImpl(object):
     """Graph walker implementation speaking the multi-ack-detailed protocol."""
 
@@ -862,18 +1035,18 @@ class MultiAckDetailedGraphWalkerImpl(object):
         self.walker = walker
         self._common = []
 
-    def ack(self, have_ref):
+    async def ack(self, have_ref):
         # Should only be called iff have_ref is common
         self._common.append(have_ref)
-        self.walker.send_ack(have_ref, b"common")
+        await self.walker.send_ack(have_ref, b"common")
 
-    def next(self):
+    async def next(self):
         while True:
-            command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
+            command, sha = await self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
             if command is None:
                 if self.walker.all_wants_satisfied(self._common):
                     self.walker.send_ack(self._common[-1], b"ready")
-                self.walker.send_nak()
+                await self.walker.send_nak()
                 if self.walker.stateless_rpc:
                     # The HTTP version of this request a flush-pkt always
                     # signifies an end of request, so we also return
@@ -894,8 +1067,6 @@ class MultiAckDetailedGraphWalkerImpl(object):
         # don't nak unless no common commits were found, even if not
         # everything is satisfied
 
-    __next__ = next
-
     def handle_done(self, done_required, done_received):
         if done_required and not done_received:
             # we are not done, especially when done is required; skip
@@ -920,6 +1091,275 @@ class MultiAckDetailedGraphWalkerImpl(object):
         return True
 
 
+class AsyncMultiAckDetailedGraphWalkerImpl(object):
+    """Graph walker implementation speaking the multi-ack-detailed protocol."""
+
+    def __init__(self, walker):
+        self.walker = walker
+        self._common = []
+
+    async def ack(self, have_ref):
+        # Should only be called iff have_ref is common
+        self._common.append(have_ref)
+        await self.walker.send_ack(have_ref, b"common")
+
+    async def next(self):
+        while True:
+            command, sha = await self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
+            if command is None:
+                if self.walker.all_wants_satisfied(self._common):
+                    await self.walker.send_ack(self._common[-1], b"ready")
+                await self.walker.send_nak()
+                if self.walker.stateless_rpc:
+                    # The HTTP version of this request a flush-pkt always
+                    # signifies an end of request, so we also return
+                    # nothing here as if we are done (but not really, as
+                    # it depends on whether no-done capability was
+                    # specified and that's handled in handle_done which
+                    # may or may not call post_nodone_check depending on
+                    # that).
+                    return None
+            elif command == COMMAND_DONE:
+                # Let the walker know that we got a done.
+                self.walker.notify_done()
+                break
+            elif command == COMMAND_HAVE:
+                # return the sha and let the caller ACK it with the
+                # above ack method.
+                return sha
+        # don't nak unless no common commits were found, even if not
+        # everything is satisfied
+
+    async def handle_done(self, done_required, done_received):
+        if done_required and not done_received:
+            # we are not done, especially when done is required; skip
+            # the pack for this request and especially do not handle
+            # the done.
+            return False
+
+        if not done_received and not self._common:
+            # Okay we are not actually done then since the walker picked
+            # up no haves.  This is usually triggered when client attempts
+            # to pull from a source that has no common base_commit.
+            # See: test_server.MultiAckDetailedGraphWalkerImplTestCase.\
+            #          test_multi_ack_stateless_nodone
+            return False
+
+        # don't nak unless no common commits were found, even if not
+        # everything is satisfied
+        if self._common:
+            await self.walker.send_ack(self._common[-1])
+        else:
+            await self.walker.send_nak()
+        return True
+
+
+class _AsyncProtocolGraphWalker(object):
+    """A graph walker that knows the git protocol.
+
+    As a graph walker, this class implements ack(), next(), and reset(). It
+    also contains some base methods for interacting with the wire and walking
+    the commit tree.
+
+    The work of determining which acks to send is passed on to the
+    implementation instance stored in _impl. The reason for this is that we do
+    not know at object creation time what ack level the protocol requires. A
+    call to set_ack_type() is required to set up the implementation, before
+    any calls to next() or ack() are made.
+    """
+
+    def __init__(self, handler, object_store, get_peeled, get_symrefs):
+        self.handler = handler
+        self.store = object_store
+        self.get_peeled = get_peeled
+        self.get_symrefs = get_symrefs
+        self.proto = handler.proto
+        self.stateless_rpc = handler.stateless_rpc
+        self.advertise_refs = handler.advertise_refs
+        self._wants = []
+        self.shallow = set()
+        self.client_shallow = set()
+        self.unshallow = set()
+        self._cached = False
+        self._cache = []
+        self._cache_index = 0
+        self._impl = None
+
+    async def determine_wants(self, heads, depth=None):
+        """Determine the wants for a set of heads.
+
+        The given heads are advertised to the client, who then specifies which
+        refs they want using 'want' lines. This portion of the protocol is the
+        same regardless of ack type, and in fact is used to set the ack type of
+        the ProtocolGraphWalker.
+
+        If the client has the 'shallow' capability, this method also reads and
+        responds to the 'shallow' and 'deepen' lines from the client. These are
+        not part of the wants per se, but they set up necessary state for
+        walking the graph. Additionally, later code depends on this method
+        consuming everything up to the first 'have' line.
+
+        Args:
+          heads: a dict of refname->SHA1 to advertise
+        Returns: a list of SHA1s requested by the client
+        """
+        symrefs = self.get_symrefs()
+        values = set(heads.values())
+        if self.advertise_refs or not self.stateless_rpc:
+            for i, (ref, sha) in enumerate(sorted(heads.items())):
+                try:
+                    peeled_sha = self.get_peeled(ref)
+                except KeyError:
+                    # Skip refs that are inaccessible
+                    # TODO(jelmer): Integrate with Repo.fetch_objects refs
+                    # logic.
+                    continue
+                if i == 0:
+                    logger.info(
+                        "Sending capabilities: %s", self.handler.capabilities())
+                    line = format_ref_line(
+                        ref, sha,
+                        self.handler.capabilities()
+                        + symref_capabilities(symrefs.items()))
+                else:
+                    line = format_ref_line(ref, sha)
+                await self.proto.write_pkt_line(line)
+                if peeled_sha != sha:
+                    await self.proto.write_pkt_line(
+                        format_ref_line(ref + ANNOTATED_TAG_SUFFIX, peeled_sha))
+
+            # i'm done..
+            await self.proto.write_pkt_line(None)
+
+            if self.advertise_refs:
+                return []
+
+        # Now client will sending want want want commands
+        want = await self.proto.read_pkt_line()
+        if not want:
+            return []
+        line, caps = extract_want_line_capabilities(want)
+        self.handler.set_client_capabilities(caps)
+        self.set_ack_type(ack_type(caps))
+        allowed = (COMMAND_WANT, COMMAND_SHALLOW, COMMAND_DEEPEN, None)
+        command, sha = _split_proto_line(line, allowed)
+
+        want_revs = []
+        while command == COMMAND_WANT:
+            if sha not in values:
+                raise GitProtocolError("Client wants invalid object %s" % sha)
+            want_revs.append(sha)
+            command, sha = await self.read_proto_line(allowed)
+
+        self.set_wants(want_revs)
+        if command in (COMMAND_SHALLOW, COMMAND_DEEPEN):
+            self.unread_proto_line(command, sha)
+            await self._handle_shallow_request(want_revs)
+
+        if self.stateless_rpc and await self.proto.eof():
+            # The client may close the socket at this point, expecting a
+            # flush-pkt from the server. We might be ready to send a packfile
+            # at this point, so we need to explicitly short-circuit in this
+            # case.
+            return []
+
+        return want_revs
+
+    def unread_proto_line(self, command, value):
+        if isinstance(value, int):
+            value = str(value).encode("ascii")
+        self.proto.unread_pkt_line(command + b" " + value)
+
+    async def ack(self, have_ref):
+        if len(have_ref) != 40:
+            raise ValueError("invalid sha %r" % have_ref)
+        return await self._impl.ack(have_ref)
+
+    def reset(self):
+        self._cached = True
+        self._cache_index = 0
+
+    async def next(self):
+        if not self._cached:
+            if not self._impl and self.stateless_rpc:
+                return None
+            return await self._impl.next()
+        self._cache_index += 1
+        if self._cache_index > len(self._cache):
+            return None
+        return self._cache[self._cache_index]
+
+    async def read_proto_line(self, allowed):
+        """Read a line from the wire.
+
+        Args:
+          allowed: An iterable of command names that should be allowed.
+        Returns: A tuple of (command, value); see _split_proto_line.
+        Raises:
+          UnexpectedCommandError: If an error occurred reading the line.
+        """
+        return _split_proto_line(await self.proto.read_pkt_line(), allowed)
+
+    async def _handle_shallow_request(self, wants):
+        while True:
+            command, val = await self.read_proto_line((COMMAND_DEEPEN, COMMAND_SHALLOW))
+            if command == COMMAND_DEEPEN:
+                depth = val
+                break
+            self.client_shallow.add(val)
+        await self.read_proto_line((None,))  # consume client's flush-pkt
+
+        shallow, not_shallow = _find_shallow(self.store, wants, depth)
+
+        # Update self.shallow instead of reassigning it since we passed a
+        # reference to it before this method was called.
+        self.shallow.update(shallow - not_shallow)
+        new_shallow = self.shallow - self.client_shallow
+        unshallow = self.unshallow = not_shallow & self.client_shallow
+
+        for sha in sorted(new_shallow):
+            await self.proto.write_pkt_line(format_shallow_line(sha))
+        for sha in sorted(unshallow):
+            await self.proto.write_pkt_line(format_unshallow_line(sha))
+
+        await self.proto.write_pkt_line(None)
+
+    def notify_done(self):
+        # relay the message down to the handler.
+        self.handler.notify_done()
+
+    async def send_ack(self, sha, ack_type=b""):
+        await self.proto.write_pkt_line(format_ack_line(sha, ack_type))
+
+    async def send_nak(self):
+        await self.proto.write_pkt_line(NAK_LINE)
+
+    async def handle_done(self, done_required, done_received):
+        # Delegate this to the implementation.
+        return await self._impl.handle_done(done_required, done_received)
+
+    def set_wants(self, wants):
+        self._wants = wants
+
+    def all_wants_satisfied(self, haves):
+        """Check whether all the current wants are satisfied by a set of haves.
+
+        Args:
+          haves: A set of commits we know the client has.
+        Note: Wants are specified with set_wants rather than passed in since
+            in the current interface they are determined outside this class.
+        """
+        return _all_wants_satisfied(self.store, haves, self._wants)
+
+    def set_ack_type(self, ack_type):
+        impl_classes = {
+            MULTI_ACK: AsyncMultiAckGraphWalkerImpl,
+            MULTI_ACK_DETAILED: AsyncMultiAckDetailedGraphWalkerImpl,
+            SINGLE_ACK: AsyncSingleAckGraphWalkerImpl,
+        }
+        self._impl = impl_classes[ack_type](self)
+
+
 class ReceivePackHandler(PackHandler):
     """Protocol handler for downloading a pack from the client."""
 
@@ -1028,6 +1468,29 @@ class ReceivePackHandler(PackHandler):
         write(None)
         flush()
 
+    async def _report_status_async(self, status: List[Tuple[bytes, bytes]]) -> None:
+        if self.has_capability(CAPABILITY_SIDE_BAND_64K):
+            async def write(d):
+                return await self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, d)
+
+            async def flush():
+                await self.proto.write_pkt_line(None)
+        else:
+            write = self.proto.write_pkt_line
+
+            async def flush():
+                pass
+
+        for name, msg in status:
+            if name == b"unpack":
+                await write(b"unpack " + msg + b"\n")
+            elif msg == b"ok":
+                await write(b"ok " + name + b"\n")
+            else:
+                await write(b"ng " + name + b" " + msg + b"\n")
+        await write(None)
+        await flush()
+
     def _on_post_receive(self, client_refs):
         hook = self.repo.hooks.get("post-receive", None)
         if not hook:
@@ -1039,6 +1502,17 @@ class ReceivePackHandler(PackHandler):
         except HookError as err:
             self.proto.write_sideband(SIDE_BAND_CHANNEL_FATAL, str(err).encode('utf-8'))
 
+    async def _on_post_receive_async(self, client_refs):
+        hook = self.repo.hooks.get("post-receive", None)
+        if not hook:
+            return
+        try:
+            output = hook.execute(client_refs)
+            if output:
+                await self.proto.write_sideband(SIDE_BAND_CHANNEL_PROGRESS, output)
+        except HookError as err:
+            await self.proto.write_sideband(SIDE_BAND_CHANNEL_FATAL, str(err).encode('utf-8'))
+
     def handle(self) -> None:
         if self.advertise_refs or not self.stateless_rpc:
             refs = sorted(self.repo.get_refs().items())
@@ -1085,6 +1559,52 @@ class ReceivePackHandler(PackHandler):
         if self.has_capability(CAPABILITY_REPORT_STATUS):
             self._report_status(status)
 
+    async def handle_async(self) -> None:
+        if self.advertise_refs or not self.stateless_rpc:
+            refs = sorted(self.repo.get_refs().items())
+            symrefs = sorted(self.repo.refs.get_symrefs().items())
+
+            if not refs:
+                refs = [(CAPABILITIES_REF, ZERO_SHA)]
+            logger.info(
+                "Sending capabilities: %s", self.capabilities())
+            await self.proto.write_pkt_line(
+                format_ref_line(
+                    refs[0][0], refs[0][1],
+                    self.capabilities() + symref_capabilities(symrefs)))
+            for i in range(1, len(refs)):
+                ref = refs[i]
+                await self.proto.write_pkt_line(format_ref_line(ref[0], ref[1]))
+
+            await self.proto.write_pkt_line(None)
+            if self.advertise_refs:
+                return
+
+        client_refs = []
+        ref = self.proto.read_pkt_line()
+
+        # if ref is none then client doesn't want to send us anything..
+        if ref is None:
+            return
+
+        ref, caps = extract_capabilities(ref)
+        self.set_client_capabilities(caps)
+
+        # client will now send us a list of (oldsha, newsha, ref)
+        while ref:
+            client_refs.append(ref.split())
+            ref = await self.proto.read_pkt_line()
+
+        # backend can now deal with this refs and read a pack using self.read
+        status = self._apply_pack(client_refs)
+
+        await self._on_post_receive(client_refs)
+
+        # when we have read all the pack from the client, send a status report
+        # if the client asked for it
+        if self.has_capability(CAPABILITY_REPORT_STATUS):
+            await self._report_status_async(status)
+
 
 class UploadArchiveHandler(Handler):
     def __init__(self, backend, args, proto, stateless_rpc=False):
@@ -1125,6 +1645,40 @@ class UploadArchiveHandler(Handler):
             write(chunk)
         self.proto.write_pkt_line(None)
 
+    async def handle_async(self):
+        def write(x):
+            return self.proto.write_sideband(SIDE_BAND_CHANNEL_DATA, x)
+
+        arguments = []
+        for pkt in self.proto.read_pkt_seq():
+            (key, value) = pkt.split(b" ", 1)
+            if key != b"argument":
+                raise GitProtocolError("unknown command %s" % key)
+            arguments.append(value.rstrip(b"\n"))
+        prefix = b""
+        format = "tar"
+        i = 0
+        store = self.repo.object_store
+        while i < len(arguments):
+            argument = arguments[i]
+            if argument == b"--prefix":
+                i += 1
+                prefix = arguments[i]
+            elif argument == b"--format":
+                i += 1
+                format = arguments[i].decode("ascii")
+            else:
+                commit_sha = self.repo.refs[argument]
+                tree = store[store[commit_sha].tree]
+            i += 1
+        self.proto.write_pkt_line(b"ACK")
+        self.proto.write_pkt_line(None)
+        for chunk in tar_stream(
+            store, tree, mtime=time.time(), prefix=prefix, format=format
+        ):
+            write(chunk)
+        self.proto.write_pkt_line(None)
+
 
 # Default handler classes for git services.
 DEFAULT_HANDLERS = {

+ 1 - 0
setup.py

@@ -119,4 +119,5 @@ setup(name='dulwich',
           'https': ['urllib3>=1.24.1'],
           'pgp': ['gpg'],
           'paramiko': ['paramiko'],
+          'aiohttp': ['aiohttp'],
       })