Browse Source

Add more typing

Jelmer Vernooij 2 years ago
parent
commit
b6b217bdb9
11 changed files with 149 additions and 80 deletions
  1. 5 5
      dulwich/config.py
  2. 6 2
      dulwich/contrib/diffstat.py
  3. 4 2
      dulwich/diff_tree.py
  4. 2 0
      dulwich/errors.py
  5. 2 1
      dulwich/graph.py
  6. 56 32
      dulwich/lru_cache.py
  7. 18 4
      dulwich/objects.py
  8. 46 27
      dulwich/pack.py
  9. 7 6
      dulwich/patch.py
  10. 2 1
      dulwich/refs.py
  11. 1 0
      setup.cfg

+ 5 - 5
dulwich/config.py

@@ -52,7 +52,7 @@ def lower_key(key):
         return key.lower()
         return key.lower()
 
 
     if isinstance(key, Iterable):
     if isinstance(key, Iterable):
-        return type(key)(map(lower_key, key))
+        return type(key)(map(lower_key, key))  # type: ignore
 
 
     return key
     return key
 
 
@@ -651,11 +651,11 @@ def _find_git_in_win_reg():
             "Uninstall\\Git_is1"
             "Uninstall\\Git_is1"
         )
         )
 
 
-    for key in (winreg.HKEY_CURRENT_USER, winreg.HKEY_LOCAL_MACHINE):
+    for key in (winreg.HKEY_CURRENT_USER, winreg.HKEY_LOCAL_MACHINE):  # type: ignore
         try:
         try:
-            with winreg.OpenKey(key, subkey) as k:
-                val, typ = winreg.QueryValueEx(k, "InstallLocation")
-                if typ == winreg.REG_SZ:
+            with winreg.OpenKey(key, subkey) as k:  # type: ignore
+                val, typ = winreg.QueryValueEx(k, "InstallLocation")  # type: ignore
+                if typ == winreg.REG_SZ:  # type: ignore
                     yield val
                     yield val
         except OSError:
         except OSError:
             pass
             pass

+ 6 - 2
dulwich/contrib/diffstat.py

@@ -35,6 +35,7 @@
 
 
 import sys
 import sys
 import re
 import re
+from typing import Optional
 
 
 # only needs to detect git style diffs as this is for
 # only needs to detect git style diffs as this is for
 # use with dulwich
 # use with dulwich
@@ -66,7 +67,7 @@ def _parse_patch(lines):
     nametypes = []
     nametypes = []
     counts = []
     counts = []
     in_patch_chunk = in_git_header = binaryfile = False
     in_patch_chunk = in_git_header = binaryfile = False
-    currentfile = None
+    currentfile: Optional[bytes] = None
     added = deleted = 0
     added = deleted = 0
     for line in lines:
     for line in lines:
         if line.startswith(_GIT_HEADER_START):
         if line.startswith(_GIT_HEADER_START):
@@ -74,7 +75,9 @@ def _parse_patch(lines):
                 names.append(currentfile)
                 names.append(currentfile)
                 nametypes.append(binaryfile)
                 nametypes.append(binaryfile)
                 counts.append((added, deleted))
                 counts.append((added, deleted))
-            currentfile = _git_header_name.search(line).group(2)
+            m = _git_header_name.search(line)
+            assert m
+            currentfile = m.group(2)
             binaryfile = False
             binaryfile = False
             added = deleted = 0
             added = deleted = 0
             in_git_header = True
             in_git_header = True
@@ -85,6 +88,7 @@ def _parse_patch(lines):
         elif line.startswith(_GIT_RENAMEFROM_START) and in_git_header:
         elif line.startswith(_GIT_RENAMEFROM_START) and in_git_header:
             currentfile = line[12:]
             currentfile = line[12:]
         elif line.startswith(_GIT_RENAMETO_START) and in_git_header:
         elif line.startswith(_GIT_RENAMETO_START) and in_git_header:
+            assert currentfile
             currentfile += b" => %s" % line[10:]
             currentfile += b" => %s" % line[10:]
         elif line.startswith(_GIT_CHUNK_START) and (in_patch_chunk or in_git_header):
         elif line.startswith(_GIT_CHUNK_START) and (in_patch_chunk or in_git_header):
             in_patch_chunk = True
             in_patch_chunk = True

+ 4 - 2
dulwich/diff_tree.py

@@ -28,10 +28,12 @@ from collections import (
 from io import BytesIO
 from io import BytesIO
 from itertools import chain
 from itertools import chain
 import stat
 import stat
+from typing import List
 
 
 from dulwich.objects import (
 from dulwich.objects import (
     S_ISGITLINK,
     S_ISGITLINK,
     TreeEntry,
     TreeEntry,
+    Tree,
 )
 )
 
 
 
 
@@ -65,8 +67,8 @@ class TreeChange(namedtuple("TreeChange", ["type", "old", "new"])):
         return cls(CHANGE_DELETE, old, _NULL_ENTRY)
         return cls(CHANGE_DELETE, old, _NULL_ENTRY)
 
 
 
 
-def _tree_entries(path, tree):
-    result = []
+def _tree_entries(path: str, tree: Tree) -> List[TreeEntry]:
+    result: List[TreeEntry] = []
     if not tree:
     if not tree:
         return result
         return result
     for entry in tree.iteritems(name_order=True):
     for entry in tree.iteritems(name_order=True):

+ 2 - 0
dulwich/errors.py

@@ -61,6 +61,8 @@ class WrongObjectException(Exception):
     was expected if they were raised.
     was expected if they were raised.
     """
     """
 
 
+    type_name: str
+
     def __init__(self, sha, *args, **kwargs):
     def __init__(self, sha, *args, **kwargs):
         Exception.__init__(self, "%s is not a %s" % (sha, self.type_name))
         Exception.__init__(self, "%s is not a %s" % (sha, self.type_name))
 
 

+ 2 - 1
dulwich/graph.py

@@ -23,6 +23,7 @@
 Implementation of merge-base following the approach of git
 Implementation of merge-base following the approach of git
 """
 """
 
 
+from typing import Deque
 from collections import deque
 from collections import deque
 
 
 
 
@@ -44,7 +45,7 @@ def _find_lcas(lookup_parents, c1, c2s):
         return False
         return False
 
 
     # initialize the working list
     # initialize the working list
-    wlst = deque()
+    wlst: Deque[int] = deque()
     cstates[c1] = _ANC_OF_1
     cstates[c1] = _ANC_OF_1
     wlst.append(c1)
     wlst.append(c1)
     for c2 in c2s:
     for c2 in c2s:

+ 56 - 32
dulwich/lru_cache.py

@@ -1,5 +1,6 @@
 # lru_cache.py -- Simple LRU cache for dulwich
 # lru_cache.py -- Simple LRU cache for dulwich
 # Copyright (C) 2006, 2008 Canonical Ltd
 # Copyright (C) 2006, 2008 Canonical Ltd
+# Copyright (C) 2022 Jelmer Vernooij <jelmer@jelmer.uk>
 #
 #
 # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
 # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
 # General Public License as public by the Free Software Foundation; version 2.0
 # General Public License as public by the Free Software Foundation; version 2.0
@@ -20,17 +21,28 @@
 
 
 """A simple least-recently-used (LRU) cache."""
 """A simple least-recently-used (LRU) cache."""
 
 
+from typing import Generic, TypeVar, Optional, Callable, Dict, Iterable, Iterator
+
+
 _null_key = object()
 _null_key = object()
 
 
 
 
-class _LRUNode(object):
+K = TypeVar('K')
+V = TypeVar('V')
+
+
+class _LRUNode(Generic[K, V]):
     """This maintains the linked-list which is the lru internals."""
     """This maintains the linked-list which is the lru internals."""
 
 
     __slots__ = ("prev", "next_key", "key", "value", "cleanup", "size")
     __slots__ = ("prev", "next_key", "key", "value", "cleanup", "size")
 
 
-    def __init__(self, key, value, cleanup=None):
+    prev: Optional["_LRUNode[K, V]"]
+    next_key: K
+    size: Optional[int]
+
+    def __init__(self, key: K, value: V, cleanup=None):
         self.prev = None
         self.prev = None
-        self.next_key = _null_key
+        self.next_key = _null_key  # type: ignore
         self.key = key
         self.key = key
         self.value = value
         self.value = value
         self.cleanup = cleanup
         self.cleanup = cleanup
@@ -59,21 +71,24 @@ class _LRUNode(object):
         self.value = None
         self.value = None
 
 
 
 
-class LRUCache(object):
+class LRUCache(Generic[K, V]):
     """A class which manages a cache of entries, removing unused ones."""
     """A class which manages a cache of entries, removing unused ones."""
 
 
-    def __init__(self, max_cache=100, after_cleanup_count=None):
-        self._cache = {}
+    _least_recently_used: Optional[_LRUNode[K, V]]
+    _most_recently_used: Optional[_LRUNode[K, V]]
+
+    def __init__(self, max_cache: int = 100, after_cleanup_count: Optional[int] = None) -> None:
+        self._cache: Dict[K, _LRUNode[K, V]] = {}
         # The "HEAD" of the lru linked list
         # The "HEAD" of the lru linked list
         self._most_recently_used = None
         self._most_recently_used = None
         # The "TAIL" of the lru linked list
         # The "TAIL" of the lru linked list
         self._least_recently_used = None
         self._least_recently_used = None
         self._update_max_cache(max_cache, after_cleanup_count)
         self._update_max_cache(max_cache, after_cleanup_count)
 
 
-    def __contains__(self, key):
+    def __contains__(self, key: K) -> bool:
         return key in self._cache
         return key in self._cache
 
 
-    def __getitem__(self, key):
+    def __getitem__(self, key: K) -> V:
         cache = self._cache
         cache = self._cache
         node = cache[key]
         node = cache[key]
         # Inlined from _record_access to decrease the overhead of __getitem__
         # Inlined from _record_access to decrease the overhead of __getitem__
@@ -96,6 +111,8 @@ class LRUCache(object):
         else:
         else:
             node_next = cache[next_key]
             node_next = cache[next_key]
             node_next.prev = node_prev
             node_next.prev = node_prev
+        assert node_prev
+        assert mru
         node_prev.next_key = next_key
         node_prev.next_key = next_key
         # Insert this node at the front of the list
         # Insert this node at the front of the list
         node.next_key = mru.key
         node.next_key = mru.key
@@ -104,10 +121,10 @@ class LRUCache(object):
         node.prev = None
         node.prev = None
         return node.value
         return node.value
 
 
-    def __len__(self):
+    def __len__(self) -> int:
         return len(self._cache)
         return len(self._cache)
 
 
-    def _walk_lru(self):
+    def _walk_lru(self) -> Iterator[_LRUNode[K, V]]:
         """Walk the LRU list, only meant to be used in tests."""
         """Walk the LRU list, only meant to be used in tests."""
         node = self._most_recently_used
         node = self._most_recently_used
         if node is not None:
         if node is not None:
@@ -144,7 +161,7 @@ class LRUCache(object):
             yield node
             yield node
             node = node_next
             node = node_next
 
 
-    def add(self, key, value, cleanup=None):
+    def add(self, key: K, value: V, cleanup: Optional[Callable[[K, V], None]] = None) -> None:
         """Add a new value to the cache.
         """Add a new value to the cache.
 
 
         Also, if the entry is ever removed from the cache, call
         Also, if the entry is ever removed from the cache, call
@@ -172,18 +189,18 @@ class LRUCache(object):
             # Trigger the cleanup
             # Trigger the cleanup
             self.cleanup()
             self.cleanup()
 
 
-    def cache_size(self):
+    def cache_size(self) -> int:
         """Get the number of entries we will cache."""
         """Get the number of entries we will cache."""
         return self._max_cache
         return self._max_cache
 
 
-    def get(self, key, default=None):
+    def get(self, key: K, default: Optional[V] = None) -> Optional[V]:
         node = self._cache.get(key, None)
         node = self._cache.get(key, None)
         if node is None:
         if node is None:
             return default
             return default
         self._record_access(node)
         self._record_access(node)
         return node.value
         return node.value
 
 
-    def keys(self):
+    def keys(self) -> Iterable[K]:
         """Get the list of keys currently cached.
         """Get the list of keys currently cached.
 
 
         Note that values returned here may not be available by the time you
         Note that values returned here may not be available by the time you
@@ -194,7 +211,7 @@ class LRUCache(object):
         """
         """
         return self._cache.keys()
         return self._cache.keys()
 
 
-    def items(self):
+    def items(self) -> Dict[K, V]:
         """Get the key:value pairs as a dict."""
         """Get the key:value pairs as a dict."""
         return {k: n.value for k, n in self._cache.items()}
         return {k: n.value for k, n in self._cache.items()}
 
 
@@ -208,11 +225,11 @@ class LRUCache(object):
         while len(self._cache) > self._after_cleanup_count:
         while len(self._cache) > self._after_cleanup_count:
             self._remove_lru()
             self._remove_lru()
 
 
-    def __setitem__(self, key, value):
+    def __setitem__(self, key: K, value: V) -> None:
         """Add a value to the cache, there will be no cleanup function."""
         """Add a value to the cache, there will be no cleanup function."""
         self.add(key, value, cleanup=None)
         self.add(key, value, cleanup=None)
 
 
-    def _record_access(self, node):
+    def _record_access(self, node: _LRUNode[K, V]) -> None:
         """Record that key was accessed."""
         """Record that key was accessed."""
         # Move 'node' to the front of the queue
         # Move 'node' to the front of the queue
         if self._most_recently_used is None:
         if self._most_recently_used is None:
@@ -238,7 +255,7 @@ class LRUCache(object):
         self._most_recently_used = node
         self._most_recently_used = node
         node.prev = None
         node.prev = None
 
 
-    def _remove_node(self, node):
+    def _remove_node(self, node: _LRUNode[K, V]) -> None:
         if node is self._least_recently_used:
         if node is self._least_recently_used:
             self._least_recently_used = node.prev
             self._least_recently_used = node.prev
         self._cache.pop(node.key)
         self._cache.pop(node.key)
@@ -254,23 +271,24 @@ class LRUCache(object):
             node_next.prev = node.prev
             node_next.prev = node.prev
         # And remove this node's pointers
         # And remove this node's pointers
         node.prev = None
         node.prev = None
-        node.next_key = _null_key
+        node.next_key = _null_key  # type: ignore
 
 
-    def _remove_lru(self):
+    def _remove_lru(self) -> None:
         """Remove one entry from the lru, and handle consequences.
         """Remove one entry from the lru, and handle consequences.
 
 
         If there are no more references to the lru, then this entry should be
         If there are no more references to the lru, then this entry should be
         removed from the cache.
         removed from the cache.
         """
         """
+        assert self._least_recently_used
         self._remove_node(self._least_recently_used)
         self._remove_node(self._least_recently_used)
 
 
-    def clear(self):
+    def clear(self) -> None:
         """Clear out all of the cache."""
         """Clear out all of the cache."""
         # Clean up in LRU order
         # Clean up in LRU order
         while self._cache:
         while self._cache:
             self._remove_lru()
             self._remove_lru()
 
 
-    def resize(self, max_cache, after_cleanup_count=None):
+    def resize(self, max_cache: int, after_cleanup_count: Optional[int] = None) -> None:
         """Change the number of entries that will be cached."""
         """Change the number of entries that will be cached."""
         self._update_max_cache(max_cache, after_cleanup_count=after_cleanup_count)
         self._update_max_cache(max_cache, after_cleanup_count=after_cleanup_count)
 
 
@@ -283,7 +301,7 @@ class LRUCache(object):
         self.cleanup()
         self.cleanup()
 
 
 
 
-class LRUSizeCache(LRUCache):
+class LRUSizeCache(LRUCache[K, V]):
     """An LRUCache that removes things based on the size of the values.
     """An LRUCache that removes things based on the size of the values.
 
 
     This differs in that it doesn't care how many actual items there are,
     This differs in that it doesn't care how many actual items there are,
@@ -293,9 +311,12 @@ class LRUSizeCache(LRUCache):
     defaults to len() if not supplied.
     defaults to len() if not supplied.
     """
     """
 
 
+    _compute_size: Callable[[V], int]
+
     def __init__(
     def __init__(
-        self, max_size=1024 * 1024, after_cleanup_size=None, compute_size=None
-    ):
+            self, max_size: int = 1024 * 1024, after_cleanup_size: Optional[int] = None,
+            compute_size: Optional[Callable[[V], int]] = None
+    ) -> None:
         """Create a new LRUSizeCache.
         """Create a new LRUSizeCache.
 
 
         Args:
         Args:
@@ -311,13 +332,14 @@ class LRUSizeCache(LRUCache):
             If not supplied, it defaults to 'len()'
             If not supplied, it defaults to 'len()'
         """
         """
         self._value_size = 0
         self._value_size = 0
-        self._compute_size = compute_size
         if compute_size is None:
         if compute_size is None:
-            self._compute_size = len
+            self._compute_size = len  # type: ignore
+        else:
+            self._compute_size = compute_size
         self._update_max_size(max_size, after_cleanup_size=after_cleanup_size)
         self._update_max_size(max_size, after_cleanup_size=after_cleanup_size)
         LRUCache.__init__(self, max_cache=max(int(max_size / 512), 1))
         LRUCache.__init__(self, max_cache=max(int(max_size / 512), 1))
 
 
-    def add(self, key, value, cleanup=None):
+    def add(self, key: K, value: V, cleanup: Optional[Callable[[K, V], None]] = None) -> None:
         """Add a new value to the cache.
         """Add a new value to the cache.
 
 
         Also, if the entry is ever removed from the cache, call
         Also, if the entry is ever removed from the cache, call
@@ -346,6 +368,7 @@ class LRUSizeCache(LRUCache):
             node = _LRUNode(key, value, cleanup=cleanup)
             node = _LRUNode(key, value, cleanup=cleanup)
             self._cache[key] = node
             self._cache[key] = node
         else:
         else:
+            assert node.size is not None
             self._value_size -= node.size
             self._value_size -= node.size
         node.size = value_len
         node.size = value_len
         self._value_size += value_len
         self._value_size += value_len
@@ -355,7 +378,7 @@ class LRUSizeCache(LRUCache):
             # Time to cleanup
             # Time to cleanup
             self.cleanup()
             self.cleanup()
 
 
-    def cleanup(self):
+    def cleanup(self) -> None:
         """Clear the cache until it shrinks to the requested size.
         """Clear the cache until it shrinks to the requested size.
 
 
         This does not completely wipe the cache, just makes sure it is under
         This does not completely wipe the cache, just makes sure it is under
@@ -365,17 +388,18 @@ class LRUSizeCache(LRUCache):
         while self._value_size > self._after_cleanup_size:
         while self._value_size > self._after_cleanup_size:
             self._remove_lru()
             self._remove_lru()
 
 
-    def _remove_node(self, node):
+    def _remove_node(self, node: _LRUNode[K, V]) -> None:
+        assert node.size is not None
         self._value_size -= node.size
         self._value_size -= node.size
         LRUCache._remove_node(self, node)
         LRUCache._remove_node(self, node)
 
 
-    def resize(self, max_size, after_cleanup_size=None):
+    def resize(self, max_size: int, after_cleanup_size: Optional[int] = None) -> None:
         """Change the number of bytes that will be cached."""
         """Change the number of bytes that will be cached."""
         self._update_max_size(max_size, after_cleanup_size=after_cleanup_size)
         self._update_max_size(max_size, after_cleanup_size=after_cleanup_size)
         max_cache = max(int(max_size / 512), 1)
         max_cache = max(int(max_size / 512), 1)
         self._update_max_cache(max_cache)
         self._update_max_cache(max_cache)
 
 
-    def _update_max_size(self, max_size, after_cleanup_size=None):
+    def _update_max_size(self, max_size: int, after_cleanup_size: Optional[int] = None) -> None:
         self._max_size = max_size
         self._max_size = max_size
         if after_cleanup_size is None:
         if after_cleanup_size is None:
             self._after_cleanup_size = self._max_size * 8 // 10
             self._after_cleanup_size = self._max_size * 8 // 10

+ 18 - 4
dulwich/objects.py

@@ -37,6 +37,7 @@ from typing import (
     List,
     List,
 )
 )
 import zlib
 import zlib
+from _hashlib import HASH
 from hashlib import sha1
 from hashlib import sha1
 
 
 from dulwich.errors import (
 from dulwich.errors import (
@@ -104,7 +105,7 @@ def _decompress(string):
 def sha_to_hex(sha):
 def sha_to_hex(sha):
     """Takes a string and returns the hex of the sha within"""
     """Takes a string and returns the hex of the sha within"""
     hexsha = binascii.hexlify(sha)
     hexsha = binascii.hexlify(sha)
-    assert len(hexsha) == 40, "Incorrect length of sha1 string: %s" % hexsha
+    assert len(hexsha) == 40, "Incorrect length of sha1 string: %r" % hexsha
     return hexsha
     return hexsha
 
 
 
 
@@ -273,6 +274,7 @@ class ShaFile(object):
     type_name: bytes
     type_name: bytes
     type_num: int
     type_num: int
     _chunked_text: Optional[List[bytes]]
     _chunked_text: Optional[List[bytes]]
+    _sha: Union[FixedSha, None, HASH]
 
 
     @staticmethod
     @staticmethod
     def _parse_legacy_object_header(magic, f) -> "ShaFile":
     def _parse_legacy_object_header(magic, f) -> "ShaFile":
@@ -454,7 +456,10 @@ class ShaFile(object):
           string: The raw uncompressed contents.
           string: The raw uncompressed contents.
           sha: Optional known sha for the object
           sha: Optional known sha for the object
         """
         """
-        obj = object_class(type_num)()
+        cls = object_class(type_num)
+        if cls is None:
+            raise AssertionError("unsupported class type num: %d" % type_num)
+        obj = cls()
         obj.set_raw_string(string, sha)
         obj.set_raw_string(string, sha)
         return obj
         return obj
 
 
@@ -542,6 +547,8 @@ class ShaFile(object):
     def copy(self):
     def copy(self):
         """Create a new copy of this SHA1 object from its raw string"""
         """Create a new copy of this SHA1 object from its raw string"""
         obj_class = object_class(self.type_num)
         obj_class = object_class(self.type_num)
+        if obj_class is None:
+            raise AssertionError('invalid type num %d' % self.type_num)
         return obj_class.from_raw_string(self.type_num, self.as_raw_string(), self.id)
         return obj_class.from_raw_string(self.type_num, self.as_raw_string(), self.id)
 
 
     @property
     @property
@@ -581,6 +588,8 @@ class Blob(ShaFile):
     type_name = b"blob"
     type_name = b"blob"
     type_num = 3
     type_num = 3
 
 
+    _chunked_text: List[bytes]
+
     def __init__(self):
     def __init__(self):
         super(Blob, self).__init__()
         super(Blob, self).__init__()
         self._chunked_text = []
         self._chunked_text = []
@@ -599,7 +608,7 @@ class Blob(ShaFile):
     def _get_chunked(self):
     def _get_chunked(self):
         return self._chunked_text
         return self._chunked_text
 
 
-    def _set_chunked(self, chunks):
+    def _set_chunked(self, chunks: List[bytes]):
         self._chunked_text = chunks
         self._chunked_text = chunks
 
 
     def _serialize(self):
     def _serialize(self):
@@ -729,6 +738,8 @@ class Tag(ShaFile):
         "_signature",
         "_signature",
     )
     )
 
 
+    _tagger: Optional[bytes]
+
     def __init__(self):
     def __init__(self):
         super(Tag, self).__init__()
         super(Tag, self).__init__()
         self._tagger = None
         self._tagger = None
@@ -751,6 +762,7 @@ class Tag(ShaFile):
           ObjectFormatException: if the object is malformed in some way
           ObjectFormatException: if the object is malformed in some way
         """
         """
         super(Tag, self).check()
         super(Tag, self).check()
+        assert self._chunked_text is not None
         self._check_has_member("_object_sha", "missing object sha")
         self._check_has_member("_object_sha", "missing object sha")
         self._check_has_member("_object_class", "missing object type")
         self._check_has_member("_object_class", "missing object type")
         self._check_has_member("_name", "missing tag name")
         self._check_has_member("_name", "missing tag name")
@@ -760,7 +772,7 @@ class Tag(ShaFile):
 
 
         check_hexsha(self._object_sha, "invalid object sha")
         check_hexsha(self._object_sha, "invalid object sha")
 
 
-        if getattr(self, "_tagger", None):
+        if self._tagger is not None:
             check_identity(self._tagger, "invalid tagger")
             check_identity(self._tagger, "invalid tagger")
 
 
         self._check_has_member("_tag_time", "missing tag time")
         self._check_has_member("_tag_time", "missing tag time")
@@ -1141,6 +1153,7 @@ class Tree(ShaFile):
           ObjectFormatException: if the object is malformed in some way
           ObjectFormatException: if the object is malformed in some way
         """
         """
         super(Tree, self).check()
         super(Tree, self).check()
+        assert self._chunked_text is not None
         last = None
         last = None
         allowed_modes = (
         allowed_modes = (
             stat.S_IFREG | 0o755,
             stat.S_IFREG | 0o755,
@@ -1395,6 +1408,7 @@ class Commit(ShaFile):
           ObjectFormatException: if the object is malformed in some way
           ObjectFormatException: if the object is malformed in some way
         """
         """
         super(Commit, self).check()
         super(Commit, self).check()
+        assert self._chunked_text is not None
         self._check_has_member("_tree", "missing tree")
         self._check_has_member("_tree", "missing tree")
         self._check_has_member("_author", "missing author")
         self._check_has_member("_author", "missing author")
         self._check_has_member("_committer", "missing committer")
         self._check_has_member("_committer", "missing committer")

+ 46 - 27
dulwich/pack.py

@@ -49,7 +49,7 @@ from itertools import chain
 
 
 import os
 import os
 import sys
 import sys
-from typing import Optional, Callable, Tuple, List
+from typing import Optional, Callable, Tuple, List, Deque, Union
 import warnings
 import warnings
 
 
 from hashlib import sha1
 from hashlib import sha1
@@ -96,13 +96,13 @@ DELTA_TYPES = (OFS_DELTA, REF_DELTA)
 DEFAULT_PACK_DELTA_WINDOW_SIZE = 10
 DEFAULT_PACK_DELTA_WINDOW_SIZE = 10
 
 
 
 
-def take_msb_bytes(read, crc32=None):
+def take_msb_bytes(read: Callable[[int], bytes], crc32: Optional[int] = None) -> Tuple[List[int], Optional[int]]:
     """Read bytes marked with most significant bit.
     """Read bytes marked with most significant bit.
 
 
     Args:
     Args:
       read: Read function
       read: Read function
     """
     """
-    ret = []
+    ret: List[int] = []
     while len(ret) == 0 or ret[-1] & 0x80:
     while len(ret) == 0 or ret[-1] & 0x80:
         b = read(1)
         b = read(1)
         if crc32 is not None:
         if crc32 is not None:
@@ -140,6 +140,9 @@ class UnpackedObject(object):
         "crc32",  # CRC32.
         "crc32",  # CRC32.
     ]
     ]
 
 
+    obj_type_num: Optional[int]
+    obj_chunks: Optional[List[bytes]]
+
     # TODO(dborowitz): read_zlib_chunks and unpack_object could very well be
     # TODO(dborowitz): read_zlib_chunks and unpack_object could very well be
     # methods of this object.
     # methods of this object.
     def __init__(self, pack_type_num, delta_base, decomp_len, crc32):
     def __init__(self, pack_type_num, delta_base, decomp_len, crc32):
@@ -148,7 +151,7 @@ class UnpackedObject(object):
         self.pack_type_num = pack_type_num
         self.pack_type_num = pack_type_num
         self.delta_base = delta_base
         self.delta_base = delta_base
         self.comp_chunks = None
         self.comp_chunks = None
-        self.decomp_chunks = []
+        self.decomp_chunks: List[bytes] = []
         self.decomp_len = decomp_len
         self.decomp_len = decomp_len
         self.crc32 = crc32
         self.crc32 = crc32
 
 
@@ -168,6 +171,7 @@ class UnpackedObject(object):
 
 
     def sha_file(self):
     def sha_file(self):
         """Return a ShaFile from this object."""
         """Return a ShaFile from this object."""
+        assert self.obj_type_num is not None and self.obj_chunks is not None
         return ShaFile.from_raw_chunks(self.obj_type_num, self.obj_chunks)
         return ShaFile.from_raw_chunks(self.obj_type_num, self.obj_chunks)
 
 
     # Only provided for backwards compatibility with code that expects either
     # Only provided for backwards compatibility with code that expects either
@@ -199,8 +203,10 @@ _ZLIB_BUFSIZE = 4096
 
 
 
 
 def read_zlib_chunks(
 def read_zlib_chunks(
-    read_some, unpacked, include_comp=False, buffer_size=_ZLIB_BUFSIZE
-):
+        read_some: Callable[[int], bytes],
+        unpacked: UnpackedObject, include_comp: bool = False,
+        buffer_size: int = _ZLIB_BUFSIZE
+) -> bytes:
     """Read zlib data from a buffer.
     """Read zlib data from a buffer.
 
 
     This function requires that the buffer have additional data following the
     This function requires that the buffer have additional data following the
@@ -298,7 +304,7 @@ def _load_file_contents(f, size=None):
         if has_mmap:
         if has_mmap:
             try:
             try:
                 contents = mmap.mmap(fd, size, access=mmap.ACCESS_READ)
                 contents = mmap.mmap(fd, size, access=mmap.ACCESS_READ)
-            except mmap.error:
+            except OSError:
                 # Perhaps a socket?
                 # Perhaps a socket?
                 pass
                 pass
             else:
             else:
@@ -431,6 +437,9 @@ class PackIndex(object):
         """Yield all the SHA1's of the objects in the index, sorted."""
         """Yield all the SHA1's of the objects in the index, sorted."""
         raise NotImplementedError(self._itersha)
         raise NotImplementedError(self._itersha)
 
 
+    def close(self):
+        pass
+
 
 
 class MemoryPackIndex(PackIndex):
 class MemoryPackIndex(PackIndex):
     """Pack index that is stored entirely in memory."""
     """Pack index that is stored entirely in memory."""
@@ -726,8 +735,8 @@ def chunks_length(chunks):
 
 
 
 
 def unpack_object(
 def unpack_object(
-    read_all,
-    read_some=None,
+    read_all: Callable[[int], bytes],
+    read_some: Optional[Callable[[int], bytes]] = None,
     compute_crc32=False,
     compute_crc32=False,
     include_comp=False,
     include_comp=False,
     zlib_bufsize=_ZLIB_BUFSIZE,
     zlib_bufsize=_ZLIB_BUFSIZE,
@@ -762,28 +771,30 @@ def unpack_object(
     else:
     else:
         crc32 = None
         crc32 = None
 
 
-    bytes, crc32 = take_msb_bytes(read_all, crc32=crc32)
-    type_num = (bytes[0] >> 4) & 0x07
-    size = bytes[0] & 0x0F
-    for i, byte in enumerate(bytes[1:]):
+    raw, crc32 = take_msb_bytes(read_all, crc32=crc32)
+    type_num = (raw[0] >> 4) & 0x07
+    size = raw[0] & 0x0F
+    for i, byte in enumerate(raw[1:]):
         size += (byte & 0x7F) << ((i * 7) + 4)
         size += (byte & 0x7F) << ((i * 7) + 4)
 
 
-    raw_base = len(bytes)
+    delta_base: Union[int, bytes, None]
+    raw_base = len(raw)
     if type_num == OFS_DELTA:
     if type_num == OFS_DELTA:
-        bytes, crc32 = take_msb_bytes(read_all, crc32=crc32)
-        raw_base += len(bytes)
-        if bytes[-1] & 0x80:
+        raw, crc32 = take_msb_bytes(read_all, crc32=crc32)
+        raw_base += len(raw)
+        if raw[-1] & 0x80:
             raise AssertionError
             raise AssertionError
-        delta_base_offset = bytes[0] & 0x7F
-        for byte in bytes[1:]:
+        delta_base_offset = raw[0] & 0x7F
+        for byte in raw[1:]:
             delta_base_offset += 1
             delta_base_offset += 1
             delta_base_offset <<= 7
             delta_base_offset <<= 7
             delta_base_offset += byte & 0x7F
             delta_base_offset += byte & 0x7F
         delta_base = delta_base_offset
         delta_base = delta_base_offset
     elif type_num == REF_DELTA:
     elif type_num == REF_DELTA:
-        delta_base = read_all(20)
-        if compute_crc32:
-            crc32 = binascii.crc32(delta_base, crc32)
+        delta_base_obj = read_all(20)
+        if crc32 is not None:
+            crc32 = binascii.crc32(delta_base_obj, crc32)
+        delta_base = delta_base_obj
         raw_base += 20
         raw_base += 20
     else:
     else:
         delta_base = None
         delta_base = None
@@ -823,7 +834,7 @@ class PackStreamReader(object):
         self._offset = 0
         self._offset = 0
         self._rbuf = BytesIO()
         self._rbuf = BytesIO()
         # trailer is a deque to avoid memory allocation on small reads
         # trailer is a deque to avoid memory allocation on small reads
-        self._trailer = deque()
+        self._trailer: Deque[bytes] = deque()
         self._zlib_bufsize = zlib_bufsize
         self._zlib_bufsize = zlib_bufsize
 
 
     def _read(self, read, size):
     def _read(self, read, size):
@@ -1671,7 +1682,7 @@ def deltify_pack_objects(objects, window_size=None, reuse_pack=None):
         magic.append((obj.type_num, path, -obj.raw_length(), obj))
         magic.append((obj.type_num, path, -obj.raw_length(), obj))
     magic.sort()
     magic.sort()
 
 
-    possible_bases = deque()
+    possible_bases: Deque[Tuple[bytes, int, bytes]] = deque()
 
 
     for type_num, path, neg_length, o in magic:
     for type_num, path, neg_length, o in magic:
         raw = o.as_raw_chunks()
         raw = o.as_raw_chunks()
@@ -2038,7 +2049,7 @@ def write_pack_index_v2(f, entries, pack_checksum):
     for (name, offset, entry_checksum) in entries:
     for (name, offset, entry_checksum) in entries:
         fan_out_table[ord(name[:1])] += 1
         fan_out_table[ord(name[:1])] += 1
     # Fan-out table
     # Fan-out table
-    largetable = []
+    largetable: List[int] = []
     for i in range(0x100):
     for i in range(0x100):
         f.write(struct.pack(b">L", fan_out_table[i]))
         f.write(struct.pack(b">L", fan_out_table[i]))
         fan_out_table[i + 1] += fan_out_table[i]
         fan_out_table[i + 1] += fan_out_table[i]
@@ -2079,6 +2090,12 @@ class _PackTupleIterable(object):
 class Pack(object):
 class Pack(object):
     """A Git pack object."""
     """A Git pack object."""
 
 
+    _data_load: Optional[Callable[[], PackData]]
+    _idx_load: Optional[Callable[[], PackIndex]]
+
+    _data: Optional[PackData]
+    _idx: Optional[PackIndex]
+
     def __init__(self, basename, resolve_ext_ref: Optional[
     def __init__(self, basename, resolve_ext_ref: Optional[
             Callable[[bytes], Tuple[int, UnpackedObject]]] = None):
             Callable[[bytes], Tuple[int, UnpackedObject]]] = None):
         self._basename = basename
         self._basename = basename
@@ -2115,20 +2132,22 @@ class Pack(object):
         return self.index.objects_sha1()
         return self.index.objects_sha1()
 
 
     @property
     @property
-    def data(self):
+    def data(self) -> PackData:
         """The pack data object being used."""
         """The pack data object being used."""
         if self._data is None:
         if self._data is None:
+            assert self._data_load
             self._data = self._data_load()
             self._data = self._data_load()
             self.check_length_and_checksum()
             self.check_length_and_checksum()
         return self._data
         return self._data
 
 
     @property
     @property
-    def index(self):
+    def index(self) -> PackIndex:
         """The index being used.
         """The index being used.
 
 
         Note: This may be an in-memory index
         Note: This may be an in-memory index
         """
         """
         if self._idx is None:
         if self._idx is None:
+            assert self._idx_load
             self._idx = self._idx_load()
             self._idx = self._idx_load()
         return self._idx
         return self._idx
 
 

+ 7 - 6
dulwich/patch.py

@@ -27,6 +27,7 @@ on.
 from difflib import SequenceMatcher
 from difflib import SequenceMatcher
 import email.parser
 import email.parser
 import time
 import time
+from typing import Union, TextIO, BinaryIO, Optional
 
 
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
@@ -338,7 +339,7 @@ def write_tree_diff(f, store, old_tree, new_tree, diff_binary=False):
         )
         )
 
 
 
 
-def git_am_patch_split(f, encoding=None):
+def git_am_patch_split(f: Union[TextIO, BinaryIO], encoding: Optional[str] = None):
     """Parse a git-am-style patch and split it up into bits.
     """Parse a git-am-style patch and split it up into bits.
 
 
     Args:
     Args:
@@ -349,12 +350,12 @@ def git_am_patch_split(f, encoding=None):
     encoding = encoding or getattr(f, "encoding", "ascii")
     encoding = encoding or getattr(f, "encoding", "ascii")
     encoding = encoding or "ascii"
     encoding = encoding or "ascii"
     contents = f.read()
     contents = f.read()
-    if isinstance(contents, bytes) and getattr(email.parser, "BytesParser", None):
-        parser = email.parser.BytesParser()
-        msg = parser.parsebytes(contents)
+    if isinstance(contents, bytes):
+        bparser = email.parser.BytesParser()
+        msg = bparser.parsebytes(contents)
     else:
     else:
-        parser = email.parser.Parser()
-        msg = parser.parsestr(contents)
+        uparser = email.parser.Parser()
+        msg = uparser.parsestr(contents)
     return parse_patch_message(msg, encoding)
     return parse_patch_message(msg, encoding)
 
 
 
 

+ 2 - 1
dulwich/refs.py

@@ -22,6 +22,7 @@
 """Ref handling.
 """Ref handling.
 
 
 """
 """
+from contextlib import suppress
 import os
 import os
 from typing import Dict, Optional
 from typing import Dict, Optional
 
 
@@ -815,7 +816,7 @@ class DiskRefsContainer(RefsContainer):
                 return
                 return
 
 
             del self._packed_refs[name]
             del self._packed_refs[name]
-            if name in self._peeled_refs:
+            with suppress(KeyError):
                 del self._peeled_refs[name]
                 del self._peeled_refs[name]
             write_packed_refs(f, self._packed_refs, self._peeled_refs)
             write_packed_refs(f, self._packed_refs, self._peeled_refs)
             f.close()
             f.close()

+ 1 - 0
setup.cfg

@@ -1,5 +1,6 @@
 [mypy]
 [mypy]
 ignore_missing_imports = True
 ignore_missing_imports = True
+#check_untyped_defs = True
 
 
 [metadata]
 [metadata]
 name = dulwich
 name = dulwich