test_protocol.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. # test_protocol.py -- Tests for the git protocol
  2. # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
  3. #
  4. # This program is free software; you can redistribute it and/or
  5. # modify it under the terms of the GNU General Public License
  6. # as published by the Free Software Foundation; version 2
  7. # or (at your option) any later version of the License.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU General Public License
  15. # along with this program; if not, write to the Free Software
  16. # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
  17. # MA 02110-1301, USA.
  18. """Tests for the smart protocol utility functions."""
  19. from StringIO import StringIO
  20. from dulwich.protocol import (
  21. Protocol,
  22. ReceivableProtocol,
  23. extract_capabilities,
  24. extract_want_line_capabilities,
  25. ack_type,
  26. SINGLE_ACK,
  27. MULTI_ACK,
  28. MULTI_ACK_DETAILED,
  29. BufferedPktLineWriter,
  30. )
  31. from dulwich.tests import TestCase
  32. class BaseProtocolTests(object):
  33. def test_write_pkt_line_none(self):
  34. self.proto.write_pkt_line(None)
  35. self.assertEquals(self.rout.getvalue(), "0000")
  36. def test_write_pkt_line(self):
  37. self.proto.write_pkt_line("bla")
  38. self.assertEquals(self.rout.getvalue(), "0007bla")
  39. def test_read_pkt_line(self):
  40. self.rin.write("0008cmd ")
  41. self.rin.seek(0)
  42. self.assertEquals("cmd ", self.proto.read_pkt_line())
  43. def test_read_pkt_seq(self):
  44. self.rin.write("0008cmd 0005l0000")
  45. self.rin.seek(0)
  46. self.assertEquals(["cmd ", "l"], list(self.proto.read_pkt_seq()))
  47. def test_read_pkt_line_none(self):
  48. self.rin.write("0000")
  49. self.rin.seek(0)
  50. self.assertEquals(None, self.proto.read_pkt_line())
  51. def test_write_sideband(self):
  52. self.proto.write_sideband(3, "bloe")
  53. self.assertEquals(self.rout.getvalue(), "0009\x03bloe")
  54. def test_send_cmd(self):
  55. self.proto.send_cmd("fetch", "a", "b")
  56. self.assertEquals(self.rout.getvalue(), "000efetch a\x00b\x00")
  57. def test_read_cmd(self):
  58. self.rin.write("0012cmd arg1\x00arg2\x00")
  59. self.rin.seek(0)
  60. self.assertEquals(("cmd", ["arg1", "arg2"]), self.proto.read_cmd())
  61. def test_read_cmd_noend0(self):
  62. self.rin.write("0011cmd arg1\x00arg2")
  63. self.rin.seek(0)
  64. self.assertRaises(AssertionError, self.proto.read_cmd)
  65. class ProtocolTests(BaseProtocolTests, TestCase):
  66. def setUp(self):
  67. TestCase.setUp(self)
  68. self.rout = StringIO()
  69. self.rin = StringIO()
  70. self.proto = Protocol(self.rin.read, self.rout.write)
  71. class ReceivableStringIO(StringIO):
  72. """StringIO with socket-like recv semantics for testing."""
  73. def recv(self, size):
  74. # fail fast if no bytes are available; in a real socket, this would
  75. # block forever
  76. if self.tell() == len(self.getvalue()):
  77. raise AssertionError("Blocking read past end of socket")
  78. if size == 1:
  79. return self.read(1)
  80. # calls shouldn't return quite as much as asked for
  81. return self.read(size - 1)
  82. class ReceivableProtocolTests(BaseProtocolTests, TestCase):
  83. def setUp(self):
  84. TestCase.setUp(self)
  85. self.rout = StringIO()
  86. self.rin = ReceivableStringIO()
  87. self.proto = ReceivableProtocol(self.rin.recv, self.rout.write)
  88. self.proto._rbufsize = 8
  89. def test_recv(self):
  90. all_data = "1234567" * 10 # not a multiple of bufsize
  91. self.rin.write(all_data)
  92. self.rin.seek(0)
  93. data = ""
  94. # We ask for 8 bytes each time and actually read 7, so it should take
  95. # exactly 10 iterations.
  96. for _ in xrange(10):
  97. data += self.proto.recv(10)
  98. # any more reads would block
  99. self.assertRaises(AssertionError, self.proto.recv, 10)
  100. self.assertEquals(all_data, data)
  101. def test_recv_read(self):
  102. all_data = "1234567" # recv exactly in one call
  103. self.rin.write(all_data)
  104. self.rin.seek(0)
  105. self.assertEquals("1234", self.proto.recv(4))
  106. self.assertEquals("567", self.proto.read(3))
  107. self.assertRaises(AssertionError, self.proto.recv, 10)
  108. def test_read_recv(self):
  109. all_data = "12345678abcdefg"
  110. self.rin.write(all_data)
  111. self.rin.seek(0)
  112. self.assertEquals("1234", self.proto.read(4))
  113. self.assertEquals("5678abc", self.proto.recv(8))
  114. self.assertEquals("defg", self.proto.read(4))
  115. self.assertRaises(AssertionError, self.proto.recv, 10)
  116. def test_mixed(self):
  117. # arbitrary non-repeating string
  118. all_data = ",".join(str(i) for i in xrange(100))
  119. self.rin.write(all_data)
  120. self.rin.seek(0)
  121. data = ""
  122. for i in xrange(1, 100):
  123. data += self.proto.recv(i)
  124. # if we get to the end, do a non-blocking read instead of blocking
  125. if len(data) + i > len(all_data):
  126. data += self.proto.recv(i)
  127. # ReceivableStringIO leaves off the last byte unless we ask
  128. # nicely
  129. data += self.proto.recv(1)
  130. break
  131. else:
  132. data += self.proto.read(i)
  133. else:
  134. # didn't break, something must have gone wrong
  135. self.fail()
  136. self.assertEquals(all_data, data)
  137. class CapabilitiesTestCase(TestCase):
  138. def test_plain(self):
  139. self.assertEquals(("bla", []), extract_capabilities("bla"))
  140. def test_caps(self):
  141. self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la"))
  142. self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la\n"))
  143. self.assertEquals(("bla", ["la", "la"]), extract_capabilities("bla\0la la"))
  144. def test_plain_want_line(self):
  145. self.assertEquals(("want bla", []), extract_want_line_capabilities("want bla"))
  146. def test_caps_want_line(self):
  147. self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la"))
  148. self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la\n"))
  149. self.assertEquals(("want bla", ["la", "la"]), extract_want_line_capabilities("want bla la la"))
  150. def test_ack_type(self):
  151. self.assertEquals(SINGLE_ACK, ack_type(['foo', 'bar']))
  152. self.assertEquals(MULTI_ACK, ack_type(['foo', 'bar', 'multi_ack']))
  153. self.assertEquals(MULTI_ACK_DETAILED,
  154. ack_type(['foo', 'bar', 'multi_ack_detailed']))
  155. # choose detailed when both present
  156. self.assertEquals(MULTI_ACK_DETAILED,
  157. ack_type(['foo', 'bar', 'multi_ack',
  158. 'multi_ack_detailed']))
  159. class BufferedPktLineWriterTests(TestCase):
  160. def setUp(self):
  161. self._output = StringIO()
  162. self._writer = BufferedPktLineWriter(self._output.write, bufsize=16)
  163. def assertOutputEquals(self, expected):
  164. self.assertEquals(expected, self._output.getvalue())
  165. def _truncate(self):
  166. self._output.seek(0)
  167. self._output.truncate()
  168. def test_write(self):
  169. self._writer.write('foo')
  170. self.assertOutputEquals('')
  171. self._writer.flush()
  172. self.assertOutputEquals('0007foo')
  173. def test_write_none(self):
  174. self._writer.write(None)
  175. self.assertOutputEquals('')
  176. self._writer.flush()
  177. self.assertOutputEquals('0000')
  178. def test_flush_empty(self):
  179. self._writer.flush()
  180. self.assertOutputEquals('')
  181. def test_write_multiple(self):
  182. self._writer.write('foo')
  183. self._writer.write('bar')
  184. self.assertOutputEquals('')
  185. self._writer.flush()
  186. self.assertOutputEquals('0007foo0007bar')
  187. def test_write_across_boundary(self):
  188. self._writer.write('foo')
  189. self._writer.write('barbaz')
  190. self.assertOutputEquals('0007foo000abarba')
  191. self._truncate()
  192. self._writer.flush()
  193. self.assertOutputEquals('z')
  194. def test_write_to_boundary(self):
  195. self._writer.write('foo')
  196. self._writer.write('barba')
  197. self.assertOutputEquals('0007foo0009barba')
  198. self._truncate()
  199. self._writer.write('z')
  200. self._writer.flush()
  201. self.assertOutputEquals('0005z')