Ver código fonte

Add ReceivableProtocol that supports recv() as well as read().

There are two different ways of dealing with reads beyond the end of a
TCP stream buffer. The usual rfile.read method blocks until exactly the
number of bytes (or EOF) are read. This causes deadlocks when we try to
do buffered reads from the stream. When reading a packfile. the cgit
client leaves the connection open, so EOF is never read.

The other method is to use the socket.recv method, which blocks until
at least one byte is read, then returns whatever is available (up to
bufsize).

ReceivableProtocol supports both ways of dealing with reads past the
end of a stream, which should be used in different contexts. Note that
currently the ReceivePackHandler code still deadlocks, since it doesn't
know when to stop calling recv.

Change-Id: I77bfea72bb09326d9946f64426637296d0389714
Dave Borowitz 15 anos atrás
pai
commit
ddacb6e886
3 arquivos alterados com 201 adições e 9 exclusões
  1. 108 0
      dulwich/protocol.py
  2. 3 2
      dulwich/server.py
  3. 90 7
      dulwich/tests/test_protocol.py

+ 108 - 0
dulwich/protocol.py

@@ -19,6 +19,8 @@
 
 """Generic functions for talking the git smart server protocol."""
 
+from cStringIO import StringIO
+import os
 import socket
 
 from dulwich.errors import (
@@ -162,6 +164,112 @@ class Protocol(object):
         return cmd, args[:-1].split(chr(0))
 
 
+_RBUFSIZE = 8192  # Default read buffer size.
+
+
+class ReceivableProtocol(Protocol):
+    """Variant of Protocol that allows reading up to a size without blocking.
+
+    This class has a recv() method that behaves like socket.recv() in addition
+    to a read() method.
+
+    If you want to read n bytes from the wire and block until exactly n bytes
+    (or EOF) are read, use read(n). If you want to read at most n bytes from the
+    wire but don't care if you get less, use recv(n). Note that recv(n) will
+    still block until at least one byte is read.
+    """
+
+    def __init__(self, recv, write, report_activity=None, rbufsize=_RBUFSIZE):
+        super(ReceivableProtocol, self).__init__(self.read, write,
+                                                report_activity)
+        self._recv = recv
+        self._rbuf = StringIO()
+        self._rbufsize = rbufsize
+
+    def read(self, size):
+        # From _fileobj.read in socket.py in the Python 2.6.5 standard library,
+        # with the following modifications:
+        #  - omit the size <= 0 branch
+        #  - seek back to start rather than 0 in case some buffer has been
+        #    consumed.
+        #  - use os.SEEK_END instead of the magic number.
+        # Copyright (c) 2001-2010 Python Software Foundation; All Rights Reserved
+        # Licensed under the Python Software Foundation License.
+        # TODO: see if buffer is more efficient than cStringIO.
+        assert size > 0
+
+        # Our use of StringIO rather than lists of string objects returned by
+        # recv() minimizes memory usage and fragmentation that occurs when
+        # rbufsize is large compared to the typical return value of recv().
+        buf = self._rbuf
+        start = buf.tell()
+        buf.seek(0, os.SEEK_END)
+        # buffer may have been partially consumed by recv()
+        buf_len = buf.tell() - start
+        if buf_len >= size:
+            # Already have size bytes in our buffer?  Extract and return.
+            buf.seek(start)
+            rv = buf.read(size)
+            self._rbuf = StringIO()
+            self._rbuf.write(buf.read())
+            self._rbuf.seek(0)
+            return rv
+
+        self._rbuf = StringIO()  # reset _rbuf.  we consume it via buf.
+        while True:
+            left = size - buf_len
+            # recv() will malloc the amount of memory given as its
+            # parameter even though it often returns much less data
+            # than that.  The returned data string is short lived
+            # as we copy it into a StringIO and free it.  This avoids
+            # fragmentation issues on many platforms.
+            data = self._recv(left)
+            if not data:
+                break
+            n = len(data)
+            if n == size and not buf_len:
+                # Shortcut.  Avoid buffer data copies when:
+                # - We have no data in our buffer.
+                # AND
+                # - Our call to recv returned exactly the
+                #   number of bytes we were asked to read.
+                return data
+            if n == left:
+                buf.write(data)
+                del data  # explicit free
+                break
+            assert n <= left, "_recv(%d) returned %d bytes" % (left, n)
+            buf.write(data)
+            buf_len += n
+            del data  # explicit free
+            #assert buf_len == buf.tell()
+        buf.seek(start)
+        return buf.read()
+
+    def recv(self, size):
+        assert size > 0
+
+        buf = self._rbuf
+        start = buf.tell()
+        buf.seek(0, os.SEEK_END)
+        buf_len = buf.tell()
+        buf.seek(start)
+
+        left = buf_len - start
+        if not left:
+            # only read from the wire if our read buffer is exhausted
+            data = self._recv(self._rbufsize)
+            if len(data) == size:
+                # shortcut: skip the buffer if we read exactly size bytes
+                return data
+            buf = StringIO()
+            buf.write(data)
+            buf.seek(0)
+            del data  # explicit free
+            self._rbuf = buf
+        return buf.read(size)
+
+
 def extract_capabilities(text):
     """Extract a capabilities list from a string, if present.
 

+ 3 - 2
dulwich/server.py

@@ -38,8 +38,9 @@ from dulwich.objects import (
     hex_to_sha,
     )
 from dulwich.protocol import (
-    Protocol,
     ProtocolFile,
+    Protocol,
+    ReceivableProtocol,
     TCP_GIT_PORT,
     ZERO_SHA,
     extract_capabilities,
@@ -637,7 +638,7 @@ class ReceivePackHandler(Handler):
 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
 
     def handle(self):
-        proto = Protocol(self.rfile.read, self.wfile.write)
+        proto = ReceivableProtocol(self.connection.recv, self.wfile.write)
         command, args = proto.read_cmd()
 
         # switch case to handle the specific git command

+ 90 - 7
dulwich/tests/test_protocol.py

@@ -20,11 +20,12 @@
 """Tests for the smart protocol utility functions."""
 
 
-from cStringIO import StringIO
+from StringIO import StringIO
 from unittest import TestCase
 
 from dulwich.protocol import (
     Protocol,
+    ReceivableProtocol,
     extract_capabilities,
     extract_want_line_capabilities,
     ack_type,
@@ -33,12 +34,7 @@ from dulwich.protocol import (
     MULTI_ACK_DETAILED,
     )
 
-class ProtocolTests(TestCase):
-
-    def setUp(self):
-        self.rout = StringIO()
-        self.rin = StringIO()
-        self.proto = Protocol(self.rin.read, self.rout.write)
+class BaseProtocolTests(object):
 
     def test_write_pkt_line_none(self):
         self.proto.write_pkt_line(None)
@@ -82,6 +78,93 @@ class ProtocolTests(TestCase):
         self.assertRaises(AssertionError, self.proto.read_cmd)
 
 
+class ProtocolTests(BaseProtocolTests, TestCase):
+
+    def setUp(self):
+        TestCase.setUp(self)
+        self.rout = StringIO()
+        self.rin = StringIO()
+        self.proto = Protocol(self.rin.read, self.rout.write)
+
+
+class ReceivableStringIO(StringIO):
+    """StringIO with socket-like recv semantics for testing."""
+
+    def recv(self, size):
+        # fail fast if no bytes are available; in a real socket, this would
+        # block forever
+        if self.tell() == len(self.getvalue()):
+            raise AssertionError("Blocking read past end of socket")
+        if size == 1:
+            return self.read(1)
+        # calls shouldn't return quite as much as asked for
+        return self.read(size - 1)
+
+
+class ReceivableProtocolTests(BaseProtocolTests, TestCase):
+
+    def setUp(self):
+        TestCase.setUp(self)
+        self.rout = StringIO()
+        self.rin = ReceivableStringIO()
+        self.proto = ReceivableProtocol(self.rin.recv, self.rout.write)
+        self.proto._rbufsize = 8
+
+    def test_recv(self):
+        all_data = "1234567" * 10  # not a multiple of bufsize
+        self.rin.write(all_data)
+        self.rin.seek(0)
+        data = ""
+        # We ask for 8 bytes each time and actually read 7, so it should take
+        # exactly 10 iterations.
+        for _ in xrange(10):
+            data += self.proto.recv(10)
+        # any more reads would block
+        self.assertRaises(AssertionError, self.proto.recv, 10)
+        self.assertEquals(all_data, data)
+
+    def test_recv_read(self):
+        all_data = "1234567"  # recv exactly in one call
+        self.rin.write(all_data)
+        self.rin.seek(0)
+        self.assertEquals("1234", self.proto.recv(4))
+        self.assertEquals("567", self.proto.read(3))
+        self.assertRaises(AssertionError, self.proto.recv, 10)
+
+    def test_read_recv(self):
+        all_data = "12345678abcdefg"
+        self.rin.write(all_data)
+        self.rin.seek(0)
+        self.assertEquals("1234", self.proto.read(4))
+        self.assertEquals("5678abc", self.proto.recv(8))
+        self.assertEquals("defg", self.proto.read(4))
+        self.assertRaises(AssertionError, self.proto.recv, 10)
+
+    def test_mixed(self):
+        # arbitrary non-repeating string
+        all_data = ",".join(str(i) for i in xrange(100))
+        self.rin.write(all_data)
+        self.rin.seek(0)
+        data = ""
+
+        for i in xrange(1, 100):
+            data += self.proto.recv(i)
+            # if we get to the end, do a non-blocking read instead of blocking
+            if len(data) + i > len(all_data):
+                data += self.proto.recv(i)
+                # ReceivableStringIO leaves off the last byte unless we ask
+                # nicely
+                data += self.proto.recv(1)
+                break
+            else:
+                data += self.proto.read(i)
+        else:
+            # didn't break, something must have gone wrong
+            self.fail()
+
+        self.assertEquals(all_data, data)
+
+
 class CapabilitiesTestCase(TestCase):
 
     def test_plain(self):