# test_protocol.py -- Tests for the git protocol # Copyright (C) 2009 Jelmer Vernooij # # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU # General Public License as public by the Free Software Foundation; version 2.0 # or (at your option) any later version. You can redistribute it and/or # modify it under the terms of either of these two licenses. # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # You should have received a copy of the licenses; if not, see # for a copy of the GNU General Public License # and for a copy of the Apache # License, Version 2.0. # """Tests for the smart protocol utility functions.""" from io import BytesIO from dulwich.errors import HangupException from dulwich.protocol import ( MULTI_ACK, MULTI_ACK_DETAILED, SINGLE_ACK, BufferedPktLineWriter, GitProtocolError, PktLineParser, Protocol, ReceivableProtocol, ack_type, extract_capabilities, extract_want_line_capabilities, filter_ref_prefix, pkt_line, pkt_seq, ) from . import TestCase class PktLineTests(TestCase): def test_pkt_line(self) -> None: self.assertEqual(b"0007bla", pkt_line(b"bla")) self.assertEqual(b"0000", pkt_line(None)) def test_pkt_seq(self) -> None: self.assertEqual(b"0007bla0007foo0000", pkt_seq(b"bla", b"foo")) self.assertEqual(b"0000", pkt_seq()) class FilterRefPrefixTests(TestCase): def test_filter_ref_prefix(self) -> None: self.assertEqual( {b"refs/heads/foo": b"0123456789", b"refs/heads/bar": b"0123456789"}, filter_ref_prefix( { b"refs/heads/foo": b"0123456789", b"refs/heads/bar": b"0123456789", b"refs/tags/bar": b"0123456789", }, [b"refs/heads/"], ), ) class BaseProtocolTests: def test_write_pkt_line_none(self) -> None: self.proto.write_pkt_line(None) self.assertEqual(self.rout.getvalue(), b"0000") def test_write_pkt_line(self) -> None: self.proto.write_pkt_line(b"bla") self.assertEqual(self.rout.getvalue(), b"0007bla") def test_read_pkt_line(self) -> None: self.rin.write(b"0008cmd ") self.rin.seek(0) self.assertEqual(b"cmd ", self.proto.read_pkt_line()) def test_eof(self) -> None: self.rin.write(b"0000") self.rin.seek(0) self.assertFalse(self.proto.eof()) self.assertEqual(None, self.proto.read_pkt_line()) self.assertTrue(self.proto.eof()) self.assertRaises(HangupException, self.proto.read_pkt_line) def test_unread_pkt_line(self) -> None: self.rin.write(b"0007foo0000") self.rin.seek(0) self.assertEqual(b"foo", self.proto.read_pkt_line()) self.proto.unread_pkt_line(b"bar") self.assertEqual(b"bar", self.proto.read_pkt_line()) self.assertEqual(None, self.proto.read_pkt_line()) self.proto.unread_pkt_line(b"baz1") self.assertRaises(ValueError, self.proto.unread_pkt_line, b"baz2") def test_read_pkt_seq(self) -> None: self.rin.write(b"0008cmd 0005l0000") self.rin.seek(0) self.assertEqual([b"cmd ", b"l"], list(self.proto.read_pkt_seq())) def test_read_pkt_line_none(self) -> None: self.rin.write(b"0000") self.rin.seek(0) self.assertEqual(None, self.proto.read_pkt_line()) def test_read_pkt_line_wrong_size(self) -> None: self.rin.write(b"0100too short") self.rin.seek(0) self.assertRaises(GitProtocolError, self.proto.read_pkt_line) def test_write_sideband(self) -> None: self.proto.write_sideband(3, b"bloe") self.assertEqual(self.rout.getvalue(), b"0009\x03bloe") def test_send_cmd(self) -> None: self.proto.send_cmd(b"fetch", b"a", b"b") self.assertEqual(self.rout.getvalue(), b"000efetch a\x00b\x00") def test_read_cmd(self) -> None: self.rin.write(b"0012cmd arg1\x00arg2\x00") self.rin.seek(0) self.assertEqual((b"cmd", [b"arg1", b"arg2"]), self.proto.read_cmd()) def test_read_cmd_noend0(self) -> None: self.rin.write(b"0011cmd arg1\x00arg2") self.rin.seek(0) self.assertRaises(AssertionError, self.proto.read_cmd) class ProtocolTests(BaseProtocolTests, TestCase): def setUp(self) -> None: TestCase.setUp(self) self.rout = BytesIO() self.rin = BytesIO() self.proto = Protocol(self.rin.read, self.rout.write) class ReceivableBytesIO(BytesIO): """BytesIO with socket-like recv semantics for testing.""" def __init__(self) -> None: BytesIO.__init__(self) self.allow_read_past_eof = False def recv(self, size): # fail fast if no bytes are available; in a real socket, this would # block forever if self.tell() == len(self.getvalue()) and not self.allow_read_past_eof: raise GitProtocolError("Blocking read past end of socket") if size == 1: return self.read(1) # calls shouldn't return quite as much as asked for return self.read(size - 1) class ReceivableProtocolTests(BaseProtocolTests, TestCase): def setUp(self) -> None: TestCase.setUp(self) self.rout = BytesIO() self.rin = ReceivableBytesIO() self.proto = ReceivableProtocol(self.rin.recv, self.rout.write) self.proto._rbufsize = 8 def test_eof(self) -> None: # Allow blocking reads past EOF just for this test. The only parts of # the protocol that might check for EOF do not depend on the recv() # semantics anyway. self.rin.allow_read_past_eof = True BaseProtocolTests.test_eof(self) def test_recv(self) -> None: all_data = b"1234567" * 10 # not a multiple of bufsize self.rin.write(all_data) self.rin.seek(0) data = b"" # We ask for 8 bytes each time and actually read 7, so it should take # exactly 10 iterations. for _ in range(10): data += self.proto.recv(10) # any more reads would block self.assertRaises(GitProtocolError, self.proto.recv, 10) self.assertEqual(all_data, data) def test_recv_read(self) -> None: all_data = b"1234567" # recv exactly in one call self.rin.write(all_data) self.rin.seek(0) self.assertEqual(b"1234", self.proto.recv(4)) self.assertEqual(b"567", self.proto.read(3)) self.assertRaises(GitProtocolError, self.proto.recv, 10) def test_read_recv(self) -> None: all_data = b"12345678abcdefg" self.rin.write(all_data) self.rin.seek(0) self.assertEqual(b"1234", self.proto.read(4)) self.assertEqual(b"5678abc", self.proto.recv(8)) self.assertEqual(b"defg", self.proto.read(4)) self.assertRaises(GitProtocolError, self.proto.recv, 10) def test_mixed(self) -> None: # arbitrary non-repeating string all_data = b",".join(str(i).encode("ascii") for i in range(100)) self.rin.write(all_data) self.rin.seek(0) data = b"" for i in range(1, 100): data += self.proto.recv(i) # if we get to the end, do a non-blocking read instead of blocking if len(data) + i > len(all_data): data += self.proto.recv(i) # ReceivableBytesIO leaves off the last byte unless we ask # nicely data += self.proto.recv(1) break else: data += self.proto.read(i) else: # didn't break, something must have gone wrong self.fail() self.assertEqual(all_data, data) class CapabilitiesTestCase(TestCase): def test_plain(self) -> None: self.assertEqual((b"bla", []), extract_capabilities(b"bla")) def test_caps(self) -> None: self.assertEqual((b"bla", [b"la"]), extract_capabilities(b"bla\0la")) self.assertEqual((b"bla", [b"la"]), extract_capabilities(b"bla\0la\n")) self.assertEqual((b"bla", [b"la", b"la"]), extract_capabilities(b"bla\0la la")) def test_plain_want_line(self) -> None: self.assertEqual((b"want bla", []), extract_want_line_capabilities(b"want bla")) def test_caps_want_line(self) -> None: self.assertEqual( (b"want bla", [b"la"]), extract_want_line_capabilities(b"want bla la"), ) self.assertEqual( (b"want bla", [b"la"]), extract_want_line_capabilities(b"want bla la\n"), ) self.assertEqual( (b"want bla", [b"la", b"la"]), extract_want_line_capabilities(b"want bla la la"), ) def test_ack_type(self) -> None: self.assertEqual(SINGLE_ACK, ack_type([b"foo", b"bar"])) self.assertEqual(MULTI_ACK, ack_type([b"foo", b"bar", b"multi_ack"])) self.assertEqual( MULTI_ACK_DETAILED, ack_type([b"foo", b"bar", b"multi_ack_detailed"]), ) # choose detailed when both present self.assertEqual( MULTI_ACK_DETAILED, ack_type([b"foo", b"bar", b"multi_ack", b"multi_ack_detailed"]), ) class BufferedPktLineWriterTests(TestCase): def setUp(self) -> None: TestCase.setUp(self) self._output = BytesIO() self._writer = BufferedPktLineWriter(self._output.write, bufsize=16) def assertOutputEquals(self, expected) -> None: self.assertEqual(expected, self._output.getvalue()) def _truncate(self) -> None: self._output.seek(0) self._output.truncate() def test_write(self) -> None: self._writer.write(b"foo") self.assertOutputEquals(b"") self._writer.flush() self.assertOutputEquals(b"0007foo") def test_write_none(self) -> None: self._writer.write(None) self.assertOutputEquals(b"") self._writer.flush() self.assertOutputEquals(b"0000") def test_flush_empty(self) -> None: self._writer.flush() self.assertOutputEquals(b"") def test_write_multiple(self) -> None: self._writer.write(b"foo") self._writer.write(b"bar") self.assertOutputEquals(b"") self._writer.flush() self.assertOutputEquals(b"0007foo0007bar") def test_write_across_boundary(self) -> None: self._writer.write(b"foo") self._writer.write(b"barbaz") self.assertOutputEquals(b"0007foo000abarba") self._truncate() self._writer.flush() self.assertOutputEquals(b"z") def test_write_to_boundary(self) -> None: self._writer.write(b"foo") self._writer.write(b"barba") self.assertOutputEquals(b"0007foo0009barba") self._truncate() self._writer.write(b"z") self._writer.flush() self.assertOutputEquals(b"0005z") class PktLineParserTests(TestCase): def test_none(self) -> None: pktlines = [] parser = PktLineParser(pktlines.append) parser.parse(b"0000") self.assertEqual(pktlines, [None]) self.assertEqual(b"", parser.get_tail()) def test_small_fragments(self) -> None: pktlines = [] parser = PktLineParser(pktlines.append) parser.parse(b"00") parser.parse(b"05") parser.parse(b"z0000") self.assertEqual(pktlines, [b"z", None]) self.assertEqual(b"", parser.get_tail()) def test_multiple_packets(self) -> None: pktlines = [] parser = PktLineParser(pktlines.append) parser.parse(b"0005z0006aba") self.assertEqual(pktlines, [b"z", b"ab"]) self.assertEqual(b"a", parser.get_tail())