test_protocol.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. # test_protocol.py -- Tests for the git protocol
  2. # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
  3. #
  4. # This program is free software; you can redistribute it and/or
  5. # modify it under the terms of the GNU General Public License
  6. # as published by the Free Software Foundation; version 2
  7. # or (at your option) any later version of the License.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU General Public License
  15. # along with this program; if not, write to the Free Software
  16. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
  17. # MA 02110-1301, USA.
  18. """Tests for the smart protocol utility functions."""
  19. from StringIO import StringIO
  20. from dulwich.errors import (
  21. HangupException,
  22. )
  23. from dulwich.protocol import (
  24. PktLineParser,
  25. Protocol,
  26. ReceivableProtocol,
  27. extract_capabilities,
  28. extract_want_line_capabilities,
  29. ack_type,
  30. SINGLE_ACK,
  31. MULTI_ACK,
  32. MULTI_ACK_DETAILED,
  33. BufferedPktLineWriter,
  34. )
  35. from dulwich.tests import TestCase
  36. class BaseProtocolTests(object):
  37. def test_write_pkt_line_none(self):
  38. self.proto.write_pkt_line(None)
  39. self.assertEquals(self.rout.getvalue(), '0000')
  40. def test_write_pkt_line(self):
  41. self.proto.write_pkt_line('bla')
  42. self.assertEquals(self.rout.getvalue(), '0007bla')
  43. def test_read_pkt_line(self):
  44. self.rin.write('0008cmd ')
  45. self.rin.seek(0)
  46. self.assertEquals('cmd ', self.proto.read_pkt_line())
  47. def test_eof(self):
  48. self.rin.write('0000')
  49. self.rin.seek(0)
  50. self.assertFalse(self.proto.eof())
  51. self.assertEquals(None, self.proto.read_pkt_line())
  52. self.assertTrue(self.proto.eof())
  53. self.assertRaises(HangupException, self.proto.read_pkt_line)
  54. def test_unread_pkt_line(self):
  55. self.rin.write('0007foo0000')
  56. self.rin.seek(0)
  57. self.assertEquals('foo', self.proto.read_pkt_line())
  58. self.proto.unread_pkt_line('bar')
  59. self.assertEquals('bar', self.proto.read_pkt_line())
  60. self.assertEquals(None, self.proto.read_pkt_line())
  61. self.proto.unread_pkt_line('baz1')
  62. self.assertRaises(ValueError, self.proto.unread_pkt_line, 'baz2')
  63. def test_read_pkt_seq(self):
  64. self.rin.write('0008cmd 0005l0000')
  65. self.rin.seek(0)
  66. self.assertEquals(['cmd ', 'l'], list(self.proto.read_pkt_seq()))
  67. def test_read_pkt_line_none(self):
  68. self.rin.write('0000')
  69. self.rin.seek(0)
  70. self.assertEquals(None, self.proto.read_pkt_line())
  71. def test_write_sideband(self):
  72. self.proto.write_sideband(3, 'bloe')
  73. self.assertEquals(self.rout.getvalue(), '0009\x03bloe')
  74. def test_send_cmd(self):
  75. self.proto.send_cmd('fetch', 'a', 'b')
  76. self.assertEquals(self.rout.getvalue(), '000efetch a\x00b\x00')
  77. def test_read_cmd(self):
  78. self.rin.write('0012cmd arg1\x00arg2\x00')
  79. self.rin.seek(0)
  80. self.assertEquals(('cmd', ['arg1', 'arg2']), self.proto.read_cmd())
  81. def test_read_cmd_noend0(self):
  82. self.rin.write('0011cmd arg1\x00arg2')
  83. self.rin.seek(0)
  84. self.assertRaises(AssertionError, self.proto.read_cmd)
  85. class ProtocolTests(BaseProtocolTests, TestCase):
  86. def setUp(self):
  87. TestCase.setUp(self)
  88. self.rout = StringIO()
  89. self.rin = StringIO()
  90. self.proto = Protocol(self.rin.read, self.rout.write)
  91. class ReceivableStringIO(StringIO):
  92. """StringIO with socket-like recv semantics for testing."""
  93. def __init__(self):
  94. StringIO.__init__(self)
  95. self.allow_read_past_eof = False
  96. def recv(self, size):
  97. # fail fast if no bytes are available; in a real socket, this would
  98. # block forever
  99. if self.tell() == len(self.getvalue()) and not self.allow_read_past_eof:
  100. raise AssertionError('Blocking read past end of socket')
  101. if size == 1:
  102. return self.read(1)
  103. # calls shouldn't return quite as much as asked for
  104. return self.read(size - 1)
  105. class ReceivableProtocolTests(BaseProtocolTests, TestCase):
  106. def setUp(self):
  107. TestCase.setUp(self)
  108. self.rout = StringIO()
  109. self.rin = ReceivableStringIO()
  110. self.proto = ReceivableProtocol(self.rin.recv, self.rout.write)
  111. self.proto._rbufsize = 8
  112. def test_eof(self):
  113. # Allow blocking reads past EOF just for this test. The only parts of
  114. # the protocol that might check for EOF do not depend on the recv()
  115. # semantics anyway.
  116. self.rin.allow_read_past_eof = True
  117. BaseProtocolTests.test_eof(self)
  118. def test_recv(self):
  119. all_data = '1234567' * 10 # not a multiple of bufsize
  120. self.rin.write(all_data)
  121. self.rin.seek(0)
  122. data = ''
  123. # We ask for 8 bytes each time and actually read 7, so it should take
  124. # exactly 10 iterations.
  125. for _ in xrange(10):
  126. data += self.proto.recv(10)
  127. # any more reads would block
  128. self.assertRaises(AssertionError, self.proto.recv, 10)
  129. self.assertEquals(all_data, data)
  130. def test_recv_read(self):
  131. all_data = '1234567' # recv exactly in one call
  132. self.rin.write(all_data)
  133. self.rin.seek(0)
  134. self.assertEquals('1234', self.proto.recv(4))
  135. self.assertEquals('567', self.proto.read(3))
  136. self.assertRaises(AssertionError, self.proto.recv, 10)
  137. def test_read_recv(self):
  138. all_data = '12345678abcdefg'
  139. self.rin.write(all_data)
  140. self.rin.seek(0)
  141. self.assertEquals('1234', self.proto.read(4))
  142. self.assertEquals('5678abc', self.proto.recv(8))
  143. self.assertEquals('defg', self.proto.read(4))
  144. self.assertRaises(AssertionError, self.proto.recv, 10)
  145. def test_mixed(self):
  146. # arbitrary non-repeating string
  147. all_data = ','.join(str(i) for i in xrange(100))
  148. self.rin.write(all_data)
  149. self.rin.seek(0)
  150. data = ''
  151. for i in xrange(1, 100):
  152. data += self.proto.recv(i)
  153. # if we get to the end, do a non-blocking read instead of blocking
  154. if len(data) + i > len(all_data):
  155. data += self.proto.recv(i)
  156. # ReceivableStringIO leaves off the last byte unless we ask
  157. # nicely
  158. data += self.proto.recv(1)
  159. break
  160. else:
  161. data += self.proto.read(i)
  162. else:
  163. # didn't break, something must have gone wrong
  164. self.fail()
  165. self.assertEquals(all_data, data)
  166. class CapabilitiesTestCase(TestCase):
  167. def test_plain(self):
  168. self.assertEquals(('bla', []), extract_capabilities('bla'))
  169. def test_caps(self):
  170. self.assertEquals(('bla', ['la']), extract_capabilities('bla\0la'))
  171. self.assertEquals(('bla', ['la']), extract_capabilities('bla\0la\n'))
  172. self.assertEquals(('bla', ['la', 'la']), extract_capabilities('bla\0la la'))
  173. def test_plain_want_line(self):
  174. self.assertEquals(('want bla', []), extract_want_line_capabilities('want bla'))
  175. def test_caps_want_line(self):
  176. self.assertEquals(('want bla', ['la']), extract_want_line_capabilities('want bla la'))
  177. self.assertEquals(('want bla', ['la']), extract_want_line_capabilities('want bla la\n'))
  178. self.assertEquals(('want bla', ['la', 'la']), extract_want_line_capabilities('want bla la la'))
  179. def test_ack_type(self):
  180. self.assertEquals(SINGLE_ACK, ack_type(['foo', 'bar']))
  181. self.assertEquals(MULTI_ACK, ack_type(['foo', 'bar', 'multi_ack']))
  182. self.assertEquals(MULTI_ACK_DETAILED,
  183. ack_type(['foo', 'bar', 'multi_ack_detailed']))
  184. # choose detailed when both present
  185. self.assertEquals(MULTI_ACK_DETAILED,
  186. ack_type(['foo', 'bar', 'multi_ack',
  187. 'multi_ack_detailed']))
  188. class BufferedPktLineWriterTests(TestCase):
  189. def setUp(self):
  190. TestCase.setUp(self)
  191. self._output = StringIO()
  192. self._writer = BufferedPktLineWriter(self._output.write, bufsize=16)
  193. def assertOutputEquals(self, expected):
  194. self.assertEquals(expected, self._output.getvalue())
  195. def _truncate(self):
  196. self._output.seek(0)
  197. self._output.truncate()
  198. def test_write(self):
  199. self._writer.write('foo')
  200. self.assertOutputEquals('')
  201. self._writer.flush()
  202. self.assertOutputEquals('0007foo')
  203. def test_write_none(self):
  204. self._writer.write(None)
  205. self.assertOutputEquals('')
  206. self._writer.flush()
  207. self.assertOutputEquals('0000')
  208. def test_flush_empty(self):
  209. self._writer.flush()
  210. self.assertOutputEquals('')
  211. def test_write_multiple(self):
  212. self._writer.write('foo')
  213. self._writer.write('bar')
  214. self.assertOutputEquals('')
  215. self._writer.flush()
  216. self.assertOutputEquals('0007foo0007bar')
  217. def test_write_across_boundary(self):
  218. self._writer.write('foo')
  219. self._writer.write('barbaz')
  220. self.assertOutputEquals('0007foo000abarba')
  221. self._truncate()
  222. self._writer.flush()
  223. self.assertOutputEquals('z')
  224. def test_write_to_boundary(self):
  225. self._writer.write('foo')
  226. self._writer.write('barba')
  227. self.assertOutputEquals('0007foo0009barba')
  228. self._truncate()
  229. self._writer.write('z')
  230. self._writer.flush()
  231. self.assertOutputEquals('0005z')
  232. class PktLineParserTests(TestCase):
  233. def test_none(self):
  234. pktlines = []
  235. parser = PktLineParser(pktlines.append)
  236. parser.parse("0000")
  237. self.assertEquals(pktlines, [None])
  238. self.assertEquals("", parser.get_tail())
  239. def test_small_fragments(self):
  240. pktlines = []
  241. parser = PktLineParser(pktlines.append)
  242. parser.parse("00")
  243. parser.parse("05")
  244. parser.parse("z0000")
  245. self.assertEquals(pktlines, ["z", None])
  246. self.assertEquals("", parser.get_tail())
  247. def test_multiple_packets(self):
  248. pktlines = []
  249. parser = PktLineParser(pktlines.append)
  250. parser.parse("0005z0006aba")
  251. self.assertEquals(pktlines, ["z", "ab"])
  252. self.assertEquals("a", parser.get_tail())