瀏覽代碼

Make the server decode a pack as it streams.

This, in combination with using recv() instead of read(), makes it so we
never do blocking reads past the end of the pack stream, even when the
client doesn't close the connection.

This is done via a PackStreamVerifier class that reads from a Protocol,
unpacks and counts objects, writes through to a file, and computes the
SHA-1 checksum on the fly. It is necessary because the only way we know
when the pack is supposed to end is by parsing the header and reading
the correct number of objects; otherwise, any further reads from the
client would hang.

Changed the Handler constructors to take a Protocol instead of taking
read and write callbacks separately. Modified some interfaces to utility
functions in pack.py so they can be used by the server-side code.

Change-Id: Id4d11e34658d1f00ad06e45330d0d128b367d8e5
Dave Borowitz 15 年之前
父節點
當前提交
abf29867f8
共有 6 個文件被更改,包括 150 次插入52 次删除
  1. 18 13
      dulwich/pack.py
  2. 118 16
      dulwich/server.py
  3. 1 0
      dulwich/tests/compat/test_server.py
  4. 1 1
      dulwich/tests/test_server.py
  5. 5 17
      dulwich/tests/test_web.py
  6. 7 5
      dulwich/web.py

+ 18 - 13
dulwich/pack.py

@@ -417,12 +417,12 @@ class PackIndex2(PackIndex):
   
   
 
 
 
 
-def read_pack_header(f):
+def read_pack_header(read):
     """Read the header of a pack file.
     """Read the header of a pack file.
 
 
-    :param f: File-like object to read from
+    :param read: Read function
     """
     """
-    header = f.read(12)
+    header = read(12)
     assert header[:4] == "PACK"
     assert header[:4] == "PACK"
     (version,) = unpack_from(">L", header, 4)
     (version,) = unpack_from(">L", header, 4)
     assert version in (2, 3), "Version was %d" % version
     assert version in (2, 3), "Version was %d" % version
@@ -434,20 +434,25 @@ def chunks_length(chunks):
     return sum(imap(len, chunks))
     return sum(imap(len, chunks))
 
 
 
 
-def unpack_object(read):
+def unpack_object(read_all, read_some=None):
     """Unpack a Git object.
     """Unpack a Git object.
 
 
-    :return: tuple with type, uncompressed data as chunks, compressed size and 
-        tail data
+    :param read_all: Read function that blocks until the number of requested
+        bytes are read.
+    :param read_some: Read function that returns at least one byte, but may not
+        return the number of bytes requested.
+    :return: tuple with type, uncompressed data, compressed size and tail data.
     """
     """
-    bytes = take_msb_bytes(read)
+    if read_some is None:
+        read_some = read_all
+    bytes = take_msb_bytes(read_all)
     type = (bytes[0] >> 4) & 0x07
     type = (bytes[0] >> 4) & 0x07
     size = bytes[0] & 0x0f
     size = bytes[0] & 0x0f
     for i, byte in enumerate(bytes[1:]):
     for i, byte in enumerate(bytes[1:]):
         size += (byte & 0x7f) << ((i * 7) + 4)
         size += (byte & 0x7f) << ((i * 7) + 4)
     raw_base = len(bytes)
     raw_base = len(bytes)
     if type == 6: # offset delta
     if type == 6: # offset delta
-        bytes = take_msb_bytes(read)
+        bytes = take_msb_bytes(read_all)
         raw_base += len(bytes)
         raw_base += len(bytes)
         assert not (bytes[-1] & 0x80)
         assert not (bytes[-1] & 0x80)
         delta_base_offset = bytes[0] & 0x7f
         delta_base_offset = bytes[0] & 0x7f
@@ -455,17 +460,17 @@ def unpack_object(read):
             delta_base_offset += 1
             delta_base_offset += 1
             delta_base_offset <<= 7
             delta_base_offset <<= 7
             delta_base_offset += (byte & 0x7f)
             delta_base_offset += (byte & 0x7f)
-        uncomp, comp_len, unused = read_zlib_chunks(read, size)
+        uncomp, comp_len, unused = read_zlib_chunks(read_some, size)
         assert size == chunks_length(uncomp)
         assert size == chunks_length(uncomp)
         return type, (delta_base_offset, uncomp), comp_len+raw_base, unused
         return type, (delta_base_offset, uncomp), comp_len+raw_base, unused
     elif type == 7: # ref delta
     elif type == 7: # ref delta
-        basename = read(20)
+        basename = read_all(20)
         raw_base += 20
         raw_base += 20
-        uncomp, comp_len, unused = read_zlib_chunks(read, size)
+        uncomp, comp_len, unused = read_zlib_chunks(read_some, size)
         assert size == chunks_length(uncomp)
         assert size == chunks_length(uncomp)
         return type, (basename, uncomp), comp_len+raw_base, unused
         return type, (basename, uncomp), comp_len+raw_base, unused
     else:
     else:
-        uncomp, comp_len, unused = read_zlib_chunks(read, size)
+        uncomp, comp_len, unused = read_zlib_chunks(read_some, size)
         assert chunks_length(uncomp) == size
         assert chunks_length(uncomp) == size
         return type, uncomp, comp_len+raw_base, unused
         return type, uncomp, comp_len+raw_base, unused
 
 
@@ -522,7 +527,7 @@ class PackData(object):
             self._file = GitFile(self._filename, 'rb')
             self._file = GitFile(self._filename, 'rb')
         else:
         else:
             self._file = file
             self._file = file
-        (version, self._num_objects) = read_pack_header(self._file)
+        (version, self._num_objects) = read_pack_header(self._file.read)
         self._offset_cache = LRUSizeCache(1024*1024*20, 
         self._offset_cache = LRUSizeCache(1024*1024*20, 
             compute_size=_compute_object_size)
             compute_size=_compute_object_size)
 
 

+ 118 - 16
dulwich/server.py

@@ -27,15 +27,22 @@ Documentation/technical directory in the cgit distribution, and in particular:
 
 
 
 
 import collections
 import collections
+from cStringIO import StringIO
+import socket
 import SocketServer
 import SocketServer
+import zlib
 
 
 from dulwich.errors import (
 from dulwich.errors import (
     ApplyDeltaError,
     ApplyDeltaError,
     ChecksumMismatch,
     ChecksumMismatch,
     GitProtocolError,
     GitProtocolError,
     )
     )
+from dulwich.misc import (
+    make_sha,
+    )
 from dulwich.objects import (
 from dulwich.objects import (
     hex_to_sha,
     hex_to_sha,
+    sha_to_hex,
     )
     )
 from dulwich.protocol import (
 from dulwich.protocol import (
     ProtocolFile,
     ProtocolFile,
@@ -51,6 +58,8 @@ from dulwich.protocol import (
     ack_type,
     ack_type,
     )
     )
 from dulwich.pack import (
 from dulwich.pack import (
+    read_pack_header,
+    unpack_object,
     write_pack_data,
     write_pack_data,
     )
     )
 
 
@@ -103,6 +112,105 @@ class BackendRepo(object):
         raise NotImplementedError
         raise NotImplementedError
 
 
 
 
+class PackStreamVerifier(object):
+    """Class to verify a pack stream as it is being read.
+
+    The pack is read from a ReceivableProtocol using read() or recv() as
+    appropriate and written out to the given file-like object.
+    """
+
+    def __init__(self, proto, outfile):
+        self.proto = proto
+        self.outfile = outfile
+        self.sha = make_sha()
+        self._rbuf = StringIO()
+        # trailer is a deque to avoid memory allocation on small reads
+        self._trailer = collections.deque()
+
+    def _read(self, read, size):
+        """Read up to size bytes using the given callback.
+
+        As a side effect, update the verifier's hash (excluding the last 20
+        bytes read) and write through to the output file.
+
+        :param read: The read callback to read from.
+        :param size: The maximum number of bytes to read; the particular
+            behavior is callback-specific.
+        """
+        data = read(size)
+
+        # maintain a trailer of the last 20 bytes we've read
+        n = len(data)
+        tn = len(self._trailer)
+        if n >= 20:
+            to_pop = tn
+            to_add = 20
+        else:
+            to_pop = max(n + tn - 20, 0)
+            to_add = n
+        for _ in xrange(to_pop):
+            self.sha.update(self._trailer.popleft())
+        self._trailer.extend(data[-to_add:])
+
+        # hash everything but the trailer
+        self.sha.update(data[:-to_add])
+        self.outfile.write(data)
+        return data
+
+    def _buf_len(self):
+        buf = self._rbuf
+        start = buf.tell()
+        buf.seek(0, 2)
+        end = buf.tell()
+        buf.seek(start)
+        return end - start
+
+    def read(self, size):
+        """Read, blocking until size bytes are read."""
+        buf_len = self._buf_len()
+        if buf_len >= size:
+            return self._rbuf.read(size)
+        buf_data = self._rbuf.read()
+        self._rbuf = StringIO()
+        return buf_data + self._read(self.proto.read, size - buf_len)
+
+    def recv(self, size):
+        """Read up to size bytes, blocking until one byte is read."""
+        buf_len = self._buf_len()
+        if buf_len:
+            data = self._rbuf.read(size)
+            if size >= buf_len:
+                self._rbuf = StringIO()
+            return data
+        return self._read(self.proto.recv, size)
+
+    def verify(self):
+        """Verify a pack stream and write it to the output file.
+
+        :raise AssertionError: if there is an error in the pack format.
+        :raise ChecksumMismatch: if the checksum of the pack contents does not
+            match the checksum in the pack trailer.
+        :raise socket.error: if an error occurred reading from the socket.
+        :raise zlib.error: if an error occurred during zlib decompression.
+        :raise IOError: if an error occurred writing to the output file.
+        """
+        _, num_objects = read_pack_header(self.read)
+        for i in xrange(num_objects):
+            type, _, _, unused = unpack_object(self.read, self.recv)
+
+            # prepend any unused data to current read buffer
+            buf = StringIO()
+            buf.write(unused)
+            buf.write(self._rbuf.read())
+            buf.seek(0)
+            self._rbuf = buf
+
+        pack_sha = sha_to_hex(''.join([c for c in self._trailer]))
+        calculated_sha = self.sha.hexdigest()
+        if pack_sha != calculated_sha:
+            raise ChecksumMismatch(pack_sha, calculated_sha)
+
+
 class DictBackend(Backend):
 class DictBackend(Backend):
     """Trivial backend that looks up Git repositories in a dictionary."""
     """Trivial backend that looks up Git repositories in a dictionary."""
 
 
@@ -117,9 +225,9 @@ class DictBackend(Backend):
 class Handler(object):
 class Handler(object):
     """Smart protocol command handler base class."""
     """Smart protocol command handler base class."""
 
 
-    def __init__(self, backend, read, write):
+    def __init__(self, backend, proto):
         self.backend = backend
         self.backend = backend
-        self.proto = Protocol(read, write)
+        self.proto = proto
         self._client_capabilities = None
         self._client_capabilities = None
 
 
     def capability_line(self):
     def capability_line(self):
@@ -158,9 +266,9 @@ class Handler(object):
 class UploadPackHandler(Handler):
 class UploadPackHandler(Handler):
     """Protocol handler for uploading a pack to the server."""
     """Protocol handler for uploading a pack to the server."""
 
 
-    def __init__(self, backend, args, read, write,
+    def __init__(self, backend, args, proto,
                  stateless_rpc=False, advertise_refs=False):
                  stateless_rpc=False, advertise_refs=False):
-        Handler.__init__(self, backend, read, write)
+        Handler.__init__(self, backend, proto)
         self.repo = backend.open_repository(args[0])
         self.repo = backend.open_repository(args[0])
         self._graph_walker = None
         self._graph_walker = None
         self.stateless_rpc = stateless_rpc
         self.stateless_rpc = stateless_rpc
@@ -522,9 +630,9 @@ class MultiAckDetailedGraphWalkerImpl(object):
 class ReceivePackHandler(Handler):
 class ReceivePackHandler(Handler):
     """Protocol handler for downloading a pack from the client."""
     """Protocol handler for downloading a pack from the client."""
 
 
-    def __init__(self, backend, args, read, write,
+    def __init__(self, backend, args, proto,
                  stateless_rpc=False, advertise_refs=False):
                  stateless_rpc=False, advertise_refs=False):
-        Handler.__init__(self, backend, read, write)
+        Handler.__init__(self, backend, proto)
         self.repo = backend.open_repository(args[0])
         self.repo = backend.open_repository(args[0])
         self.stateless_rpc = stateless_rpc
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
         self.advertise_refs = advertise_refs
@@ -532,20 +640,14 @@ class ReceivePackHandler(Handler):
     def capabilities(self):
     def capabilities(self):
         return ("report-status", "delete-refs")
         return ("report-status", "delete-refs")
 
 
-    def _apply_pack(self, refs, read):
+    def _apply_pack(self, refs):
         f, commit = self.repo.object_store.add_thin_pack()
         f, commit = self.repo.object_store.add_thin_pack()
         all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
         all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
         status = []
         status = []
         unpack_error = None
         unpack_error = None
         # TODO: more informative error messages than just the exception string
         # TODO: more informative error messages than just the exception string
         try:
         try:
-            # TODO: decode the pack as we stream to avoid blocking reads beyond
-            # the end of data (when using HTTP/1.1 chunked encoding)
-            while True:
-                data = read(10240)
-                if not data:
-                    break
-                f.write(data)
+            PackStreamVerifier(self.proto, f).verify()
         except all_exceptions, e:
         except all_exceptions, e:
             unpack_error = str(e).replace('\n', '')
             unpack_error = str(e).replace('\n', '')
         try:
         try:
@@ -620,7 +722,7 @@ class ReceivePackHandler(Handler):
             ref = self.proto.read_pkt_line()
             ref = self.proto.read_pkt_line()
 
 
         # backend can now deal with this refs and read a pack using self.read
         # backend can now deal with this refs and read a pack using self.read
-        status = self.repo._apply_pack(client_refs, self.proto.read)
+        status = self._apply_pack(client_refs)
 
 
         # when we have read all the pack from the client, send a status report
         # when we have read all the pack from the client, send a status report
         # if the client asked for it
         # if the client asked for it
@@ -649,7 +751,7 @@ class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
         else:
         else:
             return
             return
 
 
-        h = cls(self.server.backend, args, self.rfile.read, self.wfile.write)
+        h = cls(self.server.backend, args, proto)
         h.handle()
         h.handle()
 
 
 
 

+ 1 - 0
dulwich/tests/compat/test_server.py

@@ -78,4 +78,5 @@ class GitServerTestCase(ServerTests, CompatTestCase):
         return port
         return port
 
 
     def test_push_to_dulwich(self):
     def test_push_to_dulwich(self):
+        # TODO(dborowitz): enable after merging thin pack fixes.
         raise TestSkipped('Skipping push test due to known deadlock bug.')
         raise TestSkipped('Skipping push test due to known deadlock bug.')

+ 1 - 1
dulwich/tests/test_server.py

@@ -79,7 +79,7 @@ class TestProto(object):
 class HandlerTestCase(TestCase):
 class HandlerTestCase(TestCase):
 
 
     def setUp(self):
     def setUp(self):
-        self._handler = Handler(Backend(), None, None)
+        self._handler = Handler(Backend(), None)
         self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
         self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
         self._handler.required_capabilities = lambda: ('cap2',)
         self._handler.required_capabilities = lambda: ('cap2',)
 
 

+ 5 - 17
dulwich/tests/test_web.py

@@ -153,28 +153,16 @@ class DumbHandlersTestCase(WebTestCase):
 
 
 class SmartHandlersTestCase(WebTestCase):
 class SmartHandlersTestCase(WebTestCase):
 
 
-    class TestProtocol(object):
-        def __init__(self, handler):
-            self._handler = handler
-
-        def write_pkt_line(self, line):
-            if line is None:
-                self._handler.write('flush-pkt\n')
-            else:
-                self._handler.write('pkt-line: %s' % line)
-
     class _TestUploadPackHandler(object):
     class _TestUploadPackHandler(object):
-        def __init__(self, backend, args, read, write, stateless_rpc=False,
+        def __init__(self, backend, args, proto, stateless_rpc=False,
                      advertise_refs=False):
                      advertise_refs=False):
             self.args = args
             self.args = args
-            self.read = read
-            self.write = write
-            self.proto = SmartHandlersTestCase.TestProtocol(self)
+            self.proto = proto
             self.stateless_rpc = stateless_rpc
             self.stateless_rpc = stateless_rpc
             self.advertise_refs = advertise_refs
             self.advertise_refs = advertise_refs
 
 
         def handle(self):
         def handle(self):
-            self.write('handled input: %s' % self.read())
+            self.proto.write('handled input: %s' % self.proto.recv(1024))
 
 
     def _MakeHandler(self, *args, **kwargs):
     def _MakeHandler(self, *args, **kwargs):
         self._handler = self._TestUploadPackHandler(*args, **kwargs)
         self._handler = self._TestUploadPackHandler(*args, **kwargs)
@@ -222,8 +210,8 @@ class SmartHandlersTestCase(WebTestCase):
         mat = re.search('.*', '/git-upload-pack')
         mat = re.search('.*', '/git-upload-pack')
         output = ''.join(get_info_refs(self._req, 'backend', mat,
         output = ''.join(get_info_refs(self._req, 'backend', mat,
                                        services=self.services()))
                                        services=self.services()))
-        self.assertEquals(('pkt-line: # service=git-upload-pack\n'
-                           'flush-pkt\n'
+        self.assertEquals(('001e# service=git-upload-pack\n'
+                           '0000'
                            # input is ignored by the handler
                            # input is ignored by the handler
                            'handled input: '), output)
                            'handled input: '), output)
         self.assertTrue(self._handler.advertise_refs)
         self.assertTrue(self._handler.advertise_refs)

+ 7 - 5
dulwich/web.py

@@ -26,6 +26,9 @@ try:
     from urlparse import parse_qs
     from urlparse import parse_qs
 except ImportError:
 except ImportError:
     from dulwich.misc import parse_qs
     from dulwich.misc import parse_qs
+from dulwich.protocol import (
+    ReceivableProtocol,
+    )
 from dulwich.server import (
 from dulwich.server import (
     ReceivePackHandler,
     ReceivePackHandler,
     UploadPackHandler,
     UploadPackHandler,
@@ -138,9 +141,8 @@ def get_info_refs(req, backend, mat, services=None):
         req.nocache()
         req.nocache()
         req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
         req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
         output = StringIO()
         output = StringIO()
-        dummy_input = StringIO()  # GET request, handler doesn't need to read
-        handler = handler_cls(backend, [url_prefix(mat)],
-                              dummy_input.read, output.write,
+        proto = ReceivableProtocol(StringIO().read, output.write)
+        handler = handler_cls(backend, [url_prefix(mat)], proto,
                               stateless_rpc=True, advertise_refs=True)
                               stateless_rpc=True, advertise_refs=True)
         handler.proto.write_pkt_line('# service=%s\n' % service)
         handler.proto.write_pkt_line('# service=%s\n' % service)
         handler.proto.write_pkt_line(None)
         handler.proto.write_pkt_line(None)
@@ -216,8 +218,8 @@ def handle_service_request(req, backend, mat, services=None):
     # content-length
     # content-length
     if 'CONTENT_LENGTH' in req.environ:
     if 'CONTENT_LENGTH' in req.environ:
         input = _LengthLimitedFile(input, int(req.environ['CONTENT_LENGTH']))
         input = _LengthLimitedFile(input, int(req.environ['CONTENT_LENGTH']))
-    handler = handler_cls(backend, [url_prefix(mat)], input.read, output.write,
-                          stateless_rpc=True)
+    proto = ReceivableProtocol(input.read, output.write)
+    handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True)
     handler.handle()
     handler.handle()
     yield output.getvalue()
     yield output.getvalue()