test_protocol.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320
  1. # test_protocol.py -- Tests for the git protocol
  2. # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
  3. #
  4. # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
  5. # General Public License as public by the Free Software Foundation; version 2.0
  6. # or (at your option) any later version. You can redistribute it and/or
  7. # modify it under the terms of either of these two licenses.
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # You should have received a copy of the licenses; if not, see
  16. # <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
  17. # and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
  18. # License, Version 2.0.
  19. #
  20. """Tests for the smart protocol utility functions."""
  21. from io import BytesIO
  22. from dulwich.errors import (
  23. HangupException,
  24. )
  25. from dulwich.protocol import (
  26. GitProtocolError,
  27. PktLineParser,
  28. Protocol,
  29. ReceivableProtocol,
  30. extract_capabilities,
  31. extract_want_line_capabilities,
  32. ack_type,
  33. SINGLE_ACK,
  34. MULTI_ACK,
  35. MULTI_ACK_DETAILED,
  36. BufferedPktLineWriter,
  37. )
  38. from dulwich.tests import TestCase
  39. class BaseProtocolTests(object):
  40. def test_write_pkt_line_none(self):
  41. self.proto.write_pkt_line(None)
  42. self.assertEqual(self.rout.getvalue(), b'0000')
  43. def test_write_pkt_line(self):
  44. self.proto.write_pkt_line(b'bla')
  45. self.assertEqual(self.rout.getvalue(), b'0007bla')
  46. def test_read_pkt_line(self):
  47. self.rin.write(b'0008cmd ')
  48. self.rin.seek(0)
  49. self.assertEqual(b'cmd ', self.proto.read_pkt_line())
  50. def test_eof(self):
  51. self.rin.write(b'0000')
  52. self.rin.seek(0)
  53. self.assertFalse(self.proto.eof())
  54. self.assertEqual(None, self.proto.read_pkt_line())
  55. self.assertTrue(self.proto.eof())
  56. self.assertRaises(HangupException, self.proto.read_pkt_line)
  57. def test_unread_pkt_line(self):
  58. self.rin.write(b'0007foo0000')
  59. self.rin.seek(0)
  60. self.assertEqual(b'foo', self.proto.read_pkt_line())
  61. self.proto.unread_pkt_line(b'bar')
  62. self.assertEqual(b'bar', self.proto.read_pkt_line())
  63. self.assertEqual(None, self.proto.read_pkt_line())
  64. self.proto.unread_pkt_line(b'baz1')
  65. self.assertRaises(ValueError, self.proto.unread_pkt_line, b'baz2')
  66. def test_read_pkt_seq(self):
  67. self.rin.write(b'0008cmd 0005l0000')
  68. self.rin.seek(0)
  69. self.assertEqual([b'cmd ', b'l'], list(self.proto.read_pkt_seq()))
  70. def test_read_pkt_line_none(self):
  71. self.rin.write(b'0000')
  72. self.rin.seek(0)
  73. self.assertEqual(None, self.proto.read_pkt_line())
  74. def test_read_pkt_line_wrong_size(self):
  75. self.rin.write(b'0100too short')
  76. self.rin.seek(0)
  77. self.assertRaises(GitProtocolError, self.proto.read_pkt_line)
  78. def test_write_sideband(self):
  79. self.proto.write_sideband(3, b'bloe')
  80. self.assertEqual(self.rout.getvalue(), b'0009\x03bloe')
  81. def test_send_cmd(self):
  82. self.proto.send_cmd(b'fetch', b'a', b'b')
  83. self.assertEqual(self.rout.getvalue(), b'000efetch a\x00b\x00')
  84. def test_read_cmd(self):
  85. self.rin.write(b'0012cmd arg1\x00arg2\x00')
  86. self.rin.seek(0)
  87. self.assertEqual((b'cmd', [b'arg1', b'arg2']), self.proto.read_cmd())
  88. def test_read_cmd_noend0(self):
  89. self.rin.write(b'0011cmd arg1\x00arg2')
  90. self.rin.seek(0)
  91. self.assertRaises(AssertionError, self.proto.read_cmd)
  92. class ProtocolTests(BaseProtocolTests, TestCase):
  93. def setUp(self):
  94. TestCase.setUp(self)
  95. self.rout = BytesIO()
  96. self.rin = BytesIO()
  97. self.proto = Protocol(self.rin.read, self.rout.write)
  98. class ReceivableBytesIO(BytesIO):
  99. """BytesIO with socket-like recv semantics for testing."""
  100. def __init__(self):
  101. BytesIO.__init__(self)
  102. self.allow_read_past_eof = False
  103. def recv(self, size):
  104. # fail fast if no bytes are available; in a real socket, this would
  105. # block forever
  106. if self.tell() == len(self.getvalue()) and not self.allow_read_past_eof:
  107. raise GitProtocolError('Blocking read past end of socket')
  108. if size == 1:
  109. return self.read(1)
  110. # calls shouldn't return quite as much as asked for
  111. return self.read(size - 1)
  112. class ReceivableProtocolTests(BaseProtocolTests, TestCase):
  113. def setUp(self):
  114. TestCase.setUp(self)
  115. self.rout = BytesIO()
  116. self.rin = ReceivableBytesIO()
  117. self.proto = ReceivableProtocol(self.rin.recv, self.rout.write)
  118. self.proto._rbufsize = 8
  119. def test_eof(self):
  120. # Allow blocking reads past EOF just for this test. The only parts of
  121. # the protocol that might check for EOF do not depend on the recv()
  122. # semantics anyway.
  123. self.rin.allow_read_past_eof = True
  124. BaseProtocolTests.test_eof(self)
  125. def test_recv(self):
  126. all_data = b'1234567' * 10 # not a multiple of bufsize
  127. self.rin.write(all_data)
  128. self.rin.seek(0)
  129. data = b''
  130. # We ask for 8 bytes each time and actually read 7, so it should take
  131. # exactly 10 iterations.
  132. for _ in range(10):
  133. data += self.proto.recv(10)
  134. # any more reads would block
  135. self.assertRaises(GitProtocolError, self.proto.recv, 10)
  136. self.assertEqual(all_data, data)
  137. def test_recv_read(self):
  138. all_data = b'1234567' # recv exactly in one call
  139. self.rin.write(all_data)
  140. self.rin.seek(0)
  141. self.assertEqual(b'1234', self.proto.recv(4))
  142. self.assertEqual(b'567', self.proto.read(3))
  143. self.assertRaises(GitProtocolError, self.proto.recv, 10)
  144. def test_read_recv(self):
  145. all_data = b'12345678abcdefg'
  146. self.rin.write(all_data)
  147. self.rin.seek(0)
  148. self.assertEqual(b'1234', self.proto.read(4))
  149. self.assertEqual(b'5678abc', self.proto.recv(8))
  150. self.assertEqual(b'defg', self.proto.read(4))
  151. self.assertRaises(GitProtocolError, self.proto.recv, 10)
  152. def test_mixed(self):
  153. # arbitrary non-repeating string
  154. all_data = b','.join(str(i).encode('ascii') for i in range(100))
  155. self.rin.write(all_data)
  156. self.rin.seek(0)
  157. data = b''
  158. for i in range(1, 100):
  159. data += self.proto.recv(i)
  160. # if we get to the end, do a non-blocking read instead of blocking
  161. if len(data) + i > len(all_data):
  162. data += self.proto.recv(i)
  163. # ReceivableBytesIO leaves off the last byte unless we ask
  164. # nicely
  165. data += self.proto.recv(1)
  166. break
  167. else:
  168. data += self.proto.read(i)
  169. else:
  170. # didn't break, something must have gone wrong
  171. self.fail()
  172. self.assertEqual(all_data, data)
  173. class CapabilitiesTestCase(TestCase):
  174. def test_plain(self):
  175. self.assertEqual((b'bla', []), extract_capabilities(b'bla'))
  176. def test_caps(self):
  177. self.assertEqual((b'bla', [b'la']), extract_capabilities(b'bla\0la'))
  178. self.assertEqual((b'bla', [b'la']), extract_capabilities(b'bla\0la\n'))
  179. self.assertEqual((b'bla', [b'la', b'la']), extract_capabilities(b'bla\0la la'))
  180. def test_plain_want_line(self):
  181. self.assertEqual((b'want bla', []), extract_want_line_capabilities(b'want bla'))
  182. def test_caps_want_line(self):
  183. self.assertEqual((b'want bla', [b'la']),
  184. extract_want_line_capabilities(b'want bla la'))
  185. self.assertEqual((b'want bla', [b'la']),
  186. extract_want_line_capabilities(b'want bla la\n'))
  187. self.assertEqual((b'want bla', [b'la', b'la']),
  188. extract_want_line_capabilities(b'want bla la la'))
  189. def test_ack_type(self):
  190. self.assertEqual(SINGLE_ACK, ack_type([b'foo', b'bar']))
  191. self.assertEqual(MULTI_ACK, ack_type([b'foo', b'bar', b'multi_ack']))
  192. self.assertEqual(MULTI_ACK_DETAILED,
  193. ack_type([b'foo', b'bar', b'multi_ack_detailed']))
  194. # choose detailed when both present
  195. self.assertEqual(MULTI_ACK_DETAILED,
  196. ack_type([b'foo', b'bar', b'multi_ack',
  197. b'multi_ack_detailed']))
  198. class BufferedPktLineWriterTests(TestCase):
  199. def setUp(self):
  200. TestCase.setUp(self)
  201. self._output = BytesIO()
  202. self._writer = BufferedPktLineWriter(self._output.write, bufsize=16)
  203. def assertOutputEquals(self, expected):
  204. self.assertEqual(expected, self._output.getvalue())
  205. def _truncate(self):
  206. self._output.seek(0)
  207. self._output.truncate()
  208. def test_write(self):
  209. self._writer.write(b'foo')
  210. self.assertOutputEquals(b'')
  211. self._writer.flush()
  212. self.assertOutputEquals(b'0007foo')
  213. def test_write_none(self):
  214. self._writer.write(None)
  215. self.assertOutputEquals(b'')
  216. self._writer.flush()
  217. self.assertOutputEquals(b'0000')
  218. def test_flush_empty(self):
  219. self._writer.flush()
  220. self.assertOutputEquals(b'')
  221. def test_write_multiple(self):
  222. self._writer.write(b'foo')
  223. self._writer.write(b'bar')
  224. self.assertOutputEquals(b'')
  225. self._writer.flush()
  226. self.assertOutputEquals(b'0007foo0007bar')
  227. def test_write_across_boundary(self):
  228. self._writer.write(b'foo')
  229. self._writer.write(b'barbaz')
  230. self.assertOutputEquals(b'0007foo000abarba')
  231. self._truncate()
  232. self._writer.flush()
  233. self.assertOutputEquals(b'z')
  234. def test_write_to_boundary(self):
  235. self._writer.write(b'foo')
  236. self._writer.write(b'barba')
  237. self.assertOutputEquals(b'0007foo0009barba')
  238. self._truncate()
  239. self._writer.write(b'z')
  240. self._writer.flush()
  241. self.assertOutputEquals(b'0005z')
  242. class PktLineParserTests(TestCase):
  243. def test_none(self):
  244. pktlines = []
  245. parser = PktLineParser(pktlines.append)
  246. parser.parse(b"0000")
  247. self.assertEqual(pktlines, [None])
  248. self.assertEqual(b"", parser.get_tail())
  249. def test_small_fragments(self):
  250. pktlines = []
  251. parser = PktLineParser(pktlines.append)
  252. parser.parse(b"00")
  253. parser.parse(b"05")
  254. parser.parse(b"z0000")
  255. self.assertEqual(pktlines, [b"z", None])
  256. self.assertEqual(b"", parser.get_tail())
  257. def test_multiple_packets(self):
  258. pktlines = []
  259. parser = PktLineParser(pktlines.append)
  260. parser.parse(b"0005z0006aba")
  261. self.assertEqual(pktlines, [b"z", b"ab"])
  262. self.assertEqual(b"a", parser.get_tail())