浏览代码

Add some more typing, define ObjectID.

Jelmer Vernooij 2 年之前
父节点
当前提交
2ea0609607
共有 3 个文件被更改,包括 63 次插入48 次删除
  1. 4 0
      NEWS
  2. 50 42
      dulwich/objects.py
  3. 9 6
      dulwich/walk.py

+ 4 - 0
NEWS

@@ -4,6 +4,10 @@
    when creating symlinks fails due to a permission
    when creating symlinks fails due to a permission
    error. (Jelmer Vernooij, #1005)
    error. (Jelmer Vernooij, #1005)
 
 
+ * Add new ``ObjectID`` type in ``dulwich.objects``,
+   currently just an alias for ``bytes``.
+   (Jelmer Vernooij)
+
  * Support repository format version 1.
  * Support repository format version 1.
    (Jelmer Vernooij, #1056)
    (Jelmer Vernooij, #1056)
 
 

+ 50 - 42
dulwich/objects.py

@@ -33,6 +33,8 @@ from typing import (
     Iterable,
     Iterable,
     Union,
     Union,
     Type,
     Type,
+    Iterator,
+    List,
 )
 )
 import zlib
 import zlib
 from hashlib import sha1
 from hashlib import sha1
@@ -75,6 +77,9 @@ MAX_TIME = 9223372036854775807  # (2**63) - 1 - signed long int max
 BEGIN_PGP_SIGNATURE = b"-----BEGIN PGP SIGNATURE-----"
 BEGIN_PGP_SIGNATURE = b"-----BEGIN PGP SIGNATURE-----"
 
 
 
 
+ObjectID = bytes
+
+
 class EmptyFileException(FileFormatException):
 class EmptyFileException(FileFormatException):
     """An unexpectedly empty file was encountered."""
     """An unexpectedly empty file was encountered."""
 
 
@@ -153,7 +158,10 @@ def filename_to_hex(filename):
 
 
 def object_header(num_type: int, length: int) -> bytes:
 def object_header(num_type: int, length: int) -> bytes:
     """Return an object header for the given numeric type and text length."""
     """Return an object header for the given numeric type and text length."""
-    return object_class(num_type).type_name + b" " + str(length).encode("ascii") + b"\0"
+    cls = object_class(num_type)
+    if cls is None:
+        raise AssertionError("unsupported class type num: %d" % num_type)
+    return cls.type_name + b" " + str(length).encode("ascii") + b"\0"
 
 
 
 
 def serializable_property(name: str, docstring: Optional[str] = None):
 def serializable_property(name: str, docstring: Optional[str] = None):
@@ -169,7 +177,7 @@ def serializable_property(name: str, docstring: Optional[str] = None):
     return property(get, set, doc=docstring)
     return property(get, set, doc=docstring)
 
 
 
 
-def object_class(type):
+def object_class(type: Union[bytes, int]) -> Optional[Type["ShaFile"]]:
     """Get the object class corresponding to the given type.
     """Get the object class corresponding to the given type.
 
 
     Args:
     Args:
@@ -193,7 +201,7 @@ def check_hexsha(hex, error_msg):
         raise ObjectFormatException("%s %s" % (error_msg, hex))
         raise ObjectFormatException("%s %s" % (error_msg, hex))
 
 
 
 
-def check_identity(identity, error_msg):
+def check_identity(identity: bytes, error_msg: str) -> None:
     """Check if the specified identity is valid.
     """Check if the specified identity is valid.
 
 
     This will raise an exception if the identity is not valid.
     This will raise an exception if the identity is not valid.
@@ -261,11 +269,13 @@ class ShaFile(object):
 
 
     __slots__ = ("_chunked_text", "_sha", "_needs_serialization")
     __slots__ = ("_chunked_text", "_sha", "_needs_serialization")
 
 
-    type_name = None  # type: bytes
-    type_num = None  # type: int
+    _needs_serialization: bool
+    type_name: bytes
+    type_num: int
+    _chunked_text: Optional[List[bytes]]
 
 
     @staticmethod
     @staticmethod
-    def _parse_legacy_object_header(magic, f):
+    def _parse_legacy_object_header(magic, f) -> "ShaFile":
         """Parse a legacy object, creating it but not reading the file."""
         """Parse a legacy object, creating it but not reading the file."""
         bufsize = 1024
         bufsize = 1024
         decomp = zlib.decompressobj()
         decomp = zlib.decompressobj()
@@ -287,10 +297,10 @@ class ShaFile(object):
                 "Object size not an integer: %s" % exc) from exc
                 "Object size not an integer: %s" % exc) from exc
         obj_class = object_class(type_name)
         obj_class = object_class(type_name)
         if not obj_class:
         if not obj_class:
-            raise ObjectFormatException("Not a known type: %s" % type_name)
+            raise ObjectFormatException("Not a known type: %s" % type_name.decode('ascii'))
         return obj_class()
         return obj_class()
 
 
-    def _parse_legacy_object(self, map):
+    def _parse_legacy_object(self, map) -> None:
         """Parse a legacy object, setting the raw string."""
         """Parse a legacy object, setting the raw string."""
         text = _decompress(map)
         text = _decompress(map)
         header_end = text.find(b"\0")
         header_end = text.find(b"\0")
@@ -298,7 +308,8 @@ class ShaFile(object):
             raise ObjectFormatException("Invalid object header, no \\0")
             raise ObjectFormatException("Invalid object header, no \\0")
         self.set_raw_string(text[header_end + 1 :])
         self.set_raw_string(text[header_end + 1 :])
 
 
-    def as_legacy_object_chunks(self, compression_level=-1):
+    def as_legacy_object_chunks(
+            self, compression_level: int = -1) -> Iterator[bytes]:
         """Return chunks representing the object in the experimental format.
         """Return chunks representing the object in the experimental format.
 
 
         Returns: List of strings
         Returns: List of strings
@@ -309,13 +320,13 @@ class ShaFile(object):
             yield compobj.compress(chunk)
             yield compobj.compress(chunk)
         yield compobj.flush()
         yield compobj.flush()
 
 
-    def as_legacy_object(self, compression_level=-1):
+    def as_legacy_object(self, compression_level: int = -1) -> bytes:
         """Return string representing the object in the experimental format."""
         """Return string representing the object in the experimental format."""
         return b"".join(
         return b"".join(
             self.as_legacy_object_chunks(compression_level=compression_level)
             self.as_legacy_object_chunks(compression_level=compression_level)
         )
         )
 
 
-    def as_raw_chunks(self):
+    def as_raw_chunks(self) -> List[bytes]:
         """Return chunks with serialization of the object.
         """Return chunks with serialization of the object.
 
 
         Returns: List of strings, not necessarily one per line
         Returns: List of strings, not necessarily one per line
@@ -324,16 +335,16 @@ class ShaFile(object):
             self._sha = None
             self._sha = None
             self._chunked_text = self._serialize()
             self._chunked_text = self._serialize()
             self._needs_serialization = False
             self._needs_serialization = False
-        return self._chunked_text
+        return self._chunked_text  # type: ignore
 
 
-    def as_raw_string(self):
+    def as_raw_string(self) -> bytes:
         """Return raw string with serialization of the object.
         """Return raw string with serialization of the object.
 
 
         Returns: String object
         Returns: String object
         """
         """
         return b"".join(self.as_raw_chunks())
         return b"".join(self.as_raw_chunks())
 
 
-    def __bytes__(self):
+    def __bytes__(self) -> bytes:
         """Return raw string serialization of this object."""
         """Return raw string serialization of this object."""
         return self.as_raw_string()
         return self.as_raw_string()
 
 
@@ -341,24 +352,27 @@ class ShaFile(object):
         """Return unique hash for this object."""
         """Return unique hash for this object."""
         return hash(self.id)
         return hash(self.id)
 
 
-    def as_pretty_string(self):
+    def as_pretty_string(self) -> bytes:
         """Return a string representing this object, fit for display."""
         """Return a string representing this object, fit for display."""
         return self.as_raw_string()
         return self.as_raw_string()
 
 
-    def set_raw_string(self, text, sha=None):
+    def set_raw_string(
+            self, text: bytes, sha: Optional[ObjectID] = None) -> None:
         """Set the contents of this object from a serialized string."""
         """Set the contents of this object from a serialized string."""
         if not isinstance(text, bytes):
         if not isinstance(text, bytes):
             raise TypeError("Expected bytes for text, got %r" % text)
             raise TypeError("Expected bytes for text, got %r" % text)
         self.set_raw_chunks([text], sha)
         self.set_raw_chunks([text], sha)
 
 
-    def set_raw_chunks(self, chunks, sha=None):
+    def set_raw_chunks(
+            self, chunks: List[bytes],
+            sha: Optional[ObjectID] = None) -> None:
         """Set the contents of this object from a list of chunks."""
         """Set the contents of this object from a list of chunks."""
         self._chunked_text = chunks
         self._chunked_text = chunks
         self._deserialize(chunks)
         self._deserialize(chunks)
         if sha is None:
         if sha is None:
             self._sha = None
             self._sha = None
         else:
         else:
-            self._sha = FixedSha(sha)
+            self._sha = FixedSha(sha)  # type: ignore
         self._needs_serialization = False
         self._needs_serialization = False
 
 
     @staticmethod
     @staticmethod
@@ -370,7 +384,7 @@ class ShaFile(object):
             raise ObjectFormatException("Not a known type %d" % num_type)
             raise ObjectFormatException("Not a known type %d" % num_type)
         return obj_class()
         return obj_class()
 
 
-    def _parse_object(self, map):
+    def _parse_object(self, map) -> None:
         """Parse a new style object, setting self._text."""
         """Parse a new style object, setting self._text."""
         # skip type and size; type must have already been determined, and
         # skip type and size; type must have already been determined, and
         # we trust zlib to fail if it's otherwise corrupted
         # we trust zlib to fail if it's otherwise corrupted
@@ -383,7 +397,7 @@ class ShaFile(object):
         self.set_raw_string(_decompress(raw))
         self.set_raw_string(_decompress(raw))
 
 
     @classmethod
     @classmethod
-    def _is_legacy_object(cls, magic):
+    def _is_legacy_object(cls, magic: bytes) -> bool:
         b0 = ord(magic[0:1])
         b0 = ord(magic[0:1])
         b1 = ord(magic[1:2])
         b1 = ord(magic[1:2])
         word = (b0 << 8) + b1
         word = (b0 << 8) + b1
@@ -445,7 +459,9 @@ class ShaFile(object):
         return obj
         return obj
 
 
     @staticmethod
     @staticmethod
-    def from_raw_chunks(type_num, chunks, sha=None):
+    def from_raw_chunks(
+            type_num: int, chunks: List[bytes],
+            sha: Optional[ObjectID] = None):
         """Creates an object of the indicated type from the raw chunks given.
         """Creates an object of the indicated type from the raw chunks given.
 
 
         Args:
         Args:
@@ -453,7 +469,10 @@ class ShaFile(object):
           chunks: An iterable of the raw uncompressed contents.
           chunks: An iterable of the raw uncompressed contents.
           sha: Optional known sha for the object
           sha: Optional known sha for the object
         """
         """
-        obj = object_class(type_num)()
+        cls = object_class(type_num)
+        if cls is None:
+            raise AssertionError("unsupported class type num: %d" % type_num)
+        obj = cls()
         obj.set_raw_chunks(chunks, sha)
         obj.set_raw_chunks(chunks, sha)
         return obj
         return obj
 
 
@@ -477,7 +496,7 @@ class ShaFile(object):
         if getattr(self, member, None) is None:
         if getattr(self, member, None) is None:
             raise ObjectFormatException(error_msg)
             raise ObjectFormatException(error_msg)
 
 
-    def check(self):
+    def check(self) -> None:
         """Check this object for internal consistency.
         """Check this object for internal consistency.
 
 
         Raises:
         Raises:
@@ -500,9 +519,9 @@ class ShaFile(object):
             raise ChecksumMismatch(new_sha, old_sha)
             raise ChecksumMismatch(new_sha, old_sha)
 
 
     def _header(self):
     def _header(self):
-        return object_header(self.type, self.raw_length())
+        return object_header(self.type_num, self.raw_length())
 
 
-    def raw_length(self):
+    def raw_length(self) -> int:
         """Returns the length of the raw string of this object."""
         """Returns the length of the raw string of this object."""
         ret = 0
         ret = 0
         for chunk in self.as_raw_chunks():
         for chunk in self.as_raw_chunks():
@@ -522,25 +541,14 @@ class ShaFile(object):
 
 
     def copy(self):
     def copy(self):
         """Create a new copy of this SHA1 object from its raw string"""
         """Create a new copy of this SHA1 object from its raw string"""
-        obj_class = object_class(self.get_type())
-        return obj_class.from_raw_string(self.get_type(), self.as_raw_string(), self.id)
+        obj_class = object_class(self.type_num)
+        return obj_class.from_raw_string(self.type_num, self.as_raw_string(), self.id)
 
 
     @property
     @property
     def id(self):
     def id(self):
         """The hex SHA of this object."""
         """The hex SHA of this object."""
         return self.sha().hexdigest().encode("ascii")
         return self.sha().hexdigest().encode("ascii")
 
 
-    def get_type(self):
-        """Return the type number for this object class."""
-        return self.type_num
-
-    def set_type(self, type):
-        """Set the type number for this object class."""
-        self.type_num = type
-
-    # DEPRECATED: use type_num or type_name as needed.
-    type = property(get_type, set_type)
-
     def __repr__(self):
     def __repr__(self):
         return "<%s %s>" % (self.__class__.__name__, self.id)
         return "<%s %s>" % (self.__class__.__name__, self.id)
 
 
@@ -621,7 +629,7 @@ class Blob(ShaFile):
         """
         """
         super(Blob, self).check()
         super(Blob, self).check()
 
 
-    def splitlines(self):
+    def splitlines(self) -> List[bytes]:
         """Return list of lines in this blob.
         """Return list of lines in this blob.
 
 
         This preserves the original line endings.
         This preserves the original line endings.
@@ -649,7 +657,7 @@ class Blob(ShaFile):
         return ret
         return ret
 
 
 
 
-def _parse_message(chunks):
+def _parse_message(chunks: Iterable[bytes]):
     """Parse a message with a list of fields and a body.
     """Parse a message with a list of fields and a body.
 
 
     Args:
     Args:
@@ -660,7 +668,7 @@ def _parse_message(chunks):
     """
     """
     f = BytesIO(b"".join(chunks))
     f = BytesIO(b"".join(chunks))
     k = None
     k = None
-    v = ""
+    v = b""
     eof = False
     eof = False
 
 
     def _strip_last_newline(value):
     def _strip_last_newline(value):
@@ -1596,7 +1604,7 @@ OBJECT_CLASSES = (
     Tag,
     Tag,
 )
 )
 
 
-_TYPE_MAP = {}  # type: Dict[Union[bytes, int], Type[ShaFile]]
+_TYPE_MAP: Dict[Union[bytes, int], Type[ShaFile]] = {}
 
 
 for cls in OBJECT_CLASSES:
 for cls in OBJECT_CLASSES:
     _TYPE_MAP[cls.type_name] = cls
     _TYPE_MAP[cls.type_name] = cls

+ 9 - 6
dulwich/walk.py

@@ -24,6 +24,7 @@
 import collections
 import collections
 import heapq
 import heapq
 from itertools import chain
 from itertools import chain
+from typing import List, Tuple, Set
 
 
 from dulwich.diff_tree import (
 from dulwich.diff_tree import (
     RENAME_CHANGE_TYPES,
     RENAME_CHANGE_TYPES,
@@ -35,7 +36,9 @@ from dulwich.errors import (
     MissingCommitError,
     MissingCommitError,
 )
 )
 from dulwich.objects import (
 from dulwich.objects import (
+    Commit,
     Tag,
     Tag,
+    ObjectID,
 )
 )
 
 
 ORDER_DATE = "date"
 ORDER_DATE = "date"
@@ -128,15 +131,15 @@ class WalkEntry(object):
 class _CommitTimeQueue(object):
 class _CommitTimeQueue(object):
     """Priority queue of WalkEntry objects by commit time."""
     """Priority queue of WalkEntry objects by commit time."""
 
 
-    def __init__(self, walker):
+    def __init__(self, walker: "Walker"):
         self._walker = walker
         self._walker = walker
         self._store = walker.store
         self._store = walker.store
         self._get_parents = walker.get_parents
         self._get_parents = walker.get_parents
         self._excluded = walker.excluded
         self._excluded = walker.excluded
-        self._pq = []
-        self._pq_set = set()
-        self._seen = set()
-        self._done = set()
+        self._pq: List[Tuple[int, Commit]] = []
+        self._pq_set: Set[ObjectID] = set()
+        self._seen: Set[ObjectID] = set()
+        self._done: Set[ObjectID] = set()
         self._min_time = walker.since
         self._min_time = walker.since
         self._last = None
         self._last = None
         self._extra_commits_left = _MAX_EXTRA_COMMITS
         self._extra_commits_left = _MAX_EXTRA_COMMITS
@@ -145,7 +148,7 @@ class _CommitTimeQueue(object):
         for commit_id in chain(walker.include, walker.excluded):
         for commit_id in chain(walker.include, walker.excluded):
             self._push(commit_id)
             self._push(commit_id)
 
 
-    def _push(self, object_id):
+    def _push(self, object_id: bytes):
         try:
         try:
             obj = self._store[object_id]
             obj = self._store[object_id]
         except KeyError as exc:
         except KeyError as exc: