test_protocol.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333
  1. # test_protocol.py -- Tests for the git protocol
  2. # Copyright (C) 2009 Jelmer Vernooij <jelmer@jelmer.uk>
  3. #
  4. # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
  5. # General Public License as public by the Free Software Foundation; version 2.0
  6. # or (at your option) any later version. You can redistribute it and/or
  7. # modify it under the terms of either of these two licenses.
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # You should have received a copy of the licenses; if not, see
  16. # <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
  17. # and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
  18. # License, Version 2.0.
  19. #
  20. """Tests for the smart protocol utility functions."""
  21. from io import BytesIO
  22. from dulwich.errors import HangupException
  23. from dulwich.protocol import (
  24. MULTI_ACK,
  25. MULTI_ACK_DETAILED,
  26. SINGLE_ACK,
  27. BufferedPktLineWriter,
  28. GitProtocolError,
  29. PktLineParser,
  30. Protocol,
  31. ReceivableProtocol,
  32. ack_type,
  33. extract_capabilities,
  34. extract_want_line_capabilities,
  35. pkt_line,
  36. pkt_seq,
  37. )
  38. from . import TestCase
  39. class PktLinetests:
  40. def test_pkt_line(self):
  41. self.assertEqual(b"0007bla", pkt_line(b"bla"))
  42. self.assertEqual(b"0000", pkt_line(None))
  43. def test_pkt_seq(self):
  44. self.assertEqual(b"0007bla0003foo0000", pkt_seq([b"bla", b"foo"]))
  45. self.assertEqual(b"0000", pkt_seq([]))
  46. class BaseProtocolTests:
  47. def test_write_pkt_line_none(self):
  48. self.proto.write_pkt_line(None)
  49. self.assertEqual(self.rout.getvalue(), b"0000")
  50. def test_write_pkt_line(self):
  51. self.proto.write_pkt_line(b"bla")
  52. self.assertEqual(self.rout.getvalue(), b"0007bla")
  53. def test_read_pkt_line(self):
  54. self.rin.write(b"0008cmd ")
  55. self.rin.seek(0)
  56. self.assertEqual(b"cmd ", self.proto.read_pkt_line())
  57. def test_eof(self):
  58. self.rin.write(b"0000")
  59. self.rin.seek(0)
  60. self.assertFalse(self.proto.eof())
  61. self.assertEqual(None, self.proto.read_pkt_line())
  62. self.assertTrue(self.proto.eof())
  63. self.assertRaises(HangupException, self.proto.read_pkt_line)
  64. def test_unread_pkt_line(self):
  65. self.rin.write(b"0007foo0000")
  66. self.rin.seek(0)
  67. self.assertEqual(b"foo", self.proto.read_pkt_line())
  68. self.proto.unread_pkt_line(b"bar")
  69. self.assertEqual(b"bar", self.proto.read_pkt_line())
  70. self.assertEqual(None, self.proto.read_pkt_line())
  71. self.proto.unread_pkt_line(b"baz1")
  72. self.assertRaises(ValueError, self.proto.unread_pkt_line, b"baz2")
  73. def test_read_pkt_seq(self):
  74. self.rin.write(b"0008cmd 0005l0000")
  75. self.rin.seek(0)
  76. self.assertEqual([b"cmd ", b"l"], list(self.proto.read_pkt_seq()))
  77. def test_read_pkt_line_none(self):
  78. self.rin.write(b"0000")
  79. self.rin.seek(0)
  80. self.assertEqual(None, self.proto.read_pkt_line())
  81. def test_read_pkt_line_wrong_size(self):
  82. self.rin.write(b"0100too short")
  83. self.rin.seek(0)
  84. self.assertRaises(GitProtocolError, self.proto.read_pkt_line)
  85. def test_write_sideband(self):
  86. self.proto.write_sideband(3, b"bloe")
  87. self.assertEqual(self.rout.getvalue(), b"0009\x03bloe")
  88. def test_send_cmd(self):
  89. self.proto.send_cmd(b"fetch", b"a", b"b")
  90. self.assertEqual(self.rout.getvalue(), b"000efetch a\x00b\x00")
  91. def test_read_cmd(self):
  92. self.rin.write(b"0012cmd arg1\x00arg2\x00")
  93. self.rin.seek(0)
  94. self.assertEqual((b"cmd", [b"arg1", b"arg2"]), self.proto.read_cmd())
  95. def test_read_cmd_noend0(self):
  96. self.rin.write(b"0011cmd arg1\x00arg2")
  97. self.rin.seek(0)
  98. self.assertRaises(AssertionError, self.proto.read_cmd)
  99. class ProtocolTests(BaseProtocolTests, TestCase):
  100. def setUp(self):
  101. TestCase.setUp(self)
  102. self.rout = BytesIO()
  103. self.rin = BytesIO()
  104. self.proto = Protocol(self.rin.read, self.rout.write)
  105. class ReceivableBytesIO(BytesIO):
  106. """BytesIO with socket-like recv semantics for testing."""
  107. def __init__(self) -> None:
  108. BytesIO.__init__(self)
  109. self.allow_read_past_eof = False
  110. def recv(self, size):
  111. # fail fast if no bytes are available; in a real socket, this would
  112. # block forever
  113. if self.tell() == len(self.getvalue()) and not self.allow_read_past_eof:
  114. raise GitProtocolError("Blocking read past end of socket")
  115. if size == 1:
  116. return self.read(1)
  117. # calls shouldn't return quite as much as asked for
  118. return self.read(size - 1)
  119. class ReceivableProtocolTests(BaseProtocolTests, TestCase):
  120. def setUp(self):
  121. TestCase.setUp(self)
  122. self.rout = BytesIO()
  123. self.rin = ReceivableBytesIO()
  124. self.proto = ReceivableProtocol(self.rin.recv, self.rout.write)
  125. self.proto._rbufsize = 8
  126. def test_eof(self):
  127. # Allow blocking reads past EOF just for this test. The only parts of
  128. # the protocol that might check for EOF do not depend on the recv()
  129. # semantics anyway.
  130. self.rin.allow_read_past_eof = True
  131. BaseProtocolTests.test_eof(self)
  132. def test_recv(self):
  133. all_data = b"1234567" * 10 # not a multiple of bufsize
  134. self.rin.write(all_data)
  135. self.rin.seek(0)
  136. data = b""
  137. # We ask for 8 bytes each time and actually read 7, so it should take
  138. # exactly 10 iterations.
  139. for _ in range(10):
  140. data += self.proto.recv(10)
  141. # any more reads would block
  142. self.assertRaises(GitProtocolError, self.proto.recv, 10)
  143. self.assertEqual(all_data, data)
  144. def test_recv_read(self):
  145. all_data = b"1234567" # recv exactly in one call
  146. self.rin.write(all_data)
  147. self.rin.seek(0)
  148. self.assertEqual(b"1234", self.proto.recv(4))
  149. self.assertEqual(b"567", self.proto.read(3))
  150. self.assertRaises(GitProtocolError, self.proto.recv, 10)
  151. def test_read_recv(self):
  152. all_data = b"12345678abcdefg"
  153. self.rin.write(all_data)
  154. self.rin.seek(0)
  155. self.assertEqual(b"1234", self.proto.read(4))
  156. self.assertEqual(b"5678abc", self.proto.recv(8))
  157. self.assertEqual(b"defg", self.proto.read(4))
  158. self.assertRaises(GitProtocolError, self.proto.recv, 10)
  159. def test_mixed(self):
  160. # arbitrary non-repeating string
  161. all_data = b",".join(str(i).encode("ascii") for i in range(100))
  162. self.rin.write(all_data)
  163. self.rin.seek(0)
  164. data = b""
  165. for i in range(1, 100):
  166. data += self.proto.recv(i)
  167. # if we get to the end, do a non-blocking read instead of blocking
  168. if len(data) + i > len(all_data):
  169. data += self.proto.recv(i)
  170. # ReceivableBytesIO leaves off the last byte unless we ask
  171. # nicely
  172. data += self.proto.recv(1)
  173. break
  174. else:
  175. data += self.proto.read(i)
  176. else:
  177. # didn't break, something must have gone wrong
  178. self.fail()
  179. self.assertEqual(all_data, data)
  180. class CapabilitiesTestCase(TestCase):
  181. def test_plain(self):
  182. self.assertEqual((b"bla", []), extract_capabilities(b"bla"))
  183. def test_caps(self):
  184. self.assertEqual((b"bla", [b"la"]), extract_capabilities(b"bla\0la"))
  185. self.assertEqual((b"bla", [b"la"]), extract_capabilities(b"bla\0la\n"))
  186. self.assertEqual((b"bla", [b"la", b"la"]), extract_capabilities(b"bla\0la la"))
  187. def test_plain_want_line(self):
  188. self.assertEqual((b"want bla", []), extract_want_line_capabilities(b"want bla"))
  189. def test_caps_want_line(self):
  190. self.assertEqual(
  191. (b"want bla", [b"la"]),
  192. extract_want_line_capabilities(b"want bla la"),
  193. )
  194. self.assertEqual(
  195. (b"want bla", [b"la"]),
  196. extract_want_line_capabilities(b"want bla la\n"),
  197. )
  198. self.assertEqual(
  199. (b"want bla", [b"la", b"la"]),
  200. extract_want_line_capabilities(b"want bla la la"),
  201. )
  202. def test_ack_type(self):
  203. self.assertEqual(SINGLE_ACK, ack_type([b"foo", b"bar"]))
  204. self.assertEqual(MULTI_ACK, ack_type([b"foo", b"bar", b"multi_ack"]))
  205. self.assertEqual(
  206. MULTI_ACK_DETAILED,
  207. ack_type([b"foo", b"bar", b"multi_ack_detailed"]),
  208. )
  209. # choose detailed when both present
  210. self.assertEqual(
  211. MULTI_ACK_DETAILED,
  212. ack_type([b"foo", b"bar", b"multi_ack", b"multi_ack_detailed"]),
  213. )
  214. class BufferedPktLineWriterTests(TestCase):
  215. def setUp(self):
  216. TestCase.setUp(self)
  217. self._output = BytesIO()
  218. self._writer = BufferedPktLineWriter(self._output.write, bufsize=16)
  219. def assertOutputEquals(self, expected):
  220. self.assertEqual(expected, self._output.getvalue())
  221. def _truncate(self):
  222. self._output.seek(0)
  223. self._output.truncate()
  224. def test_write(self):
  225. self._writer.write(b"foo")
  226. self.assertOutputEquals(b"")
  227. self._writer.flush()
  228. self.assertOutputEquals(b"0007foo")
  229. def test_write_none(self):
  230. self._writer.write(None)
  231. self.assertOutputEquals(b"")
  232. self._writer.flush()
  233. self.assertOutputEquals(b"0000")
  234. def test_flush_empty(self):
  235. self._writer.flush()
  236. self.assertOutputEquals(b"")
  237. def test_write_multiple(self):
  238. self._writer.write(b"foo")
  239. self._writer.write(b"bar")
  240. self.assertOutputEquals(b"")
  241. self._writer.flush()
  242. self.assertOutputEquals(b"0007foo0007bar")
  243. def test_write_across_boundary(self):
  244. self._writer.write(b"foo")
  245. self._writer.write(b"barbaz")
  246. self.assertOutputEquals(b"0007foo000abarba")
  247. self._truncate()
  248. self._writer.flush()
  249. self.assertOutputEquals(b"z")
  250. def test_write_to_boundary(self):
  251. self._writer.write(b"foo")
  252. self._writer.write(b"barba")
  253. self.assertOutputEquals(b"0007foo0009barba")
  254. self._truncate()
  255. self._writer.write(b"z")
  256. self._writer.flush()
  257. self.assertOutputEquals(b"0005z")
  258. class PktLineParserTests(TestCase):
  259. def test_none(self):
  260. pktlines = []
  261. parser = PktLineParser(pktlines.append)
  262. parser.parse(b"0000")
  263. self.assertEqual(pktlines, [None])
  264. self.assertEqual(b"", parser.get_tail())
  265. def test_small_fragments(self):
  266. pktlines = []
  267. parser = PktLineParser(pktlines.append)
  268. parser.parse(b"00")
  269. parser.parse(b"05")
  270. parser.parse(b"z0000")
  271. self.assertEqual(pktlines, [b"z", None])
  272. self.assertEqual(b"", parser.get_tail())
  273. def test_multiple_packets(self):
  274. pktlines = []
  275. parser = PktLineParser(pktlines.append)
  276. parser.parse(b"0005z0006aba")
  277. self.assertEqual(pktlines, [b"z", b"ab"])
  278. self.assertEqual(b"a", parser.get_tail())