Browse Source

Add more typing

Jelmer Vernooij 1 year ago
parent
commit
e6ccc98835

+ 7 - 5
dulwich/client.py

@@ -190,7 +190,7 @@ class ReportStatusParser:
     def __init__(self) -> None:
         self._done = False
         self._pack_status = None
-        self._ref_statuses = []
+        self._ref_statuses: List[bytes] = []
 
     def check(self):
         """Check if there were any errors and, if so, raise exceptions.
@@ -427,8 +427,8 @@ def _read_shallow_updates(pkt_seq):
 class _v1ReceivePackHeader:
 
     def __init__(self, capabilities, old_refs, new_refs) -> None:
-        self.want = []
-        self.have = []
+        self.want: List[bytes] = []
+        self.have: List[bytes] = []
         self._it = self._handle_receive_pack_head(capabilities, old_refs, new_refs)
         self.sent_capabilities = False
 
@@ -646,7 +646,7 @@ class GitClient:
             to
         """
         self._report_activity = report_activity
-        self._report_status_parser = None
+        self._report_status_parser: Optional[ReportStatusParser] = None
         self._fetch_capabilities = set(UPLOAD_CAPABILITIES)
         self._fetch_capabilities.add(capability_agent())
         self._send_capabilities = set(RECEIVE_CAPABILITIES)
@@ -915,6 +915,7 @@ class GitClient:
                     pass
 
             if CAPABILITY_REPORT_STATUS in capabilities:
+                assert self._report_status_parser is not None
                 pktline_parser = PktLineParser(self._report_status_parser.handle_packet)
             for chan, data in _read_side_band64k_data(proto.read_pkt_seq()):
                 if chan == SIDE_BAND_CHANNEL_DATA:
@@ -927,6 +928,7 @@ class GitClient:
                         "Invalid sideband channel %d" % chan)
         else:
             if CAPABILITY_REPORT_STATUS in capabilities:
+                assert self._report_status_parser
                 for pkt in proto.read_pkt_seq():
                     self._report_status_parser.handle_packet(pkt)
         if self._report_status_parser is not None:
@@ -1729,7 +1731,7 @@ class SSHGitClient(TraditionalGitClient):
             "GIT_SSH_COMMAND", os.environ.get("GIT_SSH")
         )
         super().__init__(**kwargs)
-        self.alternative_paths = {}
+        self.alternative_paths: Dict[bytes, bytes] = {}
         if vendor is not None:
             self.ssh_vendor = vendor
         else:

+ 4 - 2
dulwich/config.py

@@ -40,6 +40,8 @@ from typing import (
     Tuple,
     Union,
     overload,
+    Any,
+    Dict,
 )
 
 from .file import GitFile
@@ -60,8 +62,8 @@ def lower_key(key):
 class CaseInsensitiveOrderedMultiDict(MutableMapping):
 
     def __init__(self) -> None:
-        self._real = []
-        self._keyed = {}
+        self._real: List[Any] = []
+        self._keyed: Dict[Any, Any] = {}
 
     @classmethod
     def make(cls, dict_in=None):

+ 4 - 3
dulwich/fastexport.py

@@ -22,6 +22,7 @@
 """Fast export/import functionality."""
 
 import stat
+from typing import Dict, Tuple
 
 from fastimport import commands, parser, processor
 from fastimport import errors as fastimport_errors
@@ -42,7 +43,7 @@ class GitFastExporter:
     def __init__(self, outf, store) -> None:
         self.outf = outf
         self.store = store
-        self.markers = {}
+        self.markers: Dict[bytes, bytes] = {}
         self._marker_idx = 0
 
     def print_cmd(self, cmd):
@@ -125,8 +126,8 @@ class GitImportProcessor(processor.ImportProcessor):
         processor.ImportProcessor.__init__(self, params, verbose)
         self.repo = repo
         self.last_commit = ZERO_SHA
-        self.markers = {}
-        self._contents = {}
+        self.markers: Dict[bytes, bytes] = {}
+        self._contents: Dict[bytes, Tuple[int, bytes]] = {}
 
     def lookup_object(self, objectish):
         if objectish.startswith(b":"):

+ 5 - 3
dulwich/greenthreads.py

@@ -25,12 +25,14 @@
 import gevent
 from gevent import pool
 
+from typing import Set, Tuple, Optional, FrozenSet
+
 from .object_store import (
     MissingObjectFinder,
     _collect_ancestors,
     _collect_filetree_revs,
 )
-from .objects import Commit, Tag
+from .objects import Commit, Tag, ObjectID
 
 
 def _split_commits_and_tags(obj_store, lst, *, ignore_unknown=False, pool=None):
@@ -89,7 +91,7 @@ class GreenThreadsMissingObjectFinder(MissingObjectFinder):
 
         have_commits, have_tags = _split_commits_and_tags(object_store, haves, ignore_unknown=True, pool=p)
         want_commits, want_tags = _split_commits_and_tags(object_store, wants, ignore_unknown=False, pool=p)
-        all_ancestors = _collect_ancestors(object_store, have_commits)[0]
+        all_ancestors: FrozenSet[ObjectID] = frozenset(_collect_ancestors(object_store, have_commits)[0])
         missing_commits, common_commits = _collect_ancestors(
             object_store, want_commits, all_ancestors
         )
@@ -101,7 +103,7 @@ class GreenThreadsMissingObjectFinder(MissingObjectFinder):
             self.sha_done.add(t)
         missing_tags = want_tags.difference(have_tags)
         wants = missing_commits.union(missing_tags)
-        self.objects_to_send = {(w, None, False) for w in wants}
+        self.objects_to_send: Set[Tuple[ObjectID, Optional[bytes], Optional[int], bool]] = {(w, None, 0, False) for w in wants}
         if progress is None:
             self.progress = lambda x: None
         else:

+ 3 - 1
dulwich/mailmap.py

@@ -20,6 +20,8 @@
 
 """Mailmap file reader."""
 
+from typing import Dict, Tuple, Optional
+
 
 def parse_identity(text):
     # TODO(jelmer): Integrate this with dulwich.fastexport.split_email and
@@ -62,7 +64,7 @@ class Mailmap:
     """Class for accessing a mailmap file."""
 
     def __init__(self, map=None) -> None:
-        self._table = {}
+        self._table: Dict[Tuple[Optional[str], str], Tuple[str, str]] = {}
         if map:
             for (canonical_identity, from_identity) in map:
                 self.add_entry(canonical_identity, from_identity)

+ 13 - 8
dulwich/object_store.py

@@ -31,6 +31,7 @@ from io import BytesIO
 from typing import (
     Callable,
     Dict,
+    FrozenSet,
     Iterable,
     Iterator,
     List,
@@ -360,7 +361,7 @@ class BaseObjectStore:
 
 class PackBasedObjectStore(BaseObjectStore):
     def __init__(self, pack_compression_level=-1) -> None:
-        self._pack_cache = {}
+        self._pack_cache: Dict[str, Pack] = {}
         self.pack_compression_level = pack_compression_level
 
     def add_pack(
@@ -995,7 +996,7 @@ class MemoryObjectStore(BaseObjectStore):
 
     def __init__(self) -> None:
         super().__init__()
-        self._data = {}
+        self._data: Dict[str, ShaFile] = {}
         self.pack_compression_level = -1
 
     def _to_hexsha(self, sha):
@@ -1269,7 +1270,7 @@ class MissingObjectFinder:
 
         # in fact, what we 'want' is commits, tags, and others
         # we've found missing
-        self.objects_to_send = {
+        self.objects_to_send: Set[Tuple[ObjectID, Optional[bytes], Optional[int], bool]] = {
             (w, None, Commit.type_num, False)
             for w in missing_commits}
         missing_tags = want_tags.difference(have_tags)
@@ -1293,7 +1294,7 @@ class MissingObjectFinder:
     def add_todo(self, entries: Iterable[Tuple[ObjectID, Optional[bytes], Optional[int], bool]]):
         self.objects_to_send.update([e for e in entries if e[0] not in self.sha_done])
 
-    def __next__(self) -> Tuple[bytes, PackHint]:
+    def __next__(self) -> Tuple[bytes, Optional[PackHint]]:
         while True:
             if not self.objects_to_send:
                 self.progress(("counting objects: %d, done.\n" % len(self.sha_done)).encode("ascii"))
@@ -1321,7 +1322,11 @@ class MissingObjectFinder:
         self.sha_done.add(sha)
         if len(self.sha_done) % 1000 == 0:
             self.progress(("counting objects: %d\r" % len(self.sha_done)).encode("ascii"))
-        return (sha, (type_num, name))
+        if type_num is None:
+            pack_hint = None
+        else:
+            pack_hint = (type_num, name)
+        return (sha, pack_hint)
 
     def __iter__(self):
         return self
@@ -1344,7 +1349,7 @@ class ObjectStoreGraphWalker:
         """
         self.heads = set(local_heads)
         self.get_parents = get_parents
-        self.parents = {}
+        self.parents: Dict[ObjectID, Optional[List[ObjectID]]] = {}
         if shallow is None:
             shallow = set()
         self.shallow = shallow
@@ -1610,8 +1615,8 @@ class BucketBasedObjectStore(PackBasedObjectStore):
 def _collect_ancestors(
     store: ObjectContainer,
     heads,
-    common=frozenset(),
-    shallow=frozenset(),
+    common: FrozenSet[ObjectID] = frozenset(),
+    shallow: FrozenSet[ObjectID] = frozenset(),
     get_parents=lambda commit: commit.parents,
 ):
     """Collect all ancestors of heads up to (excluding) those in common.

+ 6 - 4
dulwich/objects.py

@@ -1076,7 +1076,7 @@ class Tree(ShaFile):
 
     def __init__(self) -> None:
         super().__init__()
-        self._entries = {}
+        self._entries: Dict[bytes, Tuple[int, bytes]] = {}
 
     @classmethod
     def from_path(cls, filename):
@@ -1381,11 +1381,11 @@ class Commit(ShaFile):
 
     def __init__(self) -> None:
         super().__init__()
-        self._parents = []
+        self._parents: List[bytes] = []
         self._encoding = None
-        self._mergetag = []
+        self._mergetag: List[Tag] = []
         self._gpgsig = None
-        self._extra = []
+        self._extra: List[Tuple[bytes, bytes]] = []
         self._author_timezone_neg_utc = False
         self._commit_timezone_neg_utc = False
 
@@ -1412,6 +1412,7 @@ class Commit(ShaFile):
             if field == _TREE_HEADER:
                 self._tree = value
             elif field == _PARENT_HEADER:
+                assert value is not None
                 self._parents.append(value)
             elif field == _AUTHOR_HEADER:
                 author_info = parse_time_entry(value)
@@ -1420,6 +1421,7 @@ class Commit(ShaFile):
             elif field == _ENCODING_HEADER:
                 self._encoding = value
             elif field == _MERGETAG_HEADER:
+                assert value is not None
                 self._mergetag.append(Tag.from_string(value + b"\n"))
             elif field == _GPGSIG_HEADER:
                 self._gpgsig = value

+ 5 - 4
dulwich/pack.py

@@ -201,6 +201,7 @@ class UnpackedObject:
     obj_chunks: Optional[List[bytes]]
     delta_base: Union[None, bytes, int]
     decomp_chunks: List[bytes]
+    comp_chunks: Optional[List[bytes]]
 
     # TODO(dborowitz): read_zlib_chunks and unpack_object could very well be
     # methods of this object.
@@ -1167,7 +1168,7 @@ class PackData:
         else:
             self._file = file
         (version, self._num_objects) = read_pack_header(self._file.read)
-        self._offset_cache = LRUSizeCache(
+        self._offset_cache = LRUSizeCache[int, Tuple[int, OldUnpackedObject]](
             1024 * 1024 * 20, compute_size=_compute_object_size
         )
 
@@ -1239,7 +1240,7 @@ class PackData:
             # Back up over unused data.
             self._file.seek(-len(unused), SEEK_CUR)
 
-    def iterentries(self, progress: Optional[ProgressFn] = None, resolve_ext_ref: Optional[ResolveExtRefFn] = None):
+    def iterentries(self, progress=None, resolve_ext_ref: Optional[ResolveExtRefFn] = None):
         """Yield entries summarizing the contents of this pack.
 
         Args:
@@ -1957,7 +1958,7 @@ class PackChunkGenerator:
 
     def __init__(self, num_records=None, records=None, progress=None, compression_level=-1, reuse_compressed=True) -> None:
         self.cs = sha1(b"")
-        self.entries = {}
+        self.entries: Dict[Union[int, bytes], Tuple[int, int]] = {}
         self._it = self._pack_data_chunks(
             num_records=num_records, records=records, progress=progress, compression_level=compression_level, reuse_compressed=reuse_compressed)
 
@@ -2607,7 +2608,7 @@ def extend_pack(f: BinaryIO, object_ids: Set[ObjectID], get_raw, *, compression_
 
 
 try:
-    from dulwich._pack import (
+    from dulwich._pack import (  # type: ignore  # noqa: F811
         apply_delta,  # type: ignore # noqa: F811
         bisect_find_sha,  # type: ignore # noqa: F811
     )

+ 3 - 3
dulwich/refs.py

@@ -23,7 +23,7 @@
 import os
 import warnings
 from contextlib import suppress
-from typing import Dict, Optional
+from typing import Dict, Optional, Set, Any
 
 from .errors import PackedRefsException, RefFormatError
 from .file import GitFile, ensure_dir_exists
@@ -442,8 +442,8 @@ class DictRefsContainer(RefsContainer):
     def __init__(self, refs, logger=None) -> None:
         super().__init__(logger=logger)
         self._refs = refs
-        self._peeled = {}
-        self._watchers = set()
+        self._peeled: Dict[bytes, ObjectID] = {}
+        self._watchers: Set[Any] = set()
 
     def allkeys(self):
         return self._refs.keys()

+ 4 - 3
dulwich/repo.py

@@ -46,6 +46,7 @@ from typing import (
     Set,
     Tuple,
     Union,
+    Any
 )
 
 if TYPE_CHECKING:
@@ -1797,10 +1798,10 @@ class MemoryRepo(BaseRepo):
     def __init__(self) -> None:
         from .config import ConfigFile
 
-        self._reflog = []
+        self._reflog: List[Any] = []
         refs_container = DictRefsContainer({}, logger=self._append_reflog)
-        BaseRepo.__init__(self, MemoryObjectStore(), refs_container)
-        self._named_files = {}
+        BaseRepo.__init__(self, MemoryObjectStore(), refs_container)  # type: ignore
+        self._named_files: Dict[str, bytes] = {}
         self.bare = True
         self._config = ConfigFile()
         self._description = None

+ 4 - 4
dulwich/server.py

@@ -227,7 +227,7 @@ class PackHandler(Handler):
 
     def __init__(self, backend, proto, stateless_rpc=False) -> None:
         super().__init__(backend, proto, stateless_rpc)
-        self._client_capabilities = None
+        self._client_capabilities: Optional[Set[bytes]] = None
         # Flags needed for the no-done capability
         self._done_received = False
 
@@ -763,7 +763,7 @@ class SingleAckGraphWalkerImpl:
 
     def __init__(self, walker) -> None:
         self.walker = walker
-        self._common = []
+        self._common: List[bytes] = []
 
     def ack(self, have_ref):
         if not self._common:
@@ -808,7 +808,7 @@ class MultiAckGraphWalkerImpl:
     def __init__(self, walker) -> None:
         self.walker = walker
         self._found_base = False
-        self._common = []
+        self._common: List[bytes] = []
 
     def ack(self, have_ref):
         self._common.append(have_ref)
@@ -866,7 +866,7 @@ class MultiAckDetailedGraphWalkerImpl:
 
     def __init__(self, walker) -> None:
         self.walker = walker
-        self._common = []
+        self._common: List[bytes] = []
 
     def ack(self, have_ref):
         # Should only be called iff have_ref is common

+ 2 - 1
dulwich/tests/test_client.py

@@ -23,6 +23,7 @@ import os
 import shutil
 import sys
 import tempfile
+from typing import Dict
 import warnings
 from io import BytesIO
 from unittest.mock import patch
@@ -1090,7 +1091,7 @@ class HttpGitClientTests(TestCase):
         # otherwise without an active internet connection
         class PoolManagerMock:
             def __init__(self) -> None:
-                self.headers = {}
+                self.headers: Dict[str, str] = {}
 
             def request(self, method, url, fields=None, headers=None, redirect=True, preload_content=True):
                 base_url = url[: -len(tail)]

+ 2 - 1
dulwich/tests/test_pack.py

@@ -29,6 +29,7 @@ import tempfile
 import zlib
 from hashlib import sha1
 from io import BytesIO
+from typing import Set
 
 from dulwich.tests import TestCase
 
@@ -943,7 +944,7 @@ class TestPackIterator(DeltaChainIterator):
 
     def __init__(self, *args, **kwargs) -> None:
         super().__init__(*args, **kwargs)
-        self._unpacked_offsets = set()
+        self._unpacked_offsets: Set[int] = set()
 
     def _result(self, unpacked):
         """Return entries in the same format as build_pack."""

+ 5 - 4
dulwich/tests/test_server.py

@@ -25,6 +25,7 @@ import shutil
 import sys
 import tempfile
 from io import BytesIO
+from typing import Dict, List
 
 from dulwich.tests import TestCase
 
@@ -66,8 +67,8 @@ SIX = b"6" * 40
 
 class TestProto:
     def __init__(self) -> None:
-        self._output = []
-        self._received = {0: [], 1: [], 2: [], 3: []}
+        self._output: List[bytes] = []
+        self._received: Dict[int, List[bytes]] = {0: [], 1: [], 2: [], 3: []}
 
     def set_output(self, output_lines):
         self._output = output_lines
@@ -588,8 +589,8 @@ class ProtocolGraphWalkerTestCase(TestCase):
 
 class TestProtocolGraphWalker:
     def __init__(self) -> None:
-        self.acks = []
-        self.lines = []
+        self.acks: List[bytes] = []
+        self.lines: List[bytes] = []
         self.wants_satisified = False
         self.stateless_rpc = None
         self.advertise_refs = False

+ 3 - 2
dulwich/walk.py

@@ -24,11 +24,12 @@
 import collections
 import heapq
 from itertools import chain
-from typing import Deque, List, Optional, Set, Tuple
+from typing import Deque, List, Optional, Set, Tuple, Dict
 
 from .diff_tree import (
     RENAME_CHANGE_TYPES,
     RenameDetector,
+    TreeChange,
     tree_changes,
     tree_changes_for_merge,
 )
@@ -51,7 +52,7 @@ class WalkEntry:
         self.commit = commit
         self._store = walker.store
         self._get_parents = walker.get_parents
-        self._changes = {}
+        self._changes: Dict[str, List[TreeChange]] = {}
         self._rename_detector = walker.rename_detector
 
     def changes(self, path_prefix=None):

+ 1 - 1
dulwich/web.py

@@ -258,7 +258,7 @@ class ChunkReader:
 
     def __init__(self, f) -> None:
         self._iter = _chunk_iter(f)
-        self._buffer = []
+        self._buffer: List[bytes] = []
 
     def read(self, n):
         while sum(map(len, self._buffer)) < n: