|
@@ -20,11 +20,12 @@
|
|
|
"""Tests for the smart protocol utility functions."""
|
|
|
|
|
|
|
|
|
-from cStringIO import StringIO
|
|
|
+from StringIO import StringIO
|
|
|
from unittest import TestCase
|
|
|
|
|
|
from dulwich.protocol import (
|
|
|
Protocol,
|
|
|
+ ReceivableProtocol,
|
|
|
extract_capabilities,
|
|
|
extract_want_line_capabilities,
|
|
|
ack_type,
|
|
@@ -33,12 +34,7 @@ from dulwich.protocol import (
|
|
|
MULTI_ACK_DETAILED,
|
|
|
)
|
|
|
|
|
|
-class ProtocolTests(TestCase):
|
|
|
-
|
|
|
- def setUp(self):
|
|
|
- self.rout = StringIO()
|
|
|
- self.rin = StringIO()
|
|
|
- self.proto = Protocol(self.rin.read, self.rout.write)
|
|
|
+class BaseProtocolTests(object):
|
|
|
|
|
|
def test_write_pkt_line_none(self):
|
|
|
self.proto.write_pkt_line(None)
|
|
@@ -82,6 +78,93 @@ class ProtocolTests(TestCase):
|
|
|
self.assertRaises(AssertionError, self.proto.read_cmd)
|
|
|
|
|
|
|
|
|
+class ProtocolTests(BaseProtocolTests, TestCase):
|
|
|
+
|
|
|
+ def setUp(self):
|
|
|
+ TestCase.setUp(self)
|
|
|
+ self.rout = StringIO()
|
|
|
+ self.rin = StringIO()
|
|
|
+ self.proto = Protocol(self.rin.read, self.rout.write)
|
|
|
+
|
|
|
+
|
|
|
+class ReceivableStringIO(StringIO):
|
|
|
+ """StringIO with socket-like recv semantics for testing."""
|
|
|
+
|
|
|
+ def recv(self, size):
|
|
|
+ # fail fast if no bytes are available; in a real socket, this would
|
|
|
+ # block forever
|
|
|
+ if self.tell() == len(self.getvalue()):
|
|
|
+ raise AssertionError("Blocking read past end of socket")
|
|
|
+ if size == 1:
|
|
|
+ return self.read(1)
|
|
|
+ # calls shouldn't return quite as much as asked for
|
|
|
+ return self.read(size - 1)
|
|
|
+
|
|
|
+
|
|
|
+class ReceivableProtocolTests(BaseProtocolTests, TestCase):
|
|
|
+
|
|
|
+ def setUp(self):
|
|
|
+ TestCase.setUp(self)
|
|
|
+ self.rout = StringIO()
|
|
|
+ self.rin = ReceivableStringIO()
|
|
|
+ self.proto = ReceivableProtocol(self.rin.recv, self.rout.write)
|
|
|
+ self.proto._rbufsize = 8
|
|
|
+
|
|
|
+ def test_recv(self):
|
|
|
+ all_data = "1234567" * 10 # not a multiple of bufsize
|
|
|
+ self.rin.write(all_data)
|
|
|
+ self.rin.seek(0)
|
|
|
+ data = ""
|
|
|
+ # We ask for 8 bytes each time and actually read 7, so it should take
|
|
|
+ # exactly 10 iterations.
|
|
|
+ for _ in xrange(10):
|
|
|
+ data += self.proto.recv(10)
|
|
|
+ # any more reads would block
|
|
|
+ self.assertRaises(AssertionError, self.proto.recv, 10)
|
|
|
+ self.assertEquals(all_data, data)
|
|
|
+
|
|
|
+ def test_recv_read(self):
|
|
|
+ all_data = "1234567" # recv exactly in one call
|
|
|
+ self.rin.write(all_data)
|
|
|
+ self.rin.seek(0)
|
|
|
+ self.assertEquals("1234", self.proto.recv(4))
|
|
|
+ self.assertEquals("567", self.proto.read(3))
|
|
|
+ self.assertRaises(AssertionError, self.proto.recv, 10)
|
|
|
+
|
|
|
+ def test_read_recv(self):
|
|
|
+ all_data = "12345678abcdefg"
|
|
|
+ self.rin.write(all_data)
|
|
|
+ self.rin.seek(0)
|
|
|
+ self.assertEquals("1234", self.proto.read(4))
|
|
|
+ self.assertEquals("5678abc", self.proto.recv(8))
|
|
|
+ self.assertEquals("defg", self.proto.read(4))
|
|
|
+ self.assertRaises(AssertionError, self.proto.recv, 10)
|
|
|
+
|
|
|
+ def test_mixed(self):
|
|
|
+ # arbitrary non-repeating string
|
|
|
+ all_data = ",".join(str(i) for i in xrange(100))
|
|
|
+ self.rin.write(all_data)
|
|
|
+ self.rin.seek(0)
|
|
|
+ data = ""
|
|
|
+
|
|
|
+ for i in xrange(1, 100):
|
|
|
+ data += self.proto.recv(i)
|
|
|
+ # if we get to the end, do a non-blocking read instead of blocking
|
|
|
+ if len(data) + i > len(all_data):
|
|
|
+ data += self.proto.recv(i)
|
|
|
+ # ReceivableStringIO leaves off the last byte unless we ask
|
|
|
+ # nicely
|
|
|
+ data += self.proto.recv(1)
|
|
|
+ break
|
|
|
+ else:
|
|
|
+ data += self.proto.read(i)
|
|
|
+ else:
|
|
|
+ # didn't break, something must have gone wrong
|
|
|
+ self.fail()
|
|
|
+
|
|
|
+ self.assertEquals(all_data, data)
|
|
|
+
|
|
|
+
|
|
|
class CapabilitiesTestCase(TestCase):
|
|
|
|
|
|
def test_plain(self):
|