Browse Source

Minor fixes (#1649)

Jelmer Vernooij 1 month ago
parent
commit
d59076b482
12 changed files with 183 additions and 184 deletions
  1. 1 1
      docs/conf.py
  2. 2 0
      docs/tutorial/remote.txt
  3. 40 120
      dulwich/client.py
  4. 103 50
      dulwich/config.py
  5. 2 2
      dulwich/objects.py
  6. 4 0
      dulwich/web.py
  7. 1 2
      examples/diff.py
  8. 1 1
      examples/gcs.py
  9. 5 0
      pyproject.toml
  10. 1 1
      setup.py
  11. 18 5
      tests/contrib/test_release_robot.py
  12. 5 2
      tests/test_config.py

+ 1 - 1
docs/conf.py

@@ -132,7 +132,7 @@ html_theme_path = ["theme"]
 # Add any paths that contain custom static files (such as style sheets) here,
 # Add any paths that contain custom static files (such as style sheets) here,
 # relative to this directory. They are copied after the builtin static files,
 # relative to this directory. They are copied after the builtin static files,
 # so a file named "default.css" will overwrite the builtin "default.css".
 # so a file named "default.css" will overwrite the builtin "default.css".
-html_static_path = []
+html_static_path: list[str] = []
 
 
 # If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
 # If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
 # using the given strftime format.
 # using the given strftime format.

+ 2 - 0
docs/tutorial/remote.txt

@@ -55,6 +55,8 @@ by the server. Here in the tutorial we'll just use a dummy graph walker
 which claims that the client doesn't have any objects::
 which claims that the client doesn't have any objects::
 
 
    >>> class DummyGraphWalker(object):
    >>> class DummyGraphWalker(object):
+   ...     def __init__(self):
+   ...         self.shallow = set()
    ...     def ack(self, sha): pass
    ...     def ack(self, sha): pass
    ...     def nak(self): pass
    ...     def nak(self): pass
    ...     def next(self): pass
    ...     def next(self): pass

+ 40 - 120
dulwich/client.py

@@ -317,14 +317,8 @@ def read_pkt_refs_v1(pkt_seq) -> tuple[dict[bytes, bytes], set[bytes]]:
     return refs, set(server_capabilities)
     return refs, set(server_capabilities)
 
 
 
 
-class FetchPackResult:
-    """Result of a fetch-pack operation.
-
-    Attributes:
-      refs: Dictionary with all remote refs
-      symrefs: Dictionary with remote symrefs
-      agent: User agent string
-    """
+class _DeprecatedDictProxy:
+    """Base class for result objects that provide deprecated dict-like interface."""
 
 
     _FORWARDED_ATTRS: ClassVar[set[str]] = {
     _FORWARDED_ATTRS: ClassVar[set[str]] = {
         "clear",
         "clear",
@@ -343,34 +337,15 @@ class FetchPackResult:
         "viewvalues",
         "viewvalues",
     }
     }
 
 
-    def __init__(
-        self, refs, symrefs, agent, new_shallow=None, new_unshallow=None
-    ) -> None:
-        self.refs = refs
-        self.symrefs = symrefs
-        self.agent = agent
-        self.new_shallow = new_shallow
-        self.new_unshallow = new_unshallow
-
     def _warn_deprecated(self) -> None:
     def _warn_deprecated(self) -> None:
         import warnings
         import warnings
 
 
         warnings.warn(
         warnings.warn(
-            "Use FetchPackResult.refs instead.",
+            f"Use {self.__class__.__name__}.refs instead.",
             DeprecationWarning,
             DeprecationWarning,
             stacklevel=3,
             stacklevel=3,
         )
         )
 
 
-    def __eq__(self, other):
-        if isinstance(other, dict):
-            self._warn_deprecated()
-            return self.refs == other
-        return (
-            self.refs == other.refs
-            and self.symrefs == other.symrefs
-            and self.agent == other.agent
-        )
-
     def __contains__(self, name) -> bool:
     def __contains__(self, name) -> bool:
         self._warn_deprecated()
         self._warn_deprecated()
         return name in self.refs
         return name in self.refs
@@ -388,16 +363,48 @@ class FetchPackResult:
         return iter(self.refs)
         return iter(self.refs)
 
 
     def __getattribute__(self, name):
     def __getattribute__(self, name):
-        if name in type(self)._FORWARDED_ATTRS:
+        # Avoid infinite recursion by checking against class variable directly
+        if name != "_FORWARDED_ATTRS" and name in type(self)._FORWARDED_ATTRS:
             self._warn_deprecated()
             self._warn_deprecated()
-            return getattr(self.refs, name)
+            # Direct attribute access to avoid recursion
+            refs = object.__getattribute__(self, "refs")
+            return getattr(refs, name)
         return super().__getattribute__(name)
         return super().__getattribute__(name)
 
 
+
+class FetchPackResult(_DeprecatedDictProxy):
+    """Result of a fetch-pack operation.
+
+    Attributes:
+      refs: Dictionary with all remote refs
+      symrefs: Dictionary with remote symrefs
+      agent: User agent string
+    """
+
+    def __init__(
+        self, refs, symrefs, agent, new_shallow=None, new_unshallow=None
+    ) -> None:
+        self.refs = refs
+        self.symrefs = symrefs
+        self.agent = agent
+        self.new_shallow = new_shallow
+        self.new_unshallow = new_unshallow
+
+    def __eq__(self, other):
+        if isinstance(other, dict):
+            self._warn_deprecated()
+            return self.refs == other
+        return (
+            self.refs == other.refs
+            and self.symrefs == other.symrefs
+            and self.agent == other.agent
+        )
+
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         return f"{self.__class__.__name__}({self.refs!r}, {self.symrefs!r}, {self.agent!r})"
         return f"{self.__class__.__name__}({self.refs!r}, {self.symrefs!r}, {self.agent!r})"
 
 
 
 
-class LsRemoteResult:
+class LsRemoteResult(_DeprecatedDictProxy):
     """Result of a ls-remote operation.
     """Result of a ls-remote operation.
 
 
     Attributes:
     Attributes:
@@ -405,23 +412,6 @@ class LsRemoteResult:
       symrefs: Dictionary with remote symrefs
       symrefs: Dictionary with remote symrefs
     """
     """
 
 
-    _FORWARDED_ATTRS: ClassVar[set[str]] = {
-        "clear",
-        "copy",
-        "fromkeys",
-        "get",
-        "items",
-        "keys",
-        "pop",
-        "popitem",
-        "setdefault",
-        "update",
-        "values",
-        "viewitems",
-        "viewkeys",
-        "viewvalues",
-    }
-
     def __init__(self, refs, symrefs) -> None:
     def __init__(self, refs, symrefs) -> None:
         self.refs = refs
         self.refs = refs
         self.symrefs = symrefs
         self.symrefs = symrefs
@@ -442,33 +432,11 @@ class LsRemoteResult:
             return self.refs == other
             return self.refs == other
         return self.refs == other.refs and self.symrefs == other.symrefs
         return self.refs == other.refs and self.symrefs == other.symrefs
 
 
-    def __contains__(self, name) -> bool:
-        self._warn_deprecated()
-        return name in self.refs
-
-    def __getitem__(self, name):
-        self._warn_deprecated()
-        return self.refs[name]
-
-    def __len__(self) -> int:
-        self._warn_deprecated()
-        return len(self.refs)
-
-    def __iter__(self):
-        self._warn_deprecated()
-        return iter(self.refs)
-
-    def __getattribute__(self, name):
-        if name in type(self)._FORWARDED_ATTRS:
-            self._warn_deprecated()
-            return getattr(self.refs, name)
-        return super().__getattribute__(name)
-
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         return f"{self.__class__.__name__}({self.refs!r}, {self.symrefs!r})"
         return f"{self.__class__.__name__}({self.refs!r}, {self.symrefs!r})"
 
 
 
 
-class SendPackResult:
+class SendPackResult(_DeprecatedDictProxy):
     """Result of a upload-pack operation.
     """Result of a upload-pack operation.
 
 
     Attributes:
     Attributes:
@@ -478,65 +446,17 @@ class SendPackResult:
         failed to update), or None if it was updated successfully
         failed to update), or None if it was updated successfully
     """
     """
 
 
-    _FORWARDED_ATTRS: ClassVar[set[str]] = {
-        "clear",
-        "copy",
-        "fromkeys",
-        "get",
-        "items",
-        "keys",
-        "pop",
-        "popitem",
-        "setdefault",
-        "update",
-        "values",
-        "viewitems",
-        "viewkeys",
-        "viewvalues",
-    }
-
     def __init__(self, refs, agent=None, ref_status=None) -> None:
     def __init__(self, refs, agent=None, ref_status=None) -> None:
         self.refs = refs
         self.refs = refs
         self.agent = agent
         self.agent = agent
         self.ref_status = ref_status
         self.ref_status = ref_status
 
 
-    def _warn_deprecated(self) -> None:
-        import warnings
-
-        warnings.warn(
-            "Use SendPackResult.refs instead.",
-            DeprecationWarning,
-            stacklevel=3,
-        )
-
     def __eq__(self, other):
     def __eq__(self, other):
         if isinstance(other, dict):
         if isinstance(other, dict):
             self._warn_deprecated()
             self._warn_deprecated()
             return self.refs == other
             return self.refs == other
         return self.refs == other.refs and self.agent == other.agent
         return self.refs == other.refs and self.agent == other.agent
 
 
-    def __contains__(self, name) -> bool:
-        self._warn_deprecated()
-        return name in self.refs
-
-    def __getitem__(self, name):
-        self._warn_deprecated()
-        return self.refs[name]
-
-    def __len__(self) -> int:
-        self._warn_deprecated()
-        return len(self.refs)
-
-    def __iter__(self):
-        self._warn_deprecated()
-        return iter(self.refs)
-
-    def __getattribute__(self, name):
-        if name in type(self)._FORWARDED_ATTRS:
-            self._warn_deprecated()
-            return getattr(self.refs, name)
-        return super().__getattribute__(name)
-
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         return f"{self.__class__.__name__}({self.refs!r}, {self.agent!r})"
         return f"{self.__class__.__name__}({self.refs!r}, {self.agent!r})"
 
 
@@ -673,7 +593,7 @@ def _handle_upload_pack_head(
     proto.write_pkt_line(wantcmd)
     proto.write_pkt_line(wantcmd)
     for want in wants[1:]:
     for want in wants[1:]:
         proto.write_pkt_line(COMMAND_WANT + b" " + want + b"\n")
         proto.write_pkt_line(COMMAND_WANT + b" " + want + b"\n")
-    if depth not in (0, None) or getattr(graph_walker, "shallow", None):
+    if depth not in (0, None) or graph_walker.shallow:
         if protocol_version == 2:
         if protocol_version == 2:
             if not find_capability(capabilities, CAPABILITY_FETCH, CAPABILITY_SHALLOW):
             if not find_capability(capabilities, CAPABILITY_FETCH, CAPABILITY_SHALLOW):
                 raise GitProtocolError(
                 raise GitProtocolError(

+ 103 - 50
dulwich/config.py

@@ -29,20 +29,32 @@ import logging
 import os
 import os
 import re
 import re
 import sys
 import sys
-from collections.abc import Iterable, Iterator, KeysView, MutableMapping
+from collections.abc import (
+    ItemsView,
+    Iterable,
+    Iterator,
+    KeysView,
+    MutableMapping,
+    ValuesView,
+)
 from contextlib import suppress
 from contextlib import suppress
 from pathlib import Path
 from pathlib import Path
 from typing import (
 from typing import (
     Any,
     Any,
     BinaryIO,
     BinaryIO,
     Callable,
     Callable,
+    Generic,
     Optional,
     Optional,
+    TypeVar,
     Union,
     Union,
     overload,
     overload,
 )
 )
 
 
 from .file import GitFile
 from .file import GitFile
 
 
+ConfigKey = Union[str, bytes, tuple[Union[str, bytes], ...]]
+ConfigValue = Union[str, bytes, bool, int]
+
 logger = logging.getLogger(__name__)
 logger = logging.getLogger(__name__)
 
 
 # Type for file opener callback
 # Type for file opener callback
@@ -56,8 +68,6 @@ ConditionMatcher = Callable[[str], bool]
 MAX_INCLUDE_FILE_SIZE = 1024 * 1024  # 1MB max for included config files
 MAX_INCLUDE_FILE_SIZE = 1024 * 1024  # 1MB max for included config files
 DEFAULT_MAX_INCLUDE_DEPTH = 10  # Maximum recursion depth for includes
 DEFAULT_MAX_INCLUDE_DEPTH = 10  # Maximum recursion depth for includes
 
 
-SENTINEL = object()
-
 
 
 def _match_gitdir_pattern(
 def _match_gitdir_pattern(
     path: bytes, pattern: bytes, ignorecase: bool = False
     path: bytes, pattern: bytes, ignorecase: bool = False
@@ -136,31 +146,41 @@ def match_glob_pattern(value: str, pattern: str) -> bool:
         raise ValueError(f"Invalid glob pattern {pattern!r}: {e}")
         raise ValueError(f"Invalid glob pattern {pattern!r}: {e}")
 
 
 
 
-def lower_key(key):
+def lower_key(key: ConfigKey) -> ConfigKey:
     if isinstance(key, (bytes, str)):
     if isinstance(key, (bytes, str)):
         return key.lower()
         return key.lower()
 
 
-    if isinstance(key, Iterable):
+    if isinstance(key, tuple):
         # For config sections, only lowercase the section name (first element)
         # For config sections, only lowercase the section name (first element)
         # but preserve the case of subsection names (remaining elements)
         # but preserve the case of subsection names (remaining elements)
         if len(key) > 0:
         if len(key) > 0:
-            return (key[0].lower(), *key[1:])
+            first = key[0]
+            assert isinstance(first, (bytes, str))
+            return (first.lower(), *key[1:])
         return key
         return key
 
 
-    return key
+    raise TypeError(key)
 
 
 
 
-class CaseInsensitiveOrderedMultiDict(MutableMapping):
-    def __init__(self) -> None:
-        self._real: list[Any] = []
-        self._keyed: dict[Any, Any] = {}
+K = TypeVar("K", bound=ConfigKey)  # Key type must be ConfigKey
+V = TypeVar("V")  # Value type
+_T = TypeVar("_T")  # For get() default parameter
+
+
+class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
+    def __init__(self, default_factory: Optional[Callable[[], V]] = None) -> None:
+        self._real: list[tuple[K, V]] = []
+        self._keyed: dict[Any, V] = {}
+        self._default_factory = default_factory
 
 
     @classmethod
     @classmethod
-    def make(cls, dict_in=None):
+    def make(
+        cls, dict_in=None, default_factory=None
+    ) -> "CaseInsensitiveOrderedMultiDict[K, V]":
         if isinstance(dict_in, cls):
         if isinstance(dict_in, cls):
             return dict_in
             return dict_in
 
 
-        out = cls()
+        out = cls(default_factory=default_factory)
 
 
         if dict_in is None:
         if dict_in is None:
             return out
             return out
@@ -176,16 +196,33 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping):
     def __len__(self) -> int:
     def __len__(self) -> int:
         return len(self._keyed)
         return len(self._keyed)
 
 
-    def keys(self) -> KeysView[tuple[bytes, ...]]:
-        return self._keyed.keys()
+    def keys(self) -> KeysView[K]:
+        return self._keyed.keys()  # type: ignore[return-value]
+
+    def items(self) -> ItemsView[K, V]:
+        # Return a view that iterates over the real list to preserve order
+        class OrderedItemsView(ItemsView[K, V]):
+            def __init__(self, mapping: CaseInsensitiveOrderedMultiDict[K, V]):
+                self._mapping = mapping
 
 
-    def items(self):
-        return iter(self._real)
+            def __iter__(self) -> Iterator[tuple[K, V]]:
+                return iter(self._mapping._real)
 
 
-    def __iter__(self):
-        return self._keyed.__iter__()
+            def __len__(self) -> int:
+                return len(self._mapping._real)
 
 
-    def values(self):
+            def __contains__(self, item: object) -> bool:
+                if not isinstance(item, tuple) or len(item) != 2:
+                    return False
+                key, value = item
+                return any(k == key and v == value for k, v in self._mapping._real)
+
+        return OrderedItemsView(self)
+
+    def __iter__(self) -> Iterator[K]:
+        return iter(self._keyed)
+
+    def values(self) -> ValuesView[V]:
         return self._keyed.values()
         return self._keyed.values()
 
 
     def __setitem__(self, key, value) -> None:
     def __setitem__(self, key, value) -> None:
@@ -206,33 +243,39 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping):
             if lower_key(actual) == key:
             if lower_key(actual) == key:
                 del self._real[i]
                 del self._real[i]
 
 
-    def __getitem__(self, item):
+    def __getitem__(self, item: K) -> V:
         return self._keyed[lower_key(item)]
         return self._keyed[lower_key(item)]
 
 
-    def get(self, key, default=SENTINEL):
+    def get(self, key: K, /, default: Union[V, _T, None] = None) -> Union[V, _T, None]:  # type: ignore[override]
         try:
         try:
             return self[key]
             return self[key]
         except KeyError:
         except KeyError:
-            pass
-
-        if default is SENTINEL:
-            return type(self)()
-
-        return default
+            if default is not None:
+                return default
+            elif self._default_factory is not None:
+                return self._default_factory()
+            else:
+                return None
 
 
-    def get_all(self, key):
-        key = lower_key(key)
+    def get_all(self, key: K) -> Iterator[V]:
+        lowered_key = lower_key(key)
         for actual, value in self._real:
         for actual, value in self._real:
-            if lower_key(actual) == key:
+            if lower_key(actual) == lowered_key:
                 yield value
                 yield value
 
 
-    def setdefault(self, key, default=SENTINEL):
+    def setdefault(self, key: K, default: Optional[V] = None) -> V:
         try:
         try:
             return self[key]
             return self[key]
         except KeyError:
         except KeyError:
-            self[key] = self.get(key, default)
-
-        return self[key]
+            if default is not None:
+                self[key] = default
+                return default
+            elif self._default_factory is not None:
+                value = self._default_factory()
+                self[key] = value
+                return value
+            else:
+                raise
 
 
 
 
 Name = bytes
 Name = bytes
@@ -344,7 +387,7 @@ class Config:
         return name in self.sections()
         return name in self.sections()
 
 
 
 
-class ConfigDict(Config, MutableMapping[Section, MutableMapping[Name, Value]]):
+class ConfigDict(Config):
     """Git configuration stored in a dictionary."""
     """Git configuration stored in a dictionary."""
 
 
     def __init__(
     def __init__(
@@ -358,7 +401,11 @@ class ConfigDict(Config, MutableMapping[Section, MutableMapping[Name, Value]]):
         if encoding is None:
         if encoding is None:
             encoding = sys.getdefaultencoding()
             encoding = sys.getdefaultencoding()
         self.encoding = encoding
         self.encoding = encoding
-        self._values = CaseInsensitiveOrderedMultiDict.make(values)
+        self._values: CaseInsensitiveOrderedMultiDict[
+            Section, CaseInsensitiveOrderedMultiDict[Name, Value]
+        ] = CaseInsensitiveOrderedMultiDict.make(
+            values, default_factory=CaseInsensitiveOrderedMultiDict
+        )
 
 
     def __repr__(self) -> str:
     def __repr__(self) -> str:
         return f"{self.__class__.__name__}({self._values!r})"
         return f"{self.__class__.__name__}({self._values!r})"
@@ -366,7 +413,7 @@ class ConfigDict(Config, MutableMapping[Section, MutableMapping[Name, Value]]):
     def __eq__(self, other: object) -> bool:
     def __eq__(self, other: object) -> bool:
         return isinstance(other, self.__class__) and other._values == self._values
         return isinstance(other, self.__class__) and other._values == self._values
 
 
-    def __getitem__(self, key: Section) -> MutableMapping[Name, Value]:
+    def __getitem__(self, key: Section) -> CaseInsensitiveOrderedMultiDict[Name, Value]:
         return self._values.__getitem__(key)
         return self._values.__getitem__(key)
 
 
     def __setitem__(self, key: Section, value: MutableMapping[Name, Value]) -> None:
     def __setitem__(self, key: Section, value: MutableMapping[Name, Value]) -> None:
@@ -381,8 +428,11 @@ class ConfigDict(Config, MutableMapping[Section, MutableMapping[Name, Value]]):
     def __len__(self) -> int:
     def __len__(self) -> int:
         return self._values.__len__()
         return self._values.__len__()
 
 
+    def keys(self) -> KeysView[Section]:
+        return self._values.keys()
+
     @classmethod
     @classmethod
-    def _parse_setting(cls, name: str):
+    def _parse_setting(cls, name: str) -> tuple[str, Optional[str], str]:
         parts = name.split(".")
         parts = name.split(".")
         if len(parts) == 3:
         if len(parts) == 3:
             return (parts[0], parts[1], parts[2])
             return (parts[0], parts[1], parts[2])
@@ -420,7 +470,7 @@ class ConfigDict(Config, MutableMapping[Section, MutableMapping[Name, Value]]):
 
 
         return self._values[(section[0],)].get_all(name)
         return self._values[(section[0],)].get_all(name)
 
 
-    def get(  # type: ignore[override]
+    def get(
         self,
         self,
         section: SectionLike,
         section: SectionLike,
         name: NameLike,
         name: NameLike,
@@ -472,13 +522,15 @@ class ConfigDict(Config, MutableMapping[Section, MutableMapping[Name, Value]]):
 
 
         self._values.setdefault(section)[name] = value
         self._values.setdefault(section)[name] = value
 
 
-    def items(  # type: ignore[override]
-        self, section: Section
-    ) -> Iterator[tuple[Name, Value]]:
-        return self._values.get(section).items()
+    def items(self, section: SectionLike) -> Iterator[tuple[Name, Value]]:
+        section_bytes, _ = self._check_section_and_name(section, b"")
+        section_dict = self._values.get(section_bytes)
+        if section_dict is not None:
+            return iter(section_dict.items())
+        return iter([])
 
 
     def sections(self) -> Iterator[Section]:
     def sections(self) -> Iterator[Section]:
-        return self._values.keys()
+        return iter(self._values.keys())
 
 
 
 
 def _format_string(value: bytes) -> bytes:
 def _format_string(value: bytes) -> bytes:
@@ -781,6 +833,7 @@ class ConfigFile(ConfigDict):
                 else:
                 else:
                     continuation += line
                     continuation += line
                     value = _parse_string(continuation)
                     value = _parse_string(continuation)
+                    assert section is not None  # Already checked above
                     ret._values[section][setting] = value
                     ret._values[section][setting] = value
 
 
                     # Process include/includeIf directives
                     # Process include/includeIf directives
@@ -1076,7 +1129,7 @@ class ConfigFile(ConfigDict):
                 f.write(b"\t" + key + b" = " + value + b"\n")
                 f.write(b"\t" + key + b" = " + value + b"\n")
 
 
 
 
-def get_xdg_config_home_path(*path_segments):
+def get_xdg_config_home_path(*path_segments: str) -> str:
     xdg_config_home = os.environ.get(
     xdg_config_home = os.environ.get(
         "XDG_CONFIG_HOME",
         "XDG_CONFIG_HOME",
         os.path.expanduser("~/.config/"),
         os.path.expanduser("~/.config/"),
@@ -1084,7 +1137,7 @@ def get_xdg_config_home_path(*path_segments):
     return os.path.join(xdg_config_home, *path_segments)
     return os.path.join(xdg_config_home, *path_segments)
 
 
 
 
-def _find_git_in_win_path():
+def _find_git_in_win_path() -> Iterator[str]:
     for exe in ("git.exe", "git.cmd"):
     for exe in ("git.exe", "git.cmd"):
         for path in os.environ.get("PATH", "").split(";"):
         for path in os.environ.get("PATH", "").split(";"):
             if os.path.exists(os.path.join(path, exe)):
             if os.path.exists(os.path.join(path, exe)):
@@ -1100,7 +1153,7 @@ def _find_git_in_win_path():
                 break
                 break
 
 
 
 
-def _find_git_in_win_reg():
+def _find_git_in_win_reg() -> Iterator[str]:
     import platform
     import platform
     import winreg
     import winreg
 
 
@@ -1126,7 +1179,7 @@ def _find_git_in_win_reg():
 #   - %PROGRAMFILES%/Git/etc/gitconfig - Git for Windows (msysgit) config dir
 #   - %PROGRAMFILES%/Git/etc/gitconfig - Git for Windows (msysgit) config dir
 #     Used if CGit installation (Git/bin/git.exe) is found in PATH in the
 #     Used if CGit installation (Git/bin/git.exe) is found in PATH in the
 #     system registry
 #     system registry
-def get_win_system_paths():
+def get_win_system_paths() -> Iterator[str]:
     if "PROGRAMDATA" in os.environ:
     if "PROGRAMDATA" in os.environ:
         yield os.path.join(os.environ["PROGRAMDATA"], "Git", "config")
         yield os.path.join(os.environ["PROGRAMDATA"], "Git", "config")
 
 
@@ -1228,7 +1281,7 @@ def parse_submodules(config: ConfigFile) -> Iterator[tuple[bytes, bytes, bytes]]
       list of tuples (submodule path, url, name),
       list of tuples (submodule path, url, name),
         where name is quoted part of the section's name.
         where name is quoted part of the section's name.
     """
     """
-    for section in config.keys():
+    for section in config.sections():
         section_kind, section_name = section
         section_kind, section_name = section
         if section_kind == b"submodule":
         if section_kind == b"submodule":
             try:
             try:

+ 2 - 2
dulwich/objects.py

@@ -146,7 +146,7 @@ def hex_to_filename(
     # os.path.join accepts bytes or unicode, but all args must be of the same
     # os.path.join accepts bytes or unicode, but all args must be of the same
     # type. Make sure that hex which is expected to be bytes, is the same type
     # type. Make sure that hex which is expected to be bytes, is the same type
     # as path.
     # as path.
-    if type(path) is not type(hex) and getattr(path, "encode", None) is not None:
+    if type(path) is not type(hex) and isinstance(path, str):
         hex = hex.decode("ascii")  # type: ignore
         hex = hex.decode("ascii")  # type: ignore
     dir = hex[:2]
     dir = hex[:2]
     file = hex[2:]
     file = hex[2:]
@@ -263,7 +263,7 @@ class FixedSha:
     __slots__ = ("_hexsha", "_sha")
     __slots__ = ("_hexsha", "_sha")
 
 
     def __init__(self, hexsha: Union[str, bytes]) -> None:
     def __init__(self, hexsha: Union[str, bytes]) -> None:
-        if getattr(hexsha, "encode", None) is not None:
+        if isinstance(hexsha, str):
             hexsha = hexsha.encode("ascii")  # type: ignore
             hexsha = hexsha.encode("ascii")  # type: ignore
         if not isinstance(hexsha, bytes):
         if not isinstance(hexsha, bytes):
             raise TypeError(f"Expected bytes for hexsha, got {hexsha!r}")
             raise TypeError(f"Expected bytes for hexsha, got {hexsha!r}")

+ 4 - 0
dulwich/web.py

@@ -258,6 +258,8 @@ def _chunk_iter(f):
 
 
 
 
 class ChunkReader:
 class ChunkReader:
+    """Reader for chunked transfer encoding streams."""
+
     def __init__(self, f) -> None:
     def __init__(self, f) -> None:
         self._iter = _chunk_iter(f)
         self._iter = _chunk_iter(f)
         self._buffer: list[bytes] = []
         self._buffer: list[bytes] = []
@@ -557,6 +559,8 @@ class WSGIRequestHandlerLogger(WSGIRequestHandler):
 
 
 
 
 class WSGIServerLogger(WSGIServer):
 class WSGIServerLogger(WSGIServer):
+    """WSGIServer that uses dulwich's logger for error handling."""
+
     def handle_error(self, request, client_address) -> None:
     def handle_error(self, request, client_address) -> None:
         """Handle an error."""
         """Handle an error."""
         logger.exception(
         logger.exception(

+ 1 - 2
examples/diff.py

@@ -19,5 +19,4 @@ r = Repo(repo_path)
 
 
 commit = r[commit_id]
 commit = r[commit_id]
 parent_commit = r[commit.parents[0]]
 parent_commit = r[commit.parents[0]]
-outstream = getattr(sys.stdout, "buffer", sys.stdout)
-write_tree_diff(outstream, r.object_store, parent_commit.tree, commit.tree)
+write_tree_diff(sys.stdout.buffer, r.object_store, parent_commit.tree, commit.tree)

+ 1 - 1
examples/gcs.py

@@ -3,7 +3,7 @@
 
 
 import tempfile
 import tempfile
 
 
-from google.cloud import storage
+from google.cloud import storage  # type: ignore[attr-defined]
 
 
 from dulwich.cloud.gcs import GcsObjectStore
 from dulwich.cloud.gcs import GcsObjectStore
 from dulwich.repo import Repo
 from dulwich.repo import Repo

+ 5 - 0
pyproject.toml

@@ -47,6 +47,7 @@ dev = [
     "dissolve>=0.1.1"
     "dissolve>=0.1.1"
 ]
 ]
 merge = ["merge3"]
 merge = ["merge3"]
+fuzzing = ["atheris"]
 
 
 [project.scripts]
 [project.scripts]
 dulwich = "dulwich.cli:main"
 dulwich = "dulwich.cli:main"
@@ -54,6 +55,10 @@ dulwich = "dulwich.cli:main"
 [tool.mypy]
 [tool.mypy]
 ignore_missing_imports = true
 ignore_missing_imports = true
 
 
+[[tool.mypy.overrides]]
+module = "atheris"
+ignore_missing_imports = true
+
 [tool.setuptools]
 [tool.setuptools]
 packages = [
 packages = [
     "dulwich",
     "dulwich",

+ 1 - 1
setup.py

@@ -40,7 +40,7 @@ if "PURE" in os.environ or "--pure" in sys.argv:
     if "--pure" in sys.argv:
     if "--pure" in sys.argv:
         sys.argv.remove("--pure")
         sys.argv.remove("--pure")
     setup_requires = []
     setup_requires = []
-    rust_extensions = []
+    rust_extensions = []  # type: list["RustExtension"]
 else:
 else:
     setup_requires = ["setuptools_rust"]
     setup_requires = ["setuptools_rust"]
     # We check for egg_info since that indicates we are running prepare_metadata_for_build_*
     # We check for egg_info since that indicates we are running prepare_metadata_for_build_*

+ 18 - 5
tests/contrib/test_release_robot.py

@@ -33,6 +33,7 @@ from typing import ClassVar, Optional
 from unittest.mock import MagicMock, patch
 from unittest.mock import MagicMock, patch
 
 
 from dulwich.contrib import release_robot
 from dulwich.contrib import release_robot
+from dulwich.objects import Commit, Tag
 from dulwich.repo import Repo
 from dulwich.repo import Repo
 from dulwich.tests.utils import make_commit, make_tag
 from dulwich.tests.utils import make_commit, make_tag
 
 
@@ -63,6 +64,7 @@ class TagPatternTests(unittest.TestCase):
         }
         }
         for testcase, version in test_cases.items():
         for testcase, version in test_cases.items():
             matches = re.match(release_robot.PATTERN, testcase)
             matches = re.match(release_robot.PATTERN, testcase)
+            assert matches is not None
             self.assertEqual(matches.group(1), version)
             self.assertEqual(matches.group(1), version)
 
 
     def test_pattern_no_match(self) -> None:
     def test_pattern_no_match(self) -> None:
@@ -93,6 +95,14 @@ class GetRecentTagsTest(unittest.TestCase):
         test_tags[1]: (1484788314, b"1" * 40, (1484788401, b"2" * 40)),
         test_tags[1]: (1484788314, b"1" * 40, (1484788401, b"2" * 40)),
     }
     }
 
 
+    # Class attributes set in setUpClass
+    projdir: ClassVar[str]
+    repo: ClassVar[Repo]
+    c1: ClassVar[Commit]
+    c2: ClassVar[Commit]
+    t1: ClassVar[bytes]
+    t2: ClassVar[Tag]
+
     @classmethod
     @classmethod
     def setUpClass(cls) -> None:
     def setUpClass(cls) -> None:
         cls.projdir = tempfile.mkdtemp()  # temporary project directory
         cls.projdir = tempfile.mkdtemp()  # temporary project directory
@@ -119,11 +129,14 @@ class GetRecentTagsTest(unittest.TestCase):
         )
         )
         obj_store.add_object(cls.c2)
         obj_store.add_object(cls.c2)
         # tag 2: annotated ('2017-01-19T01:13:21')
         # tag 2: annotated ('2017-01-19T01:13:21')
+        tag_data = cls.tag_test_data[cls.test_tags[1]][2]
+        if tag_data is None:
+            raise AssertionError("test_tags[1] should have annotated tag data")
         cls.t2 = make_tag(
         cls.t2 = make_tag(
             cls.c2,
             cls.c2,
-            id=cls.tag_test_data[cls.test_tags[1]][2][1],
+            id=tag_data[1],
             name=cls.test_tags[1],
             name=cls.test_tags[1],
-            tag_time=cls.tag_test_data[cls.test_tags[1]][2][0],
+            tag_time=tag_data[0],
         )
         )
         obj_store.add_object(cls.t2)
         obj_store.add_object(cls.t2)
         cls.repo[b"refs/heads/master"] = cls.c2.id
         cls.repo[b"refs/heads/master"] = cls.c2.id
@@ -138,8 +151,8 @@ class GetRecentTagsTest(unittest.TestCase):
         """Test get recent tags."""
         """Test get recent tags."""
         tags = release_robot.get_recent_tags(self.projdir)  # get test tags
         tags = release_robot.get_recent_tags(self.projdir)  # get test tags
         for tag, metadata in tags:
         for tag, metadata in tags:
-            tag = tag.encode("utf-8")
-            test_data = self.tag_test_data[tag]  # test data tag
+            tag_bytes = tag.encode("utf-8")
+            test_data = self.tag_test_data[tag_bytes]  # test data tag
             # test commit date, id and author name
             # test commit date, id and author name
             self.assertEqual(metadata[0], gmtime_to_datetime(test_data[0]))
             self.assertEqual(metadata[0], gmtime_to_datetime(test_data[0]))
             self.assertEqual(metadata[1].encode("utf-8"), test_data[1])
             self.assertEqual(metadata[1].encode("utf-8"), test_data[1])
@@ -151,7 +164,7 @@ class GetRecentTagsTest(unittest.TestCase):
             # tag date, id and name
             # tag date, id and name
             self.assertEqual(metadata[3][0], gmtime_to_datetime(tag_obj[0]))
             self.assertEqual(metadata[3][0], gmtime_to_datetime(tag_obj[0]))
             self.assertEqual(metadata[3][1].encode("utf-8"), tag_obj[1])
             self.assertEqual(metadata[3][1].encode("utf-8"), tag_obj[1])
-            self.assertEqual(metadata[3][2].encode("utf-8"), tag)
+            self.assertEqual(metadata[3][2], tag)
 
 
     def test_get_recent_tags_sorting(self) -> None:
     def test_get_recent_tags_sorting(self) -> None:
         """Test that tags are sorted by commit time from newest to oldest."""
         """Test that tags are sorted by commit time from newest to oldest."""

+ 5 - 2
tests/test_config.py

@@ -1102,8 +1102,11 @@ class CaseInsensitiveConfigTests(TestCase):
         self.assertEqual("value", config.get(("core",)))
         self.assertEqual("value", config.get(("core",)))
         self.assertEqual("value", config.get(("CORE",)))
         self.assertEqual("value", config.get(("CORE",)))
         self.assertEqual("default", config.get(("missing",), "default"))
         self.assertEqual("default", config.get(("missing",), "default"))
-        # Test SENTINEL behavior
-        result = config.get(("missing",))
+        # Test default_factory behavior
+        config_with_factory = CaseInsensitiveOrderedMultiDict(
+            default_factory=CaseInsensitiveOrderedMultiDict
+        )
+        result = config_with_factory.get(("missing",))
         self.assertIsInstance(result, CaseInsensitiveOrderedMultiDict)
         self.assertIsInstance(result, CaseInsensitiveOrderedMultiDict)
         self.assertEqual(0, len(result))
         self.assertEqual(0, len(result))