فهرست منبع

Add convenience function for checking validity of hex sha, rather than catching things like AssertionError.

Jelmer Vernooij 10 سال پیش
والد
کامیت
417e4be13c
5فایلهای تغییر یافته به همراه39 افزوده شده و 33 حذف شده
  1. 13 4
      dulwich/objects.py
  2. 5 4
      dulwich/protocol.py
  3. 5 8
      dulwich/refs.py
  4. 10 12
      dulwich/server.py
  5. 6 5
      dulwich/tests/test_protocol.py

+ 13 - 4
dulwich/objects.py

@@ -92,7 +92,7 @@ def sha_to_hex(sha):
 
 def hex_to_sha(hex):
     """Takes a hex sha and returns a binary sha"""
-    assert len(hex) == 40, "Incorrent length of hexsha: %s" % hex
+    assert len(hex) == 40, "Incorrect length of hexsha: %s" % hex
     try:
         return binascii.unhexlify(hex)
     except TypeError as exc:
@@ -101,6 +101,17 @@ def hex_to_sha(hex):
         raise ValueError(exc.args[0])
 
 
+def valid_hexsha(hex):
+    if len(hex) != 40:
+        return False
+    try:
+        binascii.unhexlify(hex)
+    except (TypeError, binascii.Error):
+        return False
+    else:
+        return True
+
+
 def hex_to_filename(path, hex):
     """Takes a hex sha and returns its filename relative to the given path."""
     # os.path.join accepts bytes or unicode, but all args must be of the same
@@ -162,9 +173,7 @@ def check_hexsha(hex, error_msg):
     :param error_msg: Error message to use in exception
     :raise ObjectFormatException: Raised when the string is not valid
     """
-    try:
-        hex_to_sha(hex)
-    except (TypeError, AssertionError, ValueError):
+    if not valid_hexsha(hex):
         raise ObjectFormatException("%s %s" % (error_msg, hex))
 
 

+ 5 - 4
dulwich/protocol.py

@@ -120,12 +120,13 @@ class Protocol(object):
             if self.report_activity:
                 self.report_activity(size, 'read')
             pkt_contents = read(size-4)
-            if len(pkt_contents) + 4 != size:
-                raise AssertionError('Length of pkt read %04x does not match length prefix %04x.'
-                                     % (len(pkt_contents) + 4, size))
-            return pkt_contents
         except socket.error as e:
             raise GitProtocolError(e)
+        else:
+            if len(pkt_contents) + 4 != size:
+                raise GitProtocolError(
+                    'Length of pkt read %04x does not match length prefix %04x' % (len(pkt_contents) + 4, size))
+            return pkt_contents
 
     def eof(self):
         """Test whether the protocol stream has reached EOF.

+ 5 - 8
dulwich/refs.py

@@ -31,6 +31,7 @@ from dulwich.errors import (
 from dulwich.objects import (
     hex_to_sha,
     git_line,
+    valid_hexsha,
     )
 from dulwich.file import (
     GitFile,
@@ -659,10 +660,8 @@ def _split_ref_line(line):
     if len(fields) != 2:
         raise PackedRefsException("invalid ref line %r" % line)
     sha, name = fields
-    try:
-        hex_to_sha(sha)
-    except (AssertionError, TypeError) as e:
-        raise PackedRefsException(e)
+    if not valid_hexsha(sha):
+        raise PackedRefsException("Invalid hex sha %r" % sha)
     if not check_ref_format(name):
         raise PackedRefsException("invalid ref name %r" % name)
     return (sha, name)
@@ -700,10 +699,8 @@ def read_packed_refs_with_peeled(f):
         if l.startswith(b'^'):
             if not last:
                 raise PackedRefsException("unexpected peeled ref line")
-            try:
-                hex_to_sha(l[1:])
-            except (AssertionError, TypeError) as e:
-                raise PackedRefsException(e)
+            if not valid_hexsha(l[1:]):
+                raise PackedRefsException("Invalid hex sha %r" % l[1:])
             sha, name = _split_ref_line(last)
             last = None
             yield (sha, name, l[1:])

+ 10 - 12
dulwich/server.py

@@ -60,8 +60,8 @@ from dulwich.errors import (
     )
 from dulwich import log_utils
 from dulwich.objects import (
-    hex_to_sha,
     Commit,
+    valid_hexsha,
     )
 from dulwich.pack import (
     write_pack_objects,
@@ -322,17 +322,15 @@ def _split_proto_line(line, allowed):
     command = fields[0]
     if allowed is not None and command not in allowed:
         raise UnexpectedCommandError(command)
-    try:
-        if len(fields) == 1 and command in (b'done', None):
-            return (command, None)
-        elif len(fields) == 2:
-            if command in (b'want', b'have', b'shallow', b'unshallow'):
-                hex_to_sha(fields[1])
-                return tuple(fields)
-            elif command == b'deepen':
-                return command, int(fields[1])
-    except (TypeError, AssertionError) as e:
-        raise GitProtocolError(e)
+    if len(fields) == 1 and command in (b'done', None):
+        return (command, None)
+    elif len(fields) == 2:
+        if command in (b'want', b'have', b'shallow', b'unshallow'):
+            if not valid_hexsha(fields[1]):
+                raise GitProtocolError("Invalid sha")
+            return tuple(fields)
+        elif command == b'deepen':
+            return command, int(fields[1])
     raise GitProtocolError('Received invalid line from client: %r' % line)
 
 

+ 6 - 5
dulwich/tests/test_protocol.py

@@ -25,6 +25,7 @@ from dulwich.errors import (
     HangupException,
     )
 from dulwich.protocol import (
+    GitProtocolError,
     PktLineParser,
     Protocol,
     ReceivableProtocol,
@@ -85,7 +86,7 @@ class BaseProtocolTests(object):
     def test_read_pkt_line_wrong_size(self):
         self.rin.write(b'0100too short')
         self.rin.seek(0)
-        self.assertRaises(AssertionError, self.proto.read_pkt_line)
+        self.assertRaises(GitProtocolError, self.proto.read_pkt_line)
 
     def test_write_sideband(self):
         self.proto.write_sideband(3, b'bloe')
@@ -126,7 +127,7 @@ class ReceivableBytesIO(BytesIO):
         # fail fast if no bytes are available; in a real socket, this would
         # block forever
         if self.tell() == len(self.getvalue()) and not self.allow_read_past_eof:
-            raise AssertionError('Blocking read past end of socket')
+            raise GitProtocolError('Blocking read past end of socket')
         if size == 1:
             return self.read(1)
         # calls shouldn't return quite as much as asked for
@@ -159,7 +160,7 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         for _ in range(10):
             data += self.proto.recv(10)
         # any more reads would block
-        self.assertRaises(AssertionError, self.proto.recv, 10)
+        self.assertRaises(GitProtocolError, self.proto.recv, 10)
         self.assertEqual(all_data, data)
 
     def test_recv_read(self):
@@ -168,7 +169,7 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         self.rin.seek(0)
         self.assertEqual(b'1234', self.proto.recv(4))
         self.assertEqual(b'567', self.proto.read(3))
-        self.assertRaises(AssertionError, self.proto.recv, 10)
+        self.assertRaises(GitProtocolError, self.proto.recv, 10)
 
     def test_read_recv(self):
         all_data = b'12345678abcdefg'
@@ -177,7 +178,7 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         self.assertEqual(b'1234', self.proto.read(4))
         self.assertEqual(b'5678abc', self.proto.recv(8))
         self.assertEqual(b'defg', self.proto.read(4))
-        self.assertRaises(AssertionError, self.proto.recv, 10)
+        self.assertRaises(GitProtocolError, self.proto.recv, 10)
 
     def test_mixed(self):
         # arbitrary non-repeating string