Browse Source

miscellaneous typechecking

David Hotham 2 years ago
parent
commit
0075eb8b5b
6 changed files with 69 additions and 24 deletions
  1. 18 6
      dulwich/client.py
  2. 11 5
      dulwich/config.py
  3. 13 1
      dulwich/object_store.py
  4. 2 1
      dulwich/porcelain.py
  5. 7 7
      dulwich/refs.py
  6. 18 4
      dulwich/repo.py

+ 18 - 6
dulwich/client.py

@@ -47,7 +47,7 @@ import shlex
 import socket
 import subprocess
 import sys
-from typing import Optional, Dict, Callable, Set
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple
 
 from urllib.parse import (
     quote as urlquote,
@@ -112,6 +112,7 @@ from dulwich.refs import (
     ANNOTATED_TAG_SUFFIX,
     _import_remote_refs,
 )
+from dulwich.repo import Repo
 
 
 logger = logging.getLogger(__name__)
@@ -500,7 +501,6 @@ class GitClient(object):
               checkout=None, branch=None, progress=None, depth=None):
         """Clone a repository."""
         from .refs import _set_origin_head, _set_default_branch, _set_head
-        from .repo import Repo
 
         if mkdir:
             os.mkdir(target_path)
@@ -522,6 +522,7 @@ class GitClient(object):
             else:
                 encoded_path = self.get_url(path).encode('utf-8')
 
+            assert target is not None
             target_config = target.get_config()
             target_config.set((b"remote", origin.encode('utf-8')), b"url", encoded_path)
             target_config.set(
@@ -564,7 +565,16 @@ class GitClient(object):
             raise
         return target
 
-    def fetch(self, path, target, determine_wants=None, progress=None, depth=None):
+    def fetch(
+        self,
+        path: str,
+        target: Repo,
+        determine_wants: Optional[
+            Callable[[Dict[bytes, bytes], Optional[int]], List[bytes]]
+        ] = None,
+        progress: Optional[Callable[[bytes], None]] = None,
+        depth: Optional[int] = None
+    ) -> FetchPackResult:
         """Fetch into a target repository.
 
         Args:
@@ -1285,7 +1295,7 @@ class SubprocessWrapper(object):
         self.proc.wait()
 
 
-def find_git_command():
+def find_git_command() -> List[str]:
     """Find command to run for system Git (usually C Git)."""
     if sys.platform == "win32":  # support .exe, .bat and .cmd
         try:  # to avoid overhead
@@ -1359,7 +1369,6 @@ class LocalGitClient(GitClient):
 
     @classmethod
     def _open_repo(cls, path):
-        from dulwich.repo import Repo
 
         if not isinstance(path, str):
             path = os.fsdecode(path)
@@ -2268,7 +2277,10 @@ def parse_rsync_url(location):
     return (user, host, path)
 
 
-def get_transport_and_path(location, **kwargs):
+def get_transport_and_path(
+    location: str,
+    **kwargs: Any
+) -> Tuple[GitClient, str]:
     """Obtain a git client from a URL.
 
     Args:

+ 11 - 5
dulwich/config.py

@@ -30,7 +30,7 @@ import os
 import sys
 import warnings
 
-from typing import BinaryIO, Tuple, Optional
+from typing import BinaryIO, Iterator, KeysView, Optional, Tuple, Union
 
 try:
     from collections.abc import (
@@ -87,7 +87,7 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping):
     def __len__(self):
         return len(self._keyed)
 
-    def keys(self):
+    def keys(self) -> KeysView[Tuple[bytes, ...]]:
         return self._keyed.keys()
 
     def items(self):
@@ -241,7 +241,7 @@ class Config(object):
         """
         raise NotImplementedError(self.sections)
 
-    def has_section(self, name):
+    def has_section(self, name: Tuple[bytes, ...]) -> bool:
         """Check if a specified section exists.
 
         Args:
@@ -320,7 +320,11 @@ class ConfigDict(Config, MutableMapping):
 
         return self._values[(section[0],)].get_all(name)
 
-    def get(self, section, name):
+    def get(  # type: ignore[override]
+        self,
+        section: Union[bytes, str, Tuple[Union[bytes, str], ...]],
+        name: Union[str, bytes]
+    ) -> Optional[bytes]:
         section, name = self._check_section_and_name(section, name)
 
         if len(section) > 1:
@@ -679,7 +683,7 @@ class StackedConfig(Config):
         return self.writable.set(section, name, value)
 
 
-def parse_submodules(config):
+def parse_submodules(config: ConfigFile) -> Iterator[Tuple[bytes, bytes, bytes]]:
     """Parse a gitmodules GitConfig file, returning submodules.
 
     Args:
@@ -692,5 +696,7 @@ def parse_submodules(config):
         section_kind, section_name = section
         if section_kind == b"submodule":
             sm_path = config.get(section, b"path")
+            assert sm_path is not None
             sm_url = config.get(section, b"url")
+            assert sm_url is not None
             yield (sm_path, sm_url, section_name)

+ 13 - 1
dulwich/object_store.py

@@ -27,6 +27,8 @@ import os
 import stat
 import sys
 
+from typing import Callable, Dict, List, Optional, Tuple
+
 from dulwich.diff_tree import (
     tree_changes,
     walk_trees,
@@ -79,7 +81,11 @@ PACK_MODE = 0o444 if sys.platform != "win32" else 0o644
 class BaseObjectStore(object):
     """Object store interface."""
 
-    def determine_wants_all(self, refs, depth=None):
+    def determine_wants_all(
+        self,
+        refs: Dict[bytes, bytes],
+        depth: Optional[int] = None
+    ) -> List[bytes]:
         def _want_deepen(sha):
             if not depth:
                 return False
@@ -142,6 +148,12 @@ class BaseObjectStore(object):
         """Iterate over the SHAs that are present in this store."""
         raise NotImplementedError(self.__iter__)
 
+    def add_pack(
+        self
+    ) -> Tuple[BytesIO, Callable[[], None], Callable[[], None]]:
+        """Add a new pack to this object store."""
+        raise NotImplementedError(self.add_pack)
+
     def add_object(self, obj):
         """Add a single object to this object store."""
         raise NotImplementedError(self.add_object)

+ 2 - 1
dulwich/porcelain.py

@@ -1002,6 +1002,7 @@ def get_remote_repo(
     if config.has_section(section):
         remote_name = encoded_location.decode()
         url = config.get(section, "url")
+        assert url is not None
         encoded_location = url
     else:
         remote_name = None
@@ -1614,7 +1615,7 @@ def ls_tree(
         list_tree(r.object_store, tree.id, "")
 
 
-def remote_add(repo, name, url):
+def remote_add(repo: Repo, name: Union[bytes, str], url: Union[bytes, str]):
     """Add a remote.
 
     Args:

+ 7 - 7
dulwich/refs.py

@@ -158,13 +158,13 @@ class RefsContainer(object):
 
     def import_refs(
         self,
-        base,
-        other,
-        committer=None,
-        timestamp=None,
-        timezone=None,
-        message=None,
-        prune=False,
+        base: bytes,
+        other: Dict[bytes, bytes],
+        committer: Optional[bytes] = None,
+        timestamp: Optional[bytes] = None,
+        timezone: Optional[bytes] = None,
+        message: Optional[bytes] = None,
+        prune: bool = False,
     ):
         if prune:
             to_delete = set(self.subkeys(base))

+ 18 - 4
dulwich/repo.py

@@ -1057,7 +1057,12 @@ class Repo(BaseRepo):
       bare (bool): Whether this is a bare repository
     """
 
-    def __init__(self, root, object_store=None, bare=None):
+    def __init__(
+        self,
+        root: str,
+        object_store: Optional[BaseObjectStore] = None,
+        bare: Optional[bool] = None
+    ) -> None:
         hidden_path = os.path.join(root, CONTROLDIR)
         if bare is None:
             if (os.path.isfile(hidden_path) or
@@ -1093,9 +1098,18 @@ class Repo(BaseRepo):
         self.path = root
         config = self.get_config()
         try:
-            format_version = int(config.get("core", "repositoryformatversion"))
+            repository_format_version = config.get(
+                "core",
+                "repositoryformatversion"
+            )
+            format_version = (
+                0
+                if repository_format_version is None
+                else int(repository_format_version)
+            )
         except KeyError:
             format_version = 0
+
         if format_version != 0:
             raise UnsupportedVersion(format_version)
         if object_store is None:
@@ -1485,7 +1499,7 @@ class Repo(BaseRepo):
             raise
         return target
 
-    def reset_index(self, tree=None):
+    def reset_index(self, tree: Optional[Tree] = None):
         """Reset the index back to a specific tree.
 
         Args:
@@ -1569,7 +1583,7 @@ class Repo(BaseRepo):
         return ret
 
     @classmethod
-    def init(cls, path, mkdir=False):
+    def init(cls, path: str, mkdir: bool = False) -> "Repo":
         """Create a new repository.
 
         Args: