2
0

test_protocol.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # test_protocol.py -- Tests for the git protocol
  2. # Copyright (C) 2009 Jelmer Vernooij <jelmer@jelmer.uk>
  3. #
  4. # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
  5. # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
  6. # General Public License as published by the Free Software Foundation; version 2.0
  7. # or (at your option) any later version. You can redistribute it and/or
  8. # modify it under the terms of either of these two licenses.
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # You should have received a copy of the licenses; if not, see
  17. # <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
  18. # and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
  19. # License, Version 2.0.
  20. #
  21. """Tests for the smart protocol utility functions."""
  22. from io import BytesIO
  23. from dulwich.errors import HangupException
  24. from dulwich.protocol import (
  25. CAPABILITY_FILTER,
  26. KNOWN_UPLOAD_CAPABILITIES,
  27. MULTI_ACK,
  28. MULTI_ACK_DETAILED,
  29. SINGLE_ACK,
  30. BufferedPktLineWriter,
  31. GitProtocolError,
  32. PktLineParser,
  33. Protocol,
  34. ReceivableProtocol,
  35. ack_type,
  36. extract_capabilities,
  37. extract_want_line_capabilities,
  38. find_capability,
  39. pkt_line,
  40. pkt_seq,
  41. )
  42. from dulwich.refs import filter_ref_prefix
  43. from . import TestCase
  44. class PktLineTests(TestCase):
  45. def test_pkt_line(self) -> None:
  46. self.assertEqual(b"0007bla", pkt_line(b"bla"))
  47. self.assertEqual(b"0000", pkt_line(None))
  48. def test_pkt_seq(self) -> None:
  49. self.assertEqual(b"0007bla0007foo0000", pkt_seq(b"bla", b"foo"))
  50. self.assertEqual(b"0000", pkt_seq())
  51. class FilterRefPrefixTests(TestCase):
  52. def test_filter_ref_prefix(self) -> None:
  53. self.assertEqual(
  54. {b"refs/heads/foo": b"0123456789", b"refs/heads/bar": b"0123456789"},
  55. filter_ref_prefix(
  56. {
  57. b"refs/heads/foo": b"0123456789",
  58. b"refs/heads/bar": b"0123456789",
  59. b"refs/tags/bar": b"0123456789",
  60. },
  61. [b"refs/heads/"],
  62. ),
  63. )
  64. class BaseProtocolTests:
  65. def test_write_pkt_line_none(self) -> None:
  66. self.proto.write_pkt_line(None)
  67. self.assertEqual(self.rout.getvalue(), b"0000")
  68. def test_write_pkt_line(self) -> None:
  69. self.proto.write_pkt_line(b"bla")
  70. self.assertEqual(self.rout.getvalue(), b"0007bla")
  71. def test_read_pkt_line(self) -> None:
  72. self.rin.write(b"0008cmd ")
  73. self.rin.seek(0)
  74. self.assertEqual(b"cmd ", self.proto.read_pkt_line())
  75. def test_eof(self) -> None:
  76. self.rin.write(b"0000")
  77. self.rin.seek(0)
  78. self.assertFalse(self.proto.eof())
  79. self.assertEqual(None, self.proto.read_pkt_line())
  80. self.assertTrue(self.proto.eof())
  81. self.assertRaises(HangupException, self.proto.read_pkt_line)
  82. def test_unread_pkt_line(self) -> None:
  83. self.rin.write(b"0007foo0000")
  84. self.rin.seek(0)
  85. self.assertEqual(b"foo", self.proto.read_pkt_line())
  86. self.proto.unread_pkt_line(b"bar")
  87. self.assertEqual(b"bar", self.proto.read_pkt_line())
  88. self.assertEqual(None, self.proto.read_pkt_line())
  89. self.proto.unread_pkt_line(b"baz1")
  90. self.assertRaises(ValueError, self.proto.unread_pkt_line, b"baz2")
  91. def test_read_pkt_seq(self) -> None:
  92. self.rin.write(b"0008cmd 0005l0000")
  93. self.rin.seek(0)
  94. self.assertEqual([b"cmd ", b"l"], list(self.proto.read_pkt_seq()))
  95. def test_read_pkt_line_none(self) -> None:
  96. self.rin.write(b"0000")
  97. self.rin.seek(0)
  98. self.assertEqual(None, self.proto.read_pkt_line())
  99. def test_read_pkt_line_wrong_size(self) -> None:
  100. self.rin.write(b"0100too short")
  101. self.rin.seek(0)
  102. self.assertRaises(GitProtocolError, self.proto.read_pkt_line)
  103. def test_write_sideband(self) -> None:
  104. self.proto.write_sideband(3, b"bloe")
  105. self.assertEqual(self.rout.getvalue(), b"0009\x03bloe")
  106. def test_send_cmd(self) -> None:
  107. self.proto.send_cmd(b"fetch", b"a", b"b")
  108. self.assertEqual(self.rout.getvalue(), b"000efetch a\x00b\x00")
  109. def test_read_cmd(self) -> None:
  110. self.rin.write(b"0012cmd arg1\x00arg2\x00")
  111. self.rin.seek(0)
  112. self.assertEqual((b"cmd", [b"arg1", b"arg2"]), self.proto.read_cmd())
  113. def test_read_cmd_noend0(self) -> None:
  114. self.rin.write(b"0011cmd arg1\x00arg2")
  115. self.rin.seek(0)
  116. self.assertRaises(AssertionError, self.proto.read_cmd)
  117. class ProtocolTests(BaseProtocolTests, TestCase):
  118. def setUp(self) -> None:
  119. TestCase.setUp(self)
  120. self.rout = BytesIO()
  121. self.rin = BytesIO()
  122. self.proto = Protocol(self.rin.read, self.rout.write)
  123. class ReceivableBytesIO(BytesIO):
  124. """BytesIO with socket-like recv semantics for testing."""
  125. def __init__(self) -> None:
  126. BytesIO.__init__(self)
  127. self.allow_read_past_eof = False
  128. def recv(self, size):
  129. # fail fast if no bytes are available; in a real socket, this would
  130. # block forever
  131. if self.tell() == len(self.getvalue()) and not self.allow_read_past_eof:
  132. raise GitProtocolError("Blocking read past end of socket")
  133. if size == 1:
  134. return self.read(1)
  135. # calls shouldn't return quite as much as asked for
  136. return self.read(size - 1)
  137. class ReceivableProtocolTests(BaseProtocolTests, TestCase):
  138. def setUp(self) -> None:
  139. TestCase.setUp(self)
  140. self.rout = BytesIO()
  141. self.rin = ReceivableBytesIO()
  142. self.proto = ReceivableProtocol(self.rin.recv, self.rout.write)
  143. self.proto._rbufsize = 8
  144. def test_eof(self) -> None:
  145. # Allow blocking reads past EOF just for this test. The only parts of
  146. # the protocol that might check for EOF do not depend on the recv()
  147. # semantics anyway.
  148. self.rin.allow_read_past_eof = True
  149. BaseProtocolTests.test_eof(self)
  150. def test_recv(self) -> None:
  151. all_data = b"1234567" * 10 # not a multiple of bufsize
  152. self.rin.write(all_data)
  153. self.rin.seek(0)
  154. data = b""
  155. # We ask for 8 bytes each time and actually read 7, so it should take
  156. # exactly 10 iterations.
  157. for _ in range(10):
  158. data += self.proto.recv(10)
  159. # any more reads would block
  160. self.assertRaises(GitProtocolError, self.proto.recv, 10)
  161. self.assertEqual(all_data, data)
  162. def test_recv_read(self) -> None:
  163. all_data = b"1234567" # recv exactly in one call
  164. self.rin.write(all_data)
  165. self.rin.seek(0)
  166. self.assertEqual(b"1234", self.proto.recv(4))
  167. self.assertEqual(b"567", self.proto.read(3))
  168. self.assertRaises(GitProtocolError, self.proto.recv, 10)
  169. def test_read_recv(self) -> None:
  170. all_data = b"12345678abcdefg"
  171. self.rin.write(all_data)
  172. self.rin.seek(0)
  173. self.assertEqual(b"1234", self.proto.read(4))
  174. self.assertEqual(b"5678abc", self.proto.recv(8))
  175. self.assertEqual(b"defg", self.proto.read(4))
  176. self.assertRaises(GitProtocolError, self.proto.recv, 10)
  177. def test_mixed(self) -> None:
  178. # arbitrary non-repeating string
  179. all_data = b",".join(str(i).encode("ascii") for i in range(100))
  180. self.rin.write(all_data)
  181. self.rin.seek(0)
  182. data = b""
  183. for i in range(1, 100):
  184. data += self.proto.recv(i)
  185. # if we get to the end, do a non-blocking read instead of blocking
  186. if len(data) + i > len(all_data):
  187. data += self.proto.recv(i)
  188. # ReceivableBytesIO leaves off the last byte unless we ask
  189. # nicely
  190. data += self.proto.recv(1)
  191. break
  192. else:
  193. data += self.proto.read(i)
  194. else:
  195. # didn't break, something must have gone wrong
  196. self.fail()
  197. self.assertEqual(all_data, data)
  198. class CapabilitiesTestCase(TestCase):
  199. def test_plain(self) -> None:
  200. self.assertEqual((b"bla", []), extract_capabilities(b"bla"))
  201. def test_caps(self) -> None:
  202. self.assertEqual((b"bla", [b"la"]), extract_capabilities(b"bla\0la"))
  203. self.assertEqual((b"bla", [b"la"]), extract_capabilities(b"bla\0la\n"))
  204. self.assertEqual((b"bla", [b"la", b"la"]), extract_capabilities(b"bla\0la la"))
  205. def test_plain_want_line(self) -> None:
  206. self.assertEqual((b"want bla", []), extract_want_line_capabilities(b"want bla"))
  207. def test_caps_want_line(self) -> None:
  208. self.assertEqual(
  209. (b"want bla", [b"la"]),
  210. extract_want_line_capabilities(b"want bla la"),
  211. )
  212. self.assertEqual(
  213. (b"want bla", [b"la"]),
  214. extract_want_line_capabilities(b"want bla la\n"),
  215. )
  216. self.assertEqual(
  217. (b"want bla", [b"la", b"la"]),
  218. extract_want_line_capabilities(b"want bla la la"),
  219. )
  220. def test_ack_type(self) -> None:
  221. self.assertEqual(SINGLE_ACK, ack_type([b"foo", b"bar"]))
  222. self.assertEqual(MULTI_ACK, ack_type([b"foo", b"bar", b"multi_ack"]))
  223. self.assertEqual(
  224. MULTI_ACK_DETAILED,
  225. ack_type([b"foo", b"bar", b"multi_ack_detailed"]),
  226. )
  227. # choose detailed when both present
  228. self.assertEqual(
  229. MULTI_ACK_DETAILED,
  230. ack_type([b"foo", b"bar", b"multi_ack", b"multi_ack_detailed"]),
  231. )
  232. class BufferedPktLineWriterTests(TestCase):
  233. def setUp(self) -> None:
  234. TestCase.setUp(self)
  235. self._output = BytesIO()
  236. self._writer = BufferedPktLineWriter(self._output.write, bufsize=16)
  237. def assertOutputEquals(self, expected) -> None:
  238. self.assertEqual(expected, self._output.getvalue())
  239. def _truncate(self) -> None:
  240. self._output.seek(0)
  241. self._output.truncate()
  242. def test_write(self) -> None:
  243. self._writer.write(b"foo")
  244. self.assertOutputEquals(b"")
  245. self._writer.flush()
  246. self.assertOutputEquals(b"0007foo")
  247. def test_write_none(self) -> None:
  248. self._writer.write(None)
  249. self.assertOutputEquals(b"")
  250. self._writer.flush()
  251. self.assertOutputEquals(b"0000")
  252. def test_flush_empty(self) -> None:
  253. self._writer.flush()
  254. self.assertOutputEquals(b"")
  255. def test_write_multiple(self) -> None:
  256. self._writer.write(b"foo")
  257. self._writer.write(b"bar")
  258. self.assertOutputEquals(b"")
  259. self._writer.flush()
  260. self.assertOutputEquals(b"0007foo0007bar")
  261. def test_write_across_boundary(self) -> None:
  262. self._writer.write(b"foo")
  263. self._writer.write(b"barbaz")
  264. self.assertOutputEquals(b"0007foo000abarba")
  265. self._truncate()
  266. self._writer.flush()
  267. self.assertOutputEquals(b"z")
  268. def test_write_to_boundary(self) -> None:
  269. self._writer.write(b"foo")
  270. self._writer.write(b"barba")
  271. self.assertOutputEquals(b"0007foo0009barba")
  272. self._truncate()
  273. self._writer.write(b"z")
  274. self._writer.flush()
  275. self.assertOutputEquals(b"0005z")
  276. class PktLineParserTests(TestCase):
  277. def test_none(self) -> None:
  278. pktlines = []
  279. parser = PktLineParser(pktlines.append)
  280. parser.parse(b"0000")
  281. self.assertEqual(pktlines, [None])
  282. self.assertEqual(b"", parser.get_tail())
  283. def test_small_fragments(self) -> None:
  284. pktlines = []
  285. parser = PktLineParser(pktlines.append)
  286. parser.parse(b"00")
  287. parser.parse(b"05")
  288. parser.parse(b"z0000")
  289. self.assertEqual(pktlines, [b"z", None])
  290. self.assertEqual(b"", parser.get_tail())
  291. def test_multiple_packets(self) -> None:
  292. pktlines = []
  293. parser = PktLineParser(pktlines.append)
  294. parser.parse(b"0005z0006aba")
  295. self.assertEqual(pktlines, [b"z", b"ab"])
  296. self.assertEqual(b"a", parser.get_tail())
  297. class CapabilitiesTests(TestCase):
  298. """Tests for protocol capabilities."""
  299. def test_filter_capability_in_known_upload_capabilities(self) -> None:
  300. """Test that CAPABILITY_FILTER is in KNOWN_UPLOAD_CAPABILITIES."""
  301. self.assertIn(CAPABILITY_FILTER, KNOWN_UPLOAD_CAPABILITIES)
  302. class FindCapabilityTests(TestCase):
  303. """Tests for find_capability function."""
  304. def test_find_capability_with_value(self) -> None:
  305. """Test finding a capability with a value."""
  306. caps = [b"filter=blob:none", b"agent=git/2.0"]
  307. self.assertEqual(b"blob:none", find_capability(caps, b"filter"))
  308. def test_find_capability_without_value(self) -> None:
  309. """Test finding a capability without a value."""
  310. caps = [b"thin-pack", b"ofs-delta"]
  311. self.assertEqual(b"thin-pack", find_capability(caps, b"thin-pack"))
  312. def test_find_capability_not_found(self) -> None:
  313. """Test finding a capability that doesn't exist."""
  314. caps = [b"thin-pack", b"ofs-delta"]
  315. self.assertIsNone(find_capability(caps, b"missing"))
  316. def test_find_capability_multiple_names(self) -> None:
  317. """Test finding with multiple capability names."""
  318. caps = [b"filter=blob:none", b"agent=git/2.0"]
  319. # Should find first match
  320. self.assertEqual(b"git/2.0", find_capability(caps, b"missing", b"agent"))
  321. def test_find_capability_complex_value(self) -> None:
  322. """Test finding capability with complex value."""
  323. caps = [b"filter=combine:blob:none+tree:0"]
  324. self.assertEqual(b"combine:blob:none+tree:0", find_capability(caps, b"filter"))