Browse Source

Add more typing

Jelmer Vernooij 2 years ago
parent
commit
a451cbdd9f
1 changed files with 35 additions and 26 deletions
  1. 35 26
      dulwich/pack.py

+ 35 - 26
dulwich/pack.py

@@ -96,6 +96,11 @@ DELTA_TYPES = (OFS_DELTA, REF_DELTA)
 DEFAULT_PACK_DELTA_WINDOW_SIZE = 10
 DEFAULT_PACK_DELTA_WINDOW_SIZE = 10
 
 
 
 
+OldUnpackedObject = Union[Tuple[Union[bytes, int], List[bytes]], List[bytes]]
+ResolveExtRefFn = Callable[[bytes], Tuple[int, OldUnpackedObject]]
+ProgressFn = Callable[[int, str], None]
+
+
 class ObjectContainer(Protocol):
 class ObjectContainer(Protocol):
 
 
     def add_object(self, obj: ShaFile) -> None:
     def add_object(self, obj: ShaFile) -> None:
@@ -119,7 +124,7 @@ class ObjectContainer(Protocol):
 
 
 class PackedObjectContainer(ObjectContainer):
 class PackedObjectContainer(ObjectContainer):
 
 
-    def get_raw_unresolved(self, sha1: bytes) -> Tuple[int, Union[bytes, None], List[bytes]]:
+    def get_raw_unresolved(self, sha1: bytes) -> Tuple[int, Optional[bytes], List[bytes]]:
         """Get a raw unresolved object."""
         """Get a raw unresolved object."""
         raise NotImplementedError(self.get_raw_unresolved)
         raise NotImplementedError(self.get_raw_unresolved)
 
 
@@ -204,9 +209,10 @@ class UnpackedObject:
 
 
     # Only provided for backwards compatibility with code that expects either
     # Only provided for backwards compatibility with code that expects either
     # chunks or a delta tuple.
     # chunks or a delta tuple.
-    def _obj(self):
+    def _obj(self) -> OldUnpackedObject:
         """Return the decompressed chunks, or (delta base, delta chunks)."""
         """Return the decompressed chunks, or (delta base, delta chunks)."""
         if self.pack_type_num in DELTA_TYPES:
         if self.pack_type_num in DELTA_TYPES:
+            assert isinstance(self.delta_base, (bytes, int))
             return (self.delta_base, self.decomp_chunks)
             return (self.delta_base, self.decomp_chunks)
         else:
         else:
             return self.decomp_chunks
             return self.decomp_chunks
@@ -384,6 +390,9 @@ def bisect_find_sha(start, end, sha, unpack_name):
     return None
     return None
 
 
 
 
+PackIndexEntry = Tuple[bytes, int, Optional[int]]
+
+
 class PackIndex:
 class PackIndex:
     """An index in to a packfile.
     """An index in to a packfile.
 
 
@@ -405,15 +414,15 @@ class PackIndex:
     def __ne__(self, other):
     def __ne__(self, other):
         return not self.__eq__(other)
         return not self.__eq__(other)
 
 
-    def __len__(self):
+    def __len__(self) -> int:
         """Return the number of entries in this pack index."""
         """Return the number of entries in this pack index."""
         raise NotImplementedError(self.__len__)
         raise NotImplementedError(self.__len__)
 
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[bytes]:
         """Iterate over the SHAs in this pack."""
         """Iterate over the SHAs in this pack."""
         return map(sha_to_hex, self._itersha())
         return map(sha_to_hex, self._itersha())
 
 
-    def iterentries(self):
+    def iterentries(self) -> Iterator[PackIndexEntry]:
         """Iterate over the entries in this pack index.
         """Iterate over the entries in this pack index.
 
 
         Returns: iterator over tuples with object name, offset in packfile and
         Returns: iterator over tuples with object name, offset in packfile and
@@ -421,14 +430,14 @@ class PackIndex:
         """
         """
         raise NotImplementedError(self.iterentries)
         raise NotImplementedError(self.iterentries)
 
 
-    def get_pack_checksum(self):
+    def get_pack_checksum(self) -> bytes:
         """Return the SHA1 checksum stored for the corresponding packfile.
         """Return the SHA1 checksum stored for the corresponding packfile.
 
 
         Returns: 20-byte binary digest
         Returns: 20-byte binary digest
         """
         """
         raise NotImplementedError(self.get_pack_checksum)
         raise NotImplementedError(self.get_pack_checksum)
 
 
-    def object_index(self, sha):
+    def object_index(self, sha: bytes) -> int:
         """Return the index in to the corresponding packfile for the object.
         """Return the index in to the corresponding packfile for the object.
 
 
         Given the name of an object it will return the offset that object
         Given the name of an object it will return the offset that object
@@ -437,7 +446,7 @@ class PackIndex:
         """
         """
         raise NotImplementedError(self.object_index)
         raise NotImplementedError(self.object_index)
 
 
-    def object_sha1(self, index):
+    def object_sha1(self, index: int) -> bytes:
         """Return the SHA1 corresponding to the index in the pack file."""
         """Return the SHA1 corresponding to the index in the pack file."""
         # PERFORMANCE/TODO(jelmer): Avoid scanning entire index
         # PERFORMANCE/TODO(jelmer): Avoid scanning entire index
         for (name, offset, crc32) in self.iterentries():
         for (name, offset, crc32) in self.iterentries():
@@ -446,7 +455,7 @@ class PackIndex:
         else:
         else:
             raise KeyError(index)
             raise KeyError(index)
 
 
-    def _object_index(self, sha):
+    def _object_index(self, sha: bytes) -> int:
         """See object_index.
         """See object_index.
 
 
         Args:
         Args:
@@ -454,14 +463,14 @@ class PackIndex:
         """
         """
         raise NotImplementedError(self._object_index)
         raise NotImplementedError(self._object_index)
 
 
-    def objects_sha1(self):
+    def objects_sha1(self) -> bytes:
         """Return the hex SHA1 over all the shas of all objects in this pack.
         """Return the hex SHA1 over all the shas of all objects in this pack.
 
 
         Note: This is used for the filename of the pack.
         Note: This is used for the filename of the pack.
         """
         """
         return iter_sha1(self._itersha())
         return iter_sha1(self._itersha())
 
 
-    def _itersha(self):
+    def _itersha(self) -> Iterator[bytes]:
         """Yield all the SHA1's of the objects in the index, sorted."""
         """Yield all the SHA1's of the objects in the index, sorted."""
         raise NotImplementedError(self._itersha)
         raise NotImplementedError(self._itersha)
 
 
@@ -566,7 +575,7 @@ class FilePackIndex(PackIndex):
         """Return the number of entries in this pack index."""
         """Return the number of entries in this pack index."""
         return self._fan_out_table[-1]
         return self._fan_out_table[-1]
 
 
-    def _unpack_entry(self, i: int) -> Tuple[bytes, int, Optional[int]]:
+    def _unpack_entry(self, i: int) -> PackIndexEntry:
         """Unpack the i-th entry in the index file.
         """Unpack the i-th entry in the index file.
 
 
         Returns: Tuple with object name (SHA), offset in pack file and CRC32
         Returns: Tuple with object name (SHA), offset in pack file and CRC32
@@ -590,7 +599,7 @@ class FilePackIndex(PackIndex):
         for i in range(len(self)):
         for i in range(len(self)):
             yield self._unpack_name(i)
             yield self._unpack_name(i)
 
 
-    def iterentries(self) -> Iterator[Tuple[bytes, int, Optional[int]]]:
+    def iterentries(self) -> Iterator[PackIndexEntry]:
         """Iterate over the entries in this pack index.
         """Iterate over the entries in this pack index.
 
 
         Returns: iterator over tuples with object name, offset in packfile and
         Returns: iterator over tuples with object name, offset in packfile and
@@ -653,7 +662,7 @@ class FilePackIndex(PackIndex):
                 raise PackFileDisappeared(self) from exc
                 raise PackFileDisappeared(self) from exc
             raise
             raise
 
 
-    def _object_index(self, sha):
+    def _object_index(self, sha: bytes) -> int:
         """See object_index.
         """See object_index.
 
 
         Args:
         Args:
@@ -1181,7 +1190,7 @@ class PackData:
         """
         """
         return compute_file_sha(self._file, end_ofs=-20).digest()
         return compute_file_sha(self._file, end_ofs=-20).digest()
 
 
-    def iterobjects(self, progress=None, compute_crc32=True):
+    def iterobjects(self, progress: Optional[ProgressFn] = None, compute_crc32: bool = True):
         self._file.seek(self._header_size)
         self._file.seek(self._header_size)
         for i in range(1, self._num_objects + 1):
         for i in range(1, self._num_objects + 1):
             offset = self._file.tell()
             offset = self._file.tell()
@@ -1215,7 +1224,7 @@ class PackData:
             # Back up over unused data.
             # Back up over unused data.
             self._file.seek(-len(unused), SEEK_CUR)
             self._file.seek(-len(unused), SEEK_CUR)
 
 
-    def iterentries(self, progress=None, resolve_ext_ref=None):
+    def iterentries(self, progress: Optional[ProgressFn] = None, resolve_ext_ref: Optional[ResolveExtRefFn] = None):
         """Yield entries summarizing the contents of this pack.
         """Yield entries summarizing the contents of this pack.
 
 
         Args:
         Args:
@@ -1230,7 +1239,7 @@ class PackData:
                 progress(i, num_objects)
                 progress(i, num_objects)
             yield result
             yield result
 
 
-    def sorted_entries(self, progress=None, resolve_ext_ref=None):
+    def sorted_entries(self, progress: Optional[ProgressFn] = None, resolve_ext_ref: Optional[ResolveExtRefFn] = None):
         """Return entries in this pack, sorted by SHA.
         """Return entries in this pack, sorted by SHA.
 
 
         Args:
         Args:
@@ -1332,7 +1341,7 @@ class PackData:
             unpacked.decomp_chunks,
             unpacked.decomp_chunks,
         )
         )
 
 
-    def get_object_at(self, offset: int):
+    def get_object_at(self, offset: int) -> Tuple[int, OldUnpackedObject]:
         """Given an offset in to the packfile return the object that is there.
         """Given an offset in to the packfile return the object that is there.
 
 
         Using the associated index the location of an object can be looked up,
         Using the associated index the location of an object can be looked up,
@@ -2080,7 +2089,8 @@ def apply_delta(src_buf, delta):
     return out
     return out
 
 
 
 
-def write_pack_index_v2(f, entries, pack_checksum):
+def write_pack_index_v2(
+        f, entries: Iterable[PackIndexEntry], pack_checksum: bytes) -> bytes:
     """Write a new pack index file.
     """Write a new pack index file.
 
 
     Args:
     Args:
@@ -2093,7 +2103,7 @@ def write_pack_index_v2(f, entries, pack_checksum):
     f = SHA1Writer(f)
     f = SHA1Writer(f)
     f.write(b"\377tOc")  # Magic!
     f.write(b"\377tOc")  # Magic!
     f.write(struct.pack(">L", 2))
     f.write(struct.pack(">L", 2))
-    fan_out_table = defaultdict(lambda: 0)
+    fan_out_table: Dict[int, int] = defaultdict(lambda: 0)
     for (name, offset, entry_checksum) in entries:
     for (name, offset, entry_checksum) in entries:
         fan_out_table[ord(name[:1])] += 1
         fan_out_table[ord(name[:1])] += 1
     # Fan-out table
     # Fan-out table
@@ -2144,8 +2154,7 @@ class Pack:
     _data: Optional[PackData]
     _data: Optional[PackData]
     _idx: Optional[PackIndex]
     _idx: Optional[PackIndex]
 
 
-    def __init__(self, basename, resolve_ext_ref: Optional[
-            Callable[[bytes], Tuple[int, UnpackedObject]]] = None):
+    def __init__(self, basename, resolve_ext_ref: Optional[ResolveExtRefFn] = None):
         self._basename = basename
         self._basename = basename
         self._data = None
         self._data = None
         self._idx = None
         self._idx = None
@@ -2300,7 +2309,7 @@ class Pack:
 
 
         return _PackTupleIterable(self.iterobjects, len(self))
         return _PackTupleIterable(self.iterobjects, len(self))
 
 
-    def keep(self, msg=None):
+    def keep(self, msg: Optional[bytes] = None) -> str:
         """Add a .keep file for the pack, preventing git from garbage collecting it.
         """Add a .keep file for the pack, preventing git from garbage collecting it.
 
 
         Args:
         Args:
@@ -2315,7 +2324,7 @@ class Pack:
                 keepfile.write(b"\n")
                 keepfile.write(b"\n")
         return keepfile_name
         return keepfile_name
 
 
-    def get_ref(self, sha: bytes) -> Tuple[int, int, UnpackedObject]:
+    def get_ref(self, sha: bytes) -> Tuple[Optional[int], int, OldUnpackedObject]:
         """Get the object for a ref SHA, only looking in this pack."""
         """Get the object for a ref SHA, only looking in this pack."""
         # TODO: cache these results
         # TODO: cache these results
         try:
         try:
@@ -2372,7 +2381,7 @@ class Pack:
                 self.data._offset_cache[prev_offset] = base_type, chunks
                 self.data._offset_cache[prev_offset] = base_type, chunks
         return base_type, chunks
         return base_type, chunks
 
 
-    def entries(self, progress=None):
+    def entries(self, progress: Optional[ProgressFn] = None) -> Iterator[PackIndexEntry]:
         """Yield entries summarizing the contents of this pack.
         """Yield entries summarizing the contents of this pack.
 
 
         Args:
         Args:
@@ -2383,7 +2392,7 @@ class Pack:
         return self.data.iterentries(
         return self.data.iterentries(
             progress=progress, resolve_ext_ref=self.resolve_ext_ref)
             progress=progress, resolve_ext_ref=self.resolve_ext_ref)
 
 
-    def sorted_entries(self, progress=None):
+    def sorted_entries(self, progress: Optional[ProgressFn] = None) -> Iterator[PackIndexEntry]:
         """Return entries in this pack, sorted by SHA.
         """Return entries in this pack, sorted by SHA.
 
 
         Args:
         Args: