Browse Source

Add convenience function for checking validity of hex sha, rather than catching things like AssertionError.

Jelmer Vernooij 10 years ago
parent
commit
417e4be13c
5 changed files with 39 additions and 33 deletions
  1. 13 4
      dulwich/objects.py
  2. 5 4
      dulwich/protocol.py
  3. 5 8
      dulwich/refs.py
  4. 10 12
      dulwich/server.py
  5. 6 5
      dulwich/tests/test_protocol.py

+ 13 - 4
dulwich/objects.py

@@ -92,7 +92,7 @@ def sha_to_hex(sha):
 
 def hex_to_sha(hex):
     """Takes a hex sha and returns a binary sha"""
-    assert len(hex) == 40, "Incorrent length of hexsha: %s" % hex
+    assert len(hex) == 40, "Incorrect length of hexsha: %s" % hex
     try:
         return binascii.unhexlify(hex)
     except TypeError as exc:
@@ -101,6 +101,17 @@ def hex_to_sha(hex):
         raise ValueError(exc.args[0])
 
 
+def valid_hexsha(hex):
+    if len(hex) != 40:
+        return False
+    try:
+        binascii.unhexlify(hex)
+    except (TypeError, binascii.Error):
+        return False
+    else:
+        return True
+
+
 def hex_to_filename(path, hex):
     """Takes a hex sha and returns its filename relative to the given path."""
     # os.path.join accepts bytes or unicode, but all args must be of the same
@@ -162,9 +173,7 @@ def check_hexsha(hex, error_msg):
     :param error_msg: Error message to use in exception
     :raise ObjectFormatException: Raised when the string is not valid
     """
-    try:
-        hex_to_sha(hex)
-    except (TypeError, AssertionError, ValueError):
+    if not valid_hexsha(hex):
         raise ObjectFormatException("%s %s" % (error_msg, hex))
 
 

+ 5 - 4
dulwich/protocol.py

@@ -120,12 +120,13 @@ class Protocol(object):
             if self.report_activity:
                 self.report_activity(size, 'read')
             pkt_contents = read(size-4)
-            if len(pkt_contents) + 4 != size:
-                raise AssertionError('Length of pkt read %04x does not match length prefix %04x.'
-                                     % (len(pkt_contents) + 4, size))
-            return pkt_contents
         except socket.error as e:
             raise GitProtocolError(e)
+        else:
+            if len(pkt_contents) + 4 != size:
+                raise GitProtocolError(
+                    'Length of pkt read %04x does not match length prefix %04x' % (len(pkt_contents) + 4, size))
+            return pkt_contents
 
     def eof(self):
         """Test whether the protocol stream has reached EOF.

+ 5 - 8
dulwich/refs.py

@@ -31,6 +31,7 @@ from dulwich.errors import (
 from dulwich.objects import (
     hex_to_sha,
     git_line,
+    valid_hexsha,
     )
 from dulwich.file import (
     GitFile,
@@ -659,10 +660,8 @@ def _split_ref_line(line):
     if len(fields) != 2:
         raise PackedRefsException("invalid ref line %r" % line)
     sha, name = fields
-    try:
-        hex_to_sha(sha)
-    except (AssertionError, TypeError) as e:
-        raise PackedRefsException(e)
+    if not valid_hexsha(sha):
+        raise PackedRefsException("Invalid hex sha %r" % sha)
     if not check_ref_format(name):
         raise PackedRefsException("invalid ref name %r" % name)
     return (sha, name)
@@ -700,10 +699,8 @@ def read_packed_refs_with_peeled(f):
         if l.startswith(b'^'):
             if not last:
                 raise PackedRefsException("unexpected peeled ref line")
-            try:
-                hex_to_sha(l[1:])
-            except (AssertionError, TypeError) as e:
-                raise PackedRefsException(e)
+            if not valid_hexsha(l[1:]):
+                raise PackedRefsException("Invalid hex sha %r" % l[1:])
             sha, name = _split_ref_line(last)
             last = None
             yield (sha, name, l[1:])

+ 10 - 12
dulwich/server.py

@@ -60,8 +60,8 @@ from dulwich.errors import (
     )
 from dulwich import log_utils
 from dulwich.objects import (
-    hex_to_sha,
     Commit,
+    valid_hexsha,
     )
 from dulwich.pack import (
     write_pack_objects,
@@ -322,17 +322,15 @@ def _split_proto_line(line, allowed):
     command = fields[0]
     if allowed is not None and command not in allowed:
         raise UnexpectedCommandError(command)
-    try:
-        if len(fields) == 1 and command in (b'done', None):
-            return (command, None)
-        elif len(fields) == 2:
-            if command in (b'want', b'have', b'shallow', b'unshallow'):
-                hex_to_sha(fields[1])
-                return tuple(fields)
-            elif command == b'deepen':
-                return command, int(fields[1])
-    except (TypeError, AssertionError) as e:
-        raise GitProtocolError(e)
+    if len(fields) == 1 and command in (b'done', None):
+        return (command, None)
+    elif len(fields) == 2:
+        if command in (b'want', b'have', b'shallow', b'unshallow'):
+            if not valid_hexsha(fields[1]):
+                raise GitProtocolError("Invalid sha")
+            return tuple(fields)
+        elif command == b'deepen':
+            return command, int(fields[1])
     raise GitProtocolError('Received invalid line from client: %r' % line)
 
 

+ 6 - 5
dulwich/tests/test_protocol.py

@@ -25,6 +25,7 @@ from dulwich.errors import (
     HangupException,
     )
 from dulwich.protocol import (
+    GitProtocolError,
     PktLineParser,
     Protocol,
     ReceivableProtocol,
@@ -85,7 +86,7 @@ class BaseProtocolTests(object):
     def test_read_pkt_line_wrong_size(self):
         self.rin.write(b'0100too short')
         self.rin.seek(0)
-        self.assertRaises(AssertionError, self.proto.read_pkt_line)
+        self.assertRaises(GitProtocolError, self.proto.read_pkt_line)
 
     def test_write_sideband(self):
         self.proto.write_sideband(3, b'bloe')
@@ -126,7 +127,7 @@ class ReceivableBytesIO(BytesIO):
         # fail fast if no bytes are available; in a real socket, this would
         # block forever
         if self.tell() == len(self.getvalue()) and not self.allow_read_past_eof:
-            raise AssertionError('Blocking read past end of socket')
+            raise GitProtocolError('Blocking read past end of socket')
         if size == 1:
             return self.read(1)
         # calls shouldn't return quite as much as asked for
@@ -159,7 +160,7 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         for _ in range(10):
             data += self.proto.recv(10)
         # any more reads would block
-        self.assertRaises(AssertionError, self.proto.recv, 10)
+        self.assertRaises(GitProtocolError, self.proto.recv, 10)
         self.assertEqual(all_data, data)
 
     def test_recv_read(self):
@@ -168,7 +169,7 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         self.rin.seek(0)
         self.assertEqual(b'1234', self.proto.recv(4))
         self.assertEqual(b'567', self.proto.read(3))
-        self.assertRaises(AssertionError, self.proto.recv, 10)
+        self.assertRaises(GitProtocolError, self.proto.recv, 10)
 
     def test_read_recv(self):
         all_data = b'12345678abcdefg'
@@ -177,7 +178,7 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         self.assertEqual(b'1234', self.proto.read(4))
         self.assertEqual(b'5678abc', self.proto.recv(8))
         self.assertEqual(b'defg', self.proto.read(4))
-        self.assertRaises(AssertionError, self.proto.recv, 10)
+        self.assertRaises(GitProtocolError, self.proto.recv, 10)
 
     def test_mixed(self):
         # arbitrary non-repeating string