Explorar o código

Convert protocol module to python3.

Jelmer Vernooij %!s(int64=10) %!d(string=hai) anos
pai
achega
7f6c773066
Modificáronse 2 ficheiros con 94 adicións e 97 borrados
  1. 15 15
      dulwich/protocol.py
  2. 79 82
      dulwich/tests/test_protocol.py

+ 15 - 15
dulwich/protocol.py

@@ -32,7 +32,7 @@ from dulwich.errors import (
 
 TCP_GIT_PORT = 9418
 
-ZERO_SHA = "0" * 40
+ZERO_SHA = b"0" * 40
 
 SINGLE_ACK = 0
 MULTI_ACK = 1
@@ -61,8 +61,8 @@ def pkt_line(data):
         None, returns the flush-pkt ('0000').
     """
     if data is None:
-        return '0000'
-    return '%04x%s' % (len(data) + 4, data)
+        return b'0000'
+    return ('%04x' % (len(data) + 4)).encode('ascii') + data
 
 
 class Protocol(object):
@@ -209,7 +209,7 @@ class Protocol(object):
         # 65520-5 = 65515
         # WTF: Why have the len in ASCII, but the channel in binary.
         while blob:
-            self.write_pkt_line("%s%s" % (chr(channel), blob[:65515]))
+            self.write_pkt_line(bytes(bytearray([channel])) + blob[:65515])
             blob = blob[65515:]
 
     def send_cmd(self, cmd, *args):
@@ -220,7 +220,7 @@ class Protocol(object):
         :param cmd: The remote service to access.
         :param args: List of arguments to send to remove service.
         """
-        self.write_pkt_line("%s %s" % (cmd, "".join(["%s\0" % a for a in args])))
+        self.write_pkt_line(cmd + b" " + b"".join([(a + b"\0") for a in args]))
 
     def read_cmd(self):
         """Read a command and some arguments from the git client
@@ -230,10 +230,10 @@ class Protocol(object):
         :return: A tuple of (command, [list of arguments]).
         """
         line = self.read_pkt_line()
-        splice_at = line.find(" ")
+        splice_at = line.find(b" ")
         cmd, args = line[:splice_at], line[splice_at+1:]
-        assert args[-1] == "\x00"
-        return cmd, args[:-1].split(chr(0))
+        assert args[-1:] == b"\x00"
+        return cmd, args[:-1].split(b"\0")
 
 
 _RBUFSIZE = 8192  # Default read buffer size.
@@ -348,10 +348,10 @@ def extract_capabilities(text):
     :param text: String to extract from
     :return: Tuple with text with capabilities removed and list of capabilities
     """
-    if not "\0" in text:
+    if not b"\0" in text:
         return text, []
-    text, capabilities = text.rstrip().split("\0")
-    return (text, capabilities.strip().split(" "))
+    text, capabilities = text.rstrip().split(b"\0")
+    return (text, capabilities.strip().split(b" "))
 
 
 def extract_want_line_capabilities(text):
@@ -365,17 +365,17 @@ def extract_want_line_capabilities(text):
     :param text: Want line to extract from
     :return: Tuple with text with capabilities removed and list of capabilities
     """
-    split_text = text.rstrip().split(" ")
+    split_text = text.rstrip().split(b" ")
     if len(split_text) < 3:
         return text, []
-    return (" ".join(split_text[:2]), split_text[2:])
+    return (b" ".join(split_text[:2]), split_text[2:])
 
 
 def ack_type(capabilities):
     """Extract the ack type from a capabilities list."""
-    if 'multi_ack_detailed' in capabilities:
+    if b'multi_ack_detailed' in capabilities:
         return MULTI_ACK_DETAILED
-    elif 'multi_ack' in capabilities:
+    elif b'multi_ack' in capabilities:
         return MULTI_ACK
     return SINGLE_ACK
 

+ 79 - 82
dulwich/tests/test_protocol.py

@@ -37,26 +37,25 @@ from dulwich.protocol import (
     BufferedPktLineWriter,
     )
 from dulwich.tests import TestCase
-from dulwich.tests.utils import skipIfPY3
 
 
 class BaseProtocolTests(object):
 
     def test_write_pkt_line_none(self):
         self.proto.write_pkt_line(None)
-        self.assertEqual(self.rout.getvalue(), '0000')
+        self.assertEqual(self.rout.getvalue(), b'0000')
 
     def test_write_pkt_line(self):
-        self.proto.write_pkt_line('bla')
-        self.assertEqual(self.rout.getvalue(), '0007bla')
+        self.proto.write_pkt_line(b'bla')
+        self.assertEqual(self.rout.getvalue(), b'0007bla')
 
     def test_read_pkt_line(self):
-        self.rin.write('0008cmd ')
+        self.rin.write(b'0008cmd ')
         self.rin.seek(0)
-        self.assertEqual('cmd ', self.proto.read_pkt_line())
+        self.assertEqual(b'cmd ', self.proto.read_pkt_line())
 
     def test_eof(self):
-        self.rin.write('0000')
+        self.rin.write(b'0000')
         self.rin.seek(0)
         self.assertFalse(self.proto.eof())
         self.assertEqual(None, self.proto.read_pkt_line())
@@ -64,50 +63,49 @@ class BaseProtocolTests(object):
         self.assertRaises(HangupException, self.proto.read_pkt_line)
 
     def test_unread_pkt_line(self):
-        self.rin.write('0007foo0000')
+        self.rin.write(b'0007foo0000')
         self.rin.seek(0)
-        self.assertEqual('foo', self.proto.read_pkt_line())
-        self.proto.unread_pkt_line('bar')
-        self.assertEqual('bar', self.proto.read_pkt_line())
+        self.assertEqual(b'foo', self.proto.read_pkt_line())
+        self.proto.unread_pkt_line(b'bar')
+        self.assertEqual(b'bar', self.proto.read_pkt_line())
         self.assertEqual(None, self.proto.read_pkt_line())
-        self.proto.unread_pkt_line('baz1')
-        self.assertRaises(ValueError, self.proto.unread_pkt_line, 'baz2')
+        self.proto.unread_pkt_line(b'baz1')
+        self.assertRaises(ValueError, self.proto.unread_pkt_line, b'baz2')
 
     def test_read_pkt_seq(self):
-        self.rin.write('0008cmd 0005l0000')
+        self.rin.write(b'0008cmd 0005l0000')
         self.rin.seek(0)
-        self.assertEqual(['cmd ', 'l'], list(self.proto.read_pkt_seq()))
+        self.assertEqual([b'cmd ', b'l'], list(self.proto.read_pkt_seq()))
 
     def test_read_pkt_line_none(self):
-        self.rin.write('0000')
+        self.rin.write(b'0000')
         self.rin.seek(0)
         self.assertEqual(None, self.proto.read_pkt_line())
 
     def test_read_pkt_line_wrong_size(self):
-        self.rin.write('0100too short')
+        self.rin.write(b'0100too short')
         self.rin.seek(0)
         self.assertRaises(AssertionError, self.proto.read_pkt_line)
 
     def test_write_sideband(self):
-        self.proto.write_sideband(3, 'bloe')
-        self.assertEqual(self.rout.getvalue(), '0009\x03bloe')
+        self.proto.write_sideband(3, b'bloe')
+        self.assertEqual(self.rout.getvalue(), b'0009\x03bloe')
 
     def test_send_cmd(self):
-        self.proto.send_cmd('fetch', 'a', 'b')
-        self.assertEqual(self.rout.getvalue(), '000efetch a\x00b\x00')
+        self.proto.send_cmd(b'fetch', b'a', b'b')
+        self.assertEqual(self.rout.getvalue(), b'000efetch a\x00b\x00')
 
     def test_read_cmd(self):
-        self.rin.write('0012cmd arg1\x00arg2\x00')
+        self.rin.write(b'0012cmd arg1\x00arg2\x00')
         self.rin.seek(0)
-        self.assertEqual(('cmd', ['arg1', 'arg2']), self.proto.read_cmd())
+        self.assertEqual((b'cmd', [b'arg1', b'arg2']), self.proto.read_cmd())
 
     def test_read_cmd_noend0(self):
-        self.rin.write('0011cmd arg1\x00arg2')
+        self.rin.write(b'0011cmd arg1\x00arg2')
         self.rin.seek(0)
         self.assertRaises(AssertionError, self.proto.read_cmd)
 
 
-@skipIfPY3
 class ProtocolTests(BaseProtocolTests, TestCase):
 
     def setUp(self):
@@ -135,7 +133,6 @@ class ReceivableBytesIO(BytesIO):
         return self.read(size - 1)
 
 
-@skipIfPY3
 class ReceivableProtocolTests(BaseProtocolTests, TestCase):
 
     def setUp(self):
@@ -153,10 +150,10 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         BaseProtocolTests.test_eof(self)
 
     def test_recv(self):
-        all_data = '1234567' * 10  # not a multiple of bufsize
+        all_data = b'1234567' * 10  # not a multiple of bufsize
         self.rin.write(all_data)
         self.rin.seek(0)
-        data = ''
+        data = b''
         # We ask for 8 bytes each time and actually read 7, so it should take
         # exactly 10 iterations.
         for _ in range(10):
@@ -166,28 +163,28 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         self.assertEqual(all_data, data)
 
     def test_recv_read(self):
-        all_data = '1234567'  # recv exactly in one call
+        all_data = b'1234567'  # recv exactly in one call
         self.rin.write(all_data)
         self.rin.seek(0)
-        self.assertEqual('1234', self.proto.recv(4))
-        self.assertEqual('567', self.proto.read(3))
+        self.assertEqual(b'1234', self.proto.recv(4))
+        self.assertEqual(b'567', self.proto.read(3))
         self.assertRaises(AssertionError, self.proto.recv, 10)
 
     def test_read_recv(self):
-        all_data = '12345678abcdefg'
+        all_data = b'12345678abcdefg'
         self.rin.write(all_data)
         self.rin.seek(0)
-        self.assertEqual('1234', self.proto.read(4))
-        self.assertEqual('5678abc', self.proto.recv(8))
-        self.assertEqual('defg', self.proto.read(4))
+        self.assertEqual(b'1234', self.proto.read(4))
+        self.assertEqual(b'5678abc', self.proto.recv(8))
+        self.assertEqual(b'defg', self.proto.read(4))
         self.assertRaises(AssertionError, self.proto.recv, 10)
 
     def test_mixed(self):
         # arbitrary non-repeating string
-        all_data = ','.join(str(i) for i in range(100))
+        all_data = b','.join(str(i).encode('ascii') for i in range(100))
         self.rin.write(all_data)
         self.rin.seek(0)
-        data = ''
+        data = b''
 
         for i in range(1, 100):
             data += self.proto.recv(i)
@@ -207,37 +204,38 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         self.assertEqual(all_data, data)
 
 
-@skipIfPY3
 class CapabilitiesTestCase(TestCase):
 
     def test_plain(self):
-        self.assertEqual(('bla', []), extract_capabilities('bla'))
+        self.assertEqual((b'bla', []), extract_capabilities(b'bla'))
 
     def test_caps(self):
-        self.assertEqual(('bla', ['la']), extract_capabilities('bla\0la'))
-        self.assertEqual(('bla', ['la']), extract_capabilities('bla\0la\n'))
-        self.assertEqual(('bla', ['la', 'la']), extract_capabilities('bla\0la la'))
+        self.assertEqual((b'bla', [b'la']), extract_capabilities(b'bla\0la'))
+        self.assertEqual((b'bla', [b'la']), extract_capabilities(b'bla\0la\n'))
+        self.assertEqual((b'bla', [b'la', b'la']), extract_capabilities(b'bla\0la la'))
 
     def test_plain_want_line(self):
-        self.assertEqual(('want bla', []), extract_want_line_capabilities('want bla'))
+        self.assertEqual((b'want bla', []), extract_want_line_capabilities(b'want bla'))
 
     def test_caps_want_line(self):
-        self.assertEqual(('want bla', ['la']), extract_want_line_capabilities('want bla la'))
-        self.assertEqual(('want bla', ['la']), extract_want_line_capabilities('want bla la\n'))
-        self.assertEqual(('want bla', ['la', 'la']), extract_want_line_capabilities('want bla la la'))
+        self.assertEqual((b'want bla', [b'la']),
+                extract_want_line_capabilities(b'want bla la'))
+        self.assertEqual((b'want bla', [b'la']),
+                extract_want_line_capabilities(b'want bla la\n'))
+        self.assertEqual((b'want bla', [b'la', b'la']),
+                extract_want_line_capabilities(b'want bla la la'))
 
     def test_ack_type(self):
-        self.assertEqual(SINGLE_ACK, ack_type(['foo', 'bar']))
-        self.assertEqual(MULTI_ACK, ack_type(['foo', 'bar', 'multi_ack']))
+        self.assertEqual(SINGLE_ACK, ack_type([b'foo', b'bar']))
+        self.assertEqual(MULTI_ACK, ack_type([b'foo', b'bar', b'multi_ack']))
         self.assertEqual(MULTI_ACK_DETAILED,
-                          ack_type(['foo', 'bar', 'multi_ack_detailed']))
+                          ack_type([b'foo', b'bar', b'multi_ack_detailed']))
         # choose detailed when both present
         self.assertEqual(MULTI_ACK_DETAILED,
-                          ack_type(['foo', 'bar', 'multi_ack',
-                                    'multi_ack_detailed']))
+                          ack_type([b'foo', b'bar', b'multi_ack',
+                                    b'multi_ack_detailed']))
 
 
-@skipIfPY3
 class BufferedPktLineWriterTests(TestCase):
 
     def setUp(self):
@@ -253,68 +251,67 @@ class BufferedPktLineWriterTests(TestCase):
         self._output.truncate()
 
     def test_write(self):
-        self._writer.write('foo')
-        self.assertOutputEquals('')
+        self._writer.write(b'foo')
+        self.assertOutputEquals(b'')
         self._writer.flush()
-        self.assertOutputEquals('0007foo')
+        self.assertOutputEquals(b'0007foo')
 
     def test_write_none(self):
         self._writer.write(None)
-        self.assertOutputEquals('')
+        self.assertOutputEquals(b'')
         self._writer.flush()
-        self.assertOutputEquals('0000')
+        self.assertOutputEquals(b'0000')
 
     def test_flush_empty(self):
         self._writer.flush()
-        self.assertOutputEquals('')
+        self.assertOutputEquals(b'')
 
     def test_write_multiple(self):
-        self._writer.write('foo')
-        self._writer.write('bar')
-        self.assertOutputEquals('')
+        self._writer.write(b'foo')
+        self._writer.write(b'bar')
+        self.assertOutputEquals(b'')
         self._writer.flush()
-        self.assertOutputEquals('0007foo0007bar')
+        self.assertOutputEquals(b'0007foo0007bar')
 
     def test_write_across_boundary(self):
-        self._writer.write('foo')
-        self._writer.write('barbaz')
-        self.assertOutputEquals('0007foo000abarba')
+        self._writer.write(b'foo')
+        self._writer.write(b'barbaz')
+        self.assertOutputEquals(b'0007foo000abarba')
         self._truncate()
         self._writer.flush()
-        self.assertOutputEquals('z')
+        self.assertOutputEquals(b'z')
 
     def test_write_to_boundary(self):
-        self._writer.write('foo')
-        self._writer.write('barba')
-        self.assertOutputEquals('0007foo0009barba')
+        self._writer.write(b'foo')
+        self._writer.write(b'barba')
+        self.assertOutputEquals(b'0007foo0009barba')
         self._truncate()
-        self._writer.write('z')
+        self._writer.write(b'z')
         self._writer.flush()
-        self.assertOutputEquals('0005z')
+        self.assertOutputEquals(b'0005z')
 
 
-@skipIfPY3
 class PktLineParserTests(TestCase):
 
     def test_none(self):
         pktlines = []
         parser = PktLineParser(pktlines.append)
-        parser.parse("0000")
+        parser.parse(b"0000")
         self.assertEqual(pktlines, [None])
-        self.assertEqual("", parser.get_tail())
+        self.assertEqual(b"", parser.get_tail())
 
     def test_small_fragments(self):
         pktlines = []
         parser = PktLineParser(pktlines.append)
-        parser.parse("00")
-        parser.parse("05")
-        parser.parse("z0000")
-        self.assertEqual(pktlines, ["z", None])
-        self.assertEqual("", parser.get_tail())
+        parser.parse(b"00")
+        parser.parse(b"05")
+        parser.parse(b"z0000")
+        self.assertEqual(pktlines, [b"z", None])
+        self.assertEqual(b"", parser.get_tail())
 
     def test_multiple_packets(self):
         pktlines = []
         parser = PktLineParser(pktlines.append)
-        parser.parse("0005z0006aba")
-        self.assertEqual(pktlines, ["z", "ab"])
-        self.assertEqual("a", parser.get_tail())
+        parser.parse(b"0005z0006aba")
+        self.assertEqual(pktlines, [b"z", b"ab"])
+        self.assertEqual(b"a", parser.get_tail())