Răsfoiți Sursa

Add smart HTTP support to dul-web.

Change-Id: I6ddfe213707f9ccb2d377ac9813856b1f4fe90dc
Dave Borowitz 15 ani în urmă
părinte
comite
ea1d5562fd
5 a modificat fișierele cu 308 adăugiri și 35 ștergeri
  1. 4 1
      dulwich/protocol.py
  2. 140 31
      dulwich/server.py
  3. 15 1
      dulwich/tests/test_protocol.py
  4. 122 2
      dulwich/tests/test_server.py
  5. 27 0
      dulwich/web.py

+ 4 - 1
dulwich/protocol.py

@@ -30,6 +30,7 @@ TCP_GIT_PORT = 9418
 
 SINGLE_ACK = 0
 MULTI_ACK = 1
+MULTI_ACK_DETAILED = 2
 
 class ProtocolFile(object):
     """
@@ -190,6 +191,8 @@ def extract_want_line_capabilities(text):
 
 def ack_type(capabilities):
     """Extract the ack type from a capabilities list."""
-    if 'multi_ack' in capabilities:
+    if 'multi_ack_detailed' in capabilities:
+      return MULTI_ACK_DETAILED
+    elif 'multi_ack' in capabilities:
         return MULTI_ACK
     return SINGLE_ACK

+ 140 - 31
dulwich/server.py

@@ -31,6 +31,8 @@ import SocketServer
 import tempfile
 
 from dulwich.errors import (
+    ApplyDeltaError,
+    ChecksumMismatch,
     GitProtocolError,
     )
 from dulwich.objects import (
@@ -44,6 +46,7 @@ from dulwich.protocol import (
     extract_want_line_capabilities,
     SINGLE_ACK,
     MULTI_ACK,
+    MULTI_ACK_DETAILED,
     ack_type,
     )
 from dulwich.repo import (
@@ -92,18 +95,55 @@ class GitBackend(Backend):
 
     def apply_pack(self, refs, read):
         f, commit = self.repo.object_store.add_thin_pack()
+        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
+        status = []
+        unpack_error = None
+        # TODO: more informative error messages than just the exception string
+        try:
+            # TODO: decode the pack as we stream to avoid blocking reads beyond
+            # the end of data (when using HTTP/1.1 chunked encoding)
+            while True:
+                data = read(10240)
+                if not data:
+                    break
+                f.write(data)
+        except all_exceptions, e:
+            unpack_error = str(e).replace('\n', '')
         try:
-            f.write(read())
-        finally:
             commit()
+        except all_exceptions, e:
+            if not unpack_error:
+                unpack_error = str(e).replace('\n', '')
+
+        if unpack_error:
+            status.append(('unpack', unpack_error))
+        else:
+            status.append(('unpack', 'ok'))
 
         for oldsha, sha, ref in refs:
-            if ref == "0" * 40:
-                del self.repo.refs[ref]
+            # TODO: check refname
+            ref_error = None
+            try:
+                if ref == "0" * 40:
+                    try:
+                        del self.repo.refs[ref]
+                    except all_exceptions:
+                        ref_error = 'failed to delete'
+                else:
+                    try:
+                        self.repo.refs[ref] = sha
+                    except all_exceptions:
+                        ref_error = 'failed to write'
+            except KeyError, e:
+                ref_error = 'bad ref'
+            if ref_error:
+                status.append((ref, ref_error))
             else:
-                self.repo.refs[ref] = sha
+                status.append((ref, 'ok'))
+
 
         print "pack applied"
+        return status
 
 
 class Handler(object):
@@ -125,11 +165,12 @@ class UploadPackHandler(Handler):
         Handler.__init__(self, backend, read, write)
         self._client_capabilities = None
         self._graph_walker = None
-        self._stateless_rpc = stateless_rpc
-        self._advertise_refs = advertise_refs
+        self.stateless_rpc = stateless_rpc
+        self.advertise_refs = advertise_refs
 
     def default_capabilities(self):
-        return ("multi_ack", "side-band-64k", "thin-pack", "ofs-delta")
+        return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
+                "ofs-delta")
 
     def set_client_capabilities(self, caps):
         my_caps = self.default_capabilities()
@@ -184,6 +225,8 @@ class ProtocolGraphWalker(object):
         self.handler = handler
         self.store = handler.backend.object_store
         self.proto = handler.proto
+        self.stateless_rpc = handler.stateless_rpc
+        self.advertise_refs = handler.advertise_refs
         self._wants = []
         self._cached = False
         self._cache = []
@@ -204,15 +247,19 @@ class ProtocolGraphWalker(object):
         if not heads:
             raise GitProtocolError('No heads found')
         values = set(heads.itervalues())
-        for i, (ref, sha) in enumerate(heads.iteritems()):
-            line = "%s %s" % (sha, ref)
-            if not i:
-                line = "%s\x00%s" % (line, self.handler.capabilities())
-            self.proto.write_pkt_line("%s\n" % line)
-            # TODO: include peeled value of any tags
+        if self.advertise_refs or not self.stateless_rpc:
+            for i, (ref, sha) in enumerate(heads.iteritems()):
+                line = "%s %s" % (sha, ref)
+                if not i:
+                    line = "%s\x00%s" % (line, self.handler.capabilities())
+                self.proto.write_pkt_line("%s\n" % line)
+                # TODO: include peeled value of any tags
+
+            # i'm done..
+            self.proto.write_pkt_line(None)
 
-        # i'm done..
-        self.proto.write_pkt_line(None)
+            if self.advertise_refs:
+                return []
 
         # Now client will sending want want want commands
         want = self.proto.read_pkt_line()
@@ -246,7 +293,7 @@ class ProtocolGraphWalker(object):
 
     def next(self):
         if not self._cached:
-            if not self._impl:
+            if not self._impl and self.stateless_rpc:
                 return None
             return self._impl.next()
         self._cache_index += 1
@@ -274,6 +321,9 @@ class ProtocolGraphWalker(object):
             ('have', obj_id)
             ('done', None)
             (None, None)  (for a flush-pkt)
+
+        :raise GitProtocolError: if the line cannot be parsed into one of the
+            possible return values.
         """
         line = self.proto.read_pkt_line()
         if not line:
@@ -336,6 +386,7 @@ class ProtocolGraphWalker(object):
     def set_ack_type(self, ack_type):
         impl_classes = {
             MULTI_ACK: MultiAckGraphWalkerImpl,
+            MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
             SINGLE_ACK: SingleAckGraphWalkerImpl,
             }
         self._impl = impl_classes[ack_type](self)
@@ -402,8 +453,55 @@ class MultiAckGraphWalkerImpl(object):
                 return sha
 
 
+class MultiAckDetailedGraphWalkerImpl(object):
+    """Graph walker implementation speaking the multi-ack-detailed protocol."""
+
+    def __init__(self, walker):
+        self.walker = walker
+        self._found_base = False
+        self._common = []
+
+    def ack(self, have_ref):
+        self._common.append(have_ref)
+        if not self._found_base:
+            self.walker.send_ack(have_ref, 'common')
+            if self.walker.all_wants_satisfied(self._common):
+                self._found_base = True
+                self.walker.send_ack(have_ref, 'ready')
+        # else we blind ack within next
+
+    def next(self):
+        while True:
+            command, sha = self.walker.read_proto_line()
+            if command is None:
+                self.walker.send_nak()
+                if self.walker.stateless_rpc:
+                    return None
+                continue
+            elif command == 'done':
+                # don't nak unless no common commits were found, even if not
+                # everything is satisfied
+                if self._common:
+                    self.walker.send_ack(self._common[-1])
+                else:
+                    self.walker.send_nak()
+                return None
+            elif command == 'have':
+                if self._found_base:
+                    # blind ack; can happen if the client has more requests
+                    # inflight
+                    self.walker.send_ack(sha, 'ready')
+                return sha
+
+
 class ReceivePackHandler(Handler):
-    """Protocol handler for downloading a pack to the client."""
+    """Protocol handler for downloading a pack from the client."""
+
+    def __init__(self, backend, read, write,
+                 stateless_rpc=False, advertise_refs=False):
+        Handler.__init__(self, backend, read, write)
+        self.stateless_rpc = stateless_rpc
+        self.advertise_refs = advertise_refs
 
     def __init__(self, backend, read, write,
                  stateless_rpc=False, advertise_refs=False):
@@ -417,15 +515,18 @@ class ReceivePackHandler(Handler):
     def handle(self):
         refs = self.backend.get_refs().items()
 
-        if refs:
-            self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
-            for i in range(1, len(refs)):
-                ref = refs[i]
-                self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
-        else:
-            self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
+        if self.advertise_refs or not self.stateless_rpc:
+            if refs:
+                self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
+                for i in range(1, len(refs)):
+                    ref = refs[i]
+                    self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
+            else:
+                self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
 
-        self.proto.write("0000")
+            self.proto.write("0000")
+            if self.advertise_refs:
+                return
 
         client_refs = []
         ref = self.proto.read_pkt_line()
@@ -442,11 +543,19 @@ class ReceivePackHandler(Handler):
             ref = self.proto.read_pkt_line()
 
         # backend can now deal with this refs and read a pack using self.read
-        self.backend.apply_pack(client_refs, self.proto.read)
-
-        # when we have read all the pack from the client, it assumes 
-        # everything worked OK.
-        # there is NO ack from the server before it reports victory.
+        status = self.backend.apply_pack(client_refs, self.proto.read)
+
+        # when we have read all the pack from the client, send a status report
+        # if the client asked for it
+        if 'report-status' in client_capabilities:
+            for name, msg in status:
+                if name == 'unpack':
+                    self.proto.write_pkt_line('unpack %s\n' % msg)
+                elif msg == 'ok':
+                    self.proto.write_pkt_line('ok %s\n' % name)
+                else:
+                    self.proto.write_pkt_line('ng %s %s\n' % (name, msg))
+            self.proto.write_pkt_line(None)
 
 
 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):

+ 15 - 1
dulwich/tests/test_protocol.py

@@ -27,6 +27,10 @@ from dulwich.protocol import (
     Protocol,
     extract_capabilities,
     extract_want_line_capabilities,
+    ack_type,
+    SINGLE_ACK,
+    MULTI_ACK,
+    MULTI_ACK_DETAILED,
     )
 
 class ProtocolTests(TestCase):
@@ -78,7 +82,7 @@ class ProtocolTests(TestCase):
         self.assertRaises(AssertionError, self.proto.read_cmd)
 
 
-class ExtractCapabilitiesTestCase(TestCase):
+class CapabilitiesTestCase(TestCase):
 
     def test_plain(self):
         self.assertEquals(("bla", []), extract_capabilities("bla"))
@@ -95,3 +99,13 @@ class ExtractCapabilitiesTestCase(TestCase):
         self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la"))
         self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la\n"))
         self.assertEquals(("want bla", ["la", "la"]), extract_want_line_capabilities("want bla la la"))
+
+    def test_ack_type(self):
+        self.assertEquals(SINGLE_ACK, ack_type(['foo', 'bar']))
+        self.assertEquals(MULTI_ACK, ack_type(['foo', 'bar', 'multi_ack']))
+        self.assertEquals(MULTI_ACK_DETAILED,
+                          ack_type(['foo', 'bar', 'multi_ack_detailed']))
+        # choose detailed when both present
+        self.assertEquals(MULTI_ACK_DETAILED,
+                          ack_type(['foo', 'bar', 'multi_ack',
+                                    'multi_ack_detailed']))

+ 122 - 2
dulwich/tests/test_server.py

@@ -31,6 +31,7 @@ from dulwich.server import (
     ProtocolGraphWalker,
     SingleAckGraphWalkerImpl,
     MultiAckGraphWalkerImpl,
+    MultiAckDetailedGraphWalkerImpl,
     )
 
 from dulwich.protocol import (
@@ -122,6 +123,8 @@ class TestHandler(object):
     def __init__(self, objects, proto):
         self.backend = TestBackend(objects)
         self.proto = proto
+        self.stateless_rpc = False
+        self.advertise_refs = False
 
     def capabilities(self):
         return 'multi_ack'
@@ -215,6 +218,8 @@ class TestProtocolGraphWalker(object):
         self.acks = []
         self.lines = []
         self.done = False
+        self.stateless_rpc = False
+        self.advertise_refs = False
 
     def read_proto_line(self):
         return self.lines.pop(0)
@@ -249,10 +254,14 @@ class AckGraphWalkerImplTestCase(TestCase):
     def assertNoAck(self):
         self.assertEquals(None, self._walker.pop_ack())
 
-    def assertAck(self, sha, ack_type=''):
-        self.assertEquals((sha, ack_type), self._walker.pop_ack())
+    def assertAcks(self, acks):
+        for sha, ack_type in acks:
+            self.assertEquals((sha, ack_type), self._walker.pop_ack())
         self.assertNoAck()
 
+    def assertAck(self, sha, ack_type=''):
+        self.assertAcks([(sha, ack_type)])
+
     def assertNak(self):
         self.assertAck(None, 'nak')
 
@@ -397,3 +406,114 @@ class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
 
         self.assertNextEquals(None)
         self.assertNak()
+
+class MultiAckDetailedGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
+    impl_cls = MultiAckDetailedGraphWalkerImpl
+
+    def test_multi_ack(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self._walker.done = True
+        self._impl.ack(ONE)
+        self.assertAcks([(ONE, 'common'), (ONE, 'ready')])
+
+        self.assertNextEquals(THREE)
+        self._impl.ack(THREE)
+        self.assertAck(THREE, 'ready')
+
+        self.assertNextEquals(None)
+        self.assertAck(THREE)
+
+    def test_multi_ack_partial(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self._impl.ack(ONE)
+        self.assertAck(ONE, 'common')
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        # done, re-send ack of last common
+        self.assertAck(ONE)
+
+    def test_multi_ack_flush(self):
+        # same as ack test but contains a flush-pkt in the middle
+        self._walker.lines = [
+            ('have', TWO),
+            (None, None),
+            ('have', ONE),
+            ('have', THREE),
+            ('done', None),
+            ]
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNak() # nak the flush-pkt
+
+        self._walker.done = True
+        self._impl.ack(ONE)
+        self.assertAcks([(ONE, 'common'), (ONE, 'ready')])
+
+        self.assertNextEquals(THREE)
+        self._impl.ack(THREE)
+        self.assertAck(THREE, 'ready')
+
+        self.assertNextEquals(None)
+        self.assertAck(THREE)
+
+    def test_multi_ack_nak(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNoAck()
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNak()
+
+    def test_multi_ack_nak_flush(self):
+        # same as nak test but contains a flush-pkt in the middle
+        self._walker.lines = [
+            ('have', TWO),
+            (None, None),
+            ('have', ONE),
+            ('have', THREE),
+            ('done', None),
+            ]
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNak()
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNak()
+
+    def test_multi_ack_stateless(self):
+        # transmission ends with a flush-pkt
+        self._walker.lines[-1] = (None, None)
+        self._walker.stateless_rpc = True
+
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNoAck()
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNak()

+ 27 - 0
dulwich/web.py

@@ -163,6 +163,27 @@ def get_info_packs(req, backend, mat):
         yield 'P pack-%s.pack\n' % pack.name()
 
 
+class _LengthLimitedFile(object):
+    """Wrapper class to limit the length of reads from a file-like object.
+
+    This is used to ensure EOF is read from the wsgi.input object once
+    Content-Length bytes are read. This behavior is required by the WSGI spec
+    but not implemented in wsgiref as of 2.5.
+    """
+    def __init__(self, input, max_bytes):
+        self._input = input
+        self._bytes_avail = max_bytes
+
+    def read(self, size=-1):
+        if self._bytes_avail <= 0:
+            return ''
+        if size == -1 or size > self._bytes_avail:
+            size = self._bytes_avail
+        self._bytes_avail -= size
+        return self._input.read(size)
+
+    # TODO: support more methods as necessary
+
 def handle_service_request(req, backend, mat):
     service = mat.group().lstrip('/')
     handler_cls = services.get(service, None)
@@ -173,6 +194,12 @@ def handle_service_request(req, backend, mat):
 
     output = StringIO()
     input = req.environ['wsgi.input']
+    # This is not necessary if this app is run from a conforming WSGI server.
+    # Unfortunately, there's no way to tell that at this point.
+    # TODO: git may used HTTP/1.1 chunked encoding instead of specifying
+    # content-length
+    if 'CONTENT_LENGTH' in req.environ:
+        input = _LengthLimitedFile(input, int(req.environ['CONTENT_LENGTH']))
     handler = handler_cls(backend, input.read, output.write, stateless_rpc=True)
     handler.handle()
     yield output.getvalue()