|
@@ -29,20 +29,32 @@ import logging
|
|
|
import os
|
|
|
import re
|
|
|
import sys
|
|
|
-from collections.abc import Iterable, Iterator, KeysView, MutableMapping
|
|
|
+from collections.abc import (
|
|
|
+ ItemsView,
|
|
|
+ Iterable,
|
|
|
+ Iterator,
|
|
|
+ KeysView,
|
|
|
+ MutableMapping,
|
|
|
+ ValuesView,
|
|
|
+)
|
|
|
from contextlib import suppress
|
|
|
from pathlib import Path
|
|
|
from typing import (
|
|
|
Any,
|
|
|
BinaryIO,
|
|
|
Callable,
|
|
|
+ Generic,
|
|
|
Optional,
|
|
|
+ TypeVar,
|
|
|
Union,
|
|
|
overload,
|
|
|
)
|
|
|
|
|
|
from .file import GitFile
|
|
|
|
|
|
+ConfigKey = Union[str, bytes, tuple[Union[str, bytes], ...]]
|
|
|
+ConfigValue = Union[str, bytes, bool, int]
|
|
|
+
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# Type for file opener callback
|
|
@@ -56,8 +68,6 @@ ConditionMatcher = Callable[[str], bool]
|
|
|
MAX_INCLUDE_FILE_SIZE = 1024 * 1024 # 1MB max for included config files
|
|
|
DEFAULT_MAX_INCLUDE_DEPTH = 10 # Maximum recursion depth for includes
|
|
|
|
|
|
-SENTINEL = object()
|
|
|
-
|
|
|
|
|
|
def _match_gitdir_pattern(
|
|
|
path: bytes, pattern: bytes, ignorecase: bool = False
|
|
@@ -136,31 +146,41 @@ def match_glob_pattern(value: str, pattern: str) -> bool:
|
|
|
raise ValueError(f"Invalid glob pattern {pattern!r}: {e}")
|
|
|
|
|
|
|
|
|
-def lower_key(key):
|
|
|
+def lower_key(key: ConfigKey) -> ConfigKey:
|
|
|
if isinstance(key, (bytes, str)):
|
|
|
return key.lower()
|
|
|
|
|
|
- if isinstance(key, Iterable):
|
|
|
+ if isinstance(key, tuple):
|
|
|
# For config sections, only lowercase the section name (first element)
|
|
|
# but preserve the case of subsection names (remaining elements)
|
|
|
if len(key) > 0:
|
|
|
- return (key[0].lower(), *key[1:])
|
|
|
+ first = key[0]
|
|
|
+ assert isinstance(first, (bytes, str))
|
|
|
+ return (first.lower(), *key[1:])
|
|
|
return key
|
|
|
|
|
|
- return key
|
|
|
+ raise TypeError(key)
|
|
|
|
|
|
|
|
|
-class CaseInsensitiveOrderedMultiDict(MutableMapping):
|
|
|
- def __init__(self) -> None:
|
|
|
- self._real: list[Any] = []
|
|
|
- self._keyed: dict[Any, Any] = {}
|
|
|
+K = TypeVar("K", bound=ConfigKey) # Key type must be ConfigKey
|
|
|
+V = TypeVar("V") # Value type
|
|
|
+_T = TypeVar("_T") # For get() default parameter
|
|
|
+
|
|
|
+
|
|
|
+class CaseInsensitiveOrderedMultiDict(MutableMapping[K, V], Generic[K, V]):
|
|
|
+ def __init__(self, default_factory: Optional[Callable[[], V]] = None) -> None:
|
|
|
+ self._real: list[tuple[K, V]] = []
|
|
|
+ self._keyed: dict[Any, V] = {}
|
|
|
+ self._default_factory = default_factory
|
|
|
|
|
|
@classmethod
|
|
|
- def make(cls, dict_in=None):
|
|
|
+ def make(
|
|
|
+ cls, dict_in=None, default_factory=None
|
|
|
+ ) -> "CaseInsensitiveOrderedMultiDict[K, V]":
|
|
|
if isinstance(dict_in, cls):
|
|
|
return dict_in
|
|
|
|
|
|
- out = cls()
|
|
|
+ out = cls(default_factory=default_factory)
|
|
|
|
|
|
if dict_in is None:
|
|
|
return out
|
|
@@ -176,16 +196,33 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping):
|
|
|
def __len__(self) -> int:
|
|
|
return len(self._keyed)
|
|
|
|
|
|
- def keys(self) -> KeysView[tuple[bytes, ...]]:
|
|
|
- return self._keyed.keys()
|
|
|
+ def keys(self) -> KeysView[K]:
|
|
|
+ return self._keyed.keys() # type: ignore[return-value]
|
|
|
+
|
|
|
+ def items(self) -> ItemsView[K, V]:
|
|
|
+ # Return a view that iterates over the real list to preserve order
|
|
|
+ class OrderedItemsView(ItemsView[K, V]):
|
|
|
+ def __init__(self, mapping: CaseInsensitiveOrderedMultiDict[K, V]):
|
|
|
+ self._mapping = mapping
|
|
|
|
|
|
- def items(self):
|
|
|
- return iter(self._real)
|
|
|
+ def __iter__(self) -> Iterator[tuple[K, V]]:
|
|
|
+ return iter(self._mapping._real)
|
|
|
|
|
|
- def __iter__(self):
|
|
|
- return self._keyed.__iter__()
|
|
|
+ def __len__(self) -> int:
|
|
|
+ return len(self._mapping._real)
|
|
|
|
|
|
- def values(self):
|
|
|
+ def __contains__(self, item: object) -> bool:
|
|
|
+ if not isinstance(item, tuple) or len(item) != 2:
|
|
|
+ return False
|
|
|
+ key, value = item
|
|
|
+ return any(k == key and v == value for k, v in self._mapping._real)
|
|
|
+
|
|
|
+ return OrderedItemsView(self)
|
|
|
+
|
|
|
+ def __iter__(self) -> Iterator[K]:
|
|
|
+ return iter(self._keyed)
|
|
|
+
|
|
|
+ def values(self) -> ValuesView[V]:
|
|
|
return self._keyed.values()
|
|
|
|
|
|
def __setitem__(self, key, value) -> None:
|
|
@@ -206,33 +243,39 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping):
|
|
|
if lower_key(actual) == key:
|
|
|
del self._real[i]
|
|
|
|
|
|
- def __getitem__(self, item):
|
|
|
+ def __getitem__(self, item: K) -> V:
|
|
|
return self._keyed[lower_key(item)]
|
|
|
|
|
|
- def get(self, key, default=SENTINEL):
|
|
|
+ def get(self, key: K, /, default: Union[V, _T, None] = None) -> Union[V, _T, None]: # type: ignore[override]
|
|
|
try:
|
|
|
return self[key]
|
|
|
except KeyError:
|
|
|
- pass
|
|
|
-
|
|
|
- if default is SENTINEL:
|
|
|
- return type(self)()
|
|
|
-
|
|
|
- return default
|
|
|
+ if default is not None:
|
|
|
+ return default
|
|
|
+ elif self._default_factory is not None:
|
|
|
+ return self._default_factory()
|
|
|
+ else:
|
|
|
+ return None
|
|
|
|
|
|
- def get_all(self, key):
|
|
|
- key = lower_key(key)
|
|
|
+ def get_all(self, key: K) -> Iterator[V]:
|
|
|
+ lowered_key = lower_key(key)
|
|
|
for actual, value in self._real:
|
|
|
- if lower_key(actual) == key:
|
|
|
+ if lower_key(actual) == lowered_key:
|
|
|
yield value
|
|
|
|
|
|
- def setdefault(self, key, default=SENTINEL):
|
|
|
+ def setdefault(self, key: K, default: Optional[V] = None) -> V:
|
|
|
try:
|
|
|
return self[key]
|
|
|
except KeyError:
|
|
|
- self[key] = self.get(key, default)
|
|
|
-
|
|
|
- return self[key]
|
|
|
+ if default is not None:
|
|
|
+ self[key] = default
|
|
|
+ return default
|
|
|
+ elif self._default_factory is not None:
|
|
|
+ value = self._default_factory()
|
|
|
+ self[key] = value
|
|
|
+ return value
|
|
|
+ else:
|
|
|
+ raise
|
|
|
|
|
|
|
|
|
Name = bytes
|
|
@@ -344,7 +387,7 @@ class Config:
|
|
|
return name in self.sections()
|
|
|
|
|
|
|
|
|
-class ConfigDict(Config, MutableMapping[Section, MutableMapping[Name, Value]]):
|
|
|
+class ConfigDict(Config):
|
|
|
"""Git configuration stored in a dictionary."""
|
|
|
|
|
|
def __init__(
|
|
@@ -358,7 +401,11 @@ class ConfigDict(Config, MutableMapping[Section, MutableMapping[Name, Value]]):
|
|
|
if encoding is None:
|
|
|
encoding = sys.getdefaultencoding()
|
|
|
self.encoding = encoding
|
|
|
- self._values = CaseInsensitiveOrderedMultiDict.make(values)
|
|
|
+ self._values: CaseInsensitiveOrderedMultiDict[
|
|
|
+ Section, CaseInsensitiveOrderedMultiDict[Name, Value]
|
|
|
+ ] = CaseInsensitiveOrderedMultiDict.make(
|
|
|
+ values, default_factory=CaseInsensitiveOrderedMultiDict
|
|
|
+ )
|
|
|
|
|
|
def __repr__(self) -> str:
|
|
|
return f"{self.__class__.__name__}({self._values!r})"
|
|
@@ -366,7 +413,7 @@ class ConfigDict(Config, MutableMapping[Section, MutableMapping[Name, Value]]):
|
|
|
def __eq__(self, other: object) -> bool:
|
|
|
return isinstance(other, self.__class__) and other._values == self._values
|
|
|
|
|
|
- def __getitem__(self, key: Section) -> MutableMapping[Name, Value]:
|
|
|
+ def __getitem__(self, key: Section) -> CaseInsensitiveOrderedMultiDict[Name, Value]:
|
|
|
return self._values.__getitem__(key)
|
|
|
|
|
|
def __setitem__(self, key: Section, value: MutableMapping[Name, Value]) -> None:
|
|
@@ -381,8 +428,11 @@ class ConfigDict(Config, MutableMapping[Section, MutableMapping[Name, Value]]):
|
|
|
def __len__(self) -> int:
|
|
|
return self._values.__len__()
|
|
|
|
|
|
+ def keys(self) -> KeysView[Section]:
|
|
|
+ return self._values.keys()
|
|
|
+
|
|
|
@classmethod
|
|
|
- def _parse_setting(cls, name: str):
|
|
|
+ def _parse_setting(cls, name: str) -> tuple[str, Optional[str], str]:
|
|
|
parts = name.split(".")
|
|
|
if len(parts) == 3:
|
|
|
return (parts[0], parts[1], parts[2])
|
|
@@ -420,7 +470,7 @@ class ConfigDict(Config, MutableMapping[Section, MutableMapping[Name, Value]]):
|
|
|
|
|
|
return self._values[(section[0],)].get_all(name)
|
|
|
|
|
|
- def get( # type: ignore[override]
|
|
|
+ def get(
|
|
|
self,
|
|
|
section: SectionLike,
|
|
|
name: NameLike,
|
|
@@ -472,13 +522,15 @@ class ConfigDict(Config, MutableMapping[Section, MutableMapping[Name, Value]]):
|
|
|
|
|
|
self._values.setdefault(section)[name] = value
|
|
|
|
|
|
- def items( # type: ignore[override]
|
|
|
- self, section: Section
|
|
|
- ) -> Iterator[tuple[Name, Value]]:
|
|
|
- return self._values.get(section).items()
|
|
|
+ def items(self, section: SectionLike) -> Iterator[tuple[Name, Value]]:
|
|
|
+ section_bytes, _ = self._check_section_and_name(section, b"")
|
|
|
+ section_dict = self._values.get(section_bytes)
|
|
|
+ if section_dict is not None:
|
|
|
+ return iter(section_dict.items())
|
|
|
+ return iter([])
|
|
|
|
|
|
def sections(self) -> Iterator[Section]:
|
|
|
- return self._values.keys()
|
|
|
+ return iter(self._values.keys())
|
|
|
|
|
|
|
|
|
def _format_string(value: bytes) -> bytes:
|
|
@@ -781,6 +833,7 @@ class ConfigFile(ConfigDict):
|
|
|
else:
|
|
|
continuation += line
|
|
|
value = _parse_string(continuation)
|
|
|
+ assert section is not None # Already checked above
|
|
|
ret._values[section][setting] = value
|
|
|
|
|
|
# Process include/includeIf directives
|
|
@@ -1076,7 +1129,7 @@ class ConfigFile(ConfigDict):
|
|
|
f.write(b"\t" + key + b" = " + value + b"\n")
|
|
|
|
|
|
|
|
|
-def get_xdg_config_home_path(*path_segments):
|
|
|
+def get_xdg_config_home_path(*path_segments: str) -> str:
|
|
|
xdg_config_home = os.environ.get(
|
|
|
"XDG_CONFIG_HOME",
|
|
|
os.path.expanduser("~/.config/"),
|
|
@@ -1084,7 +1137,7 @@ def get_xdg_config_home_path(*path_segments):
|
|
|
return os.path.join(xdg_config_home, *path_segments)
|
|
|
|
|
|
|
|
|
-def _find_git_in_win_path():
|
|
|
+def _find_git_in_win_path() -> Iterator[str]:
|
|
|
for exe in ("git.exe", "git.cmd"):
|
|
|
for path in os.environ.get("PATH", "").split(";"):
|
|
|
if os.path.exists(os.path.join(path, exe)):
|
|
@@ -1100,7 +1153,7 @@ def _find_git_in_win_path():
|
|
|
break
|
|
|
|
|
|
|
|
|
-def _find_git_in_win_reg():
|
|
|
+def _find_git_in_win_reg() -> Iterator[str]:
|
|
|
import platform
|
|
|
import winreg
|
|
|
|
|
@@ -1126,7 +1179,7 @@ def _find_git_in_win_reg():
|
|
|
# - %PROGRAMFILES%/Git/etc/gitconfig - Git for Windows (msysgit) config dir
|
|
|
# Used if CGit installation (Git/bin/git.exe) is found in PATH in the
|
|
|
# system registry
|
|
|
-def get_win_system_paths():
|
|
|
+def get_win_system_paths() -> Iterator[str]:
|
|
|
if "PROGRAMDATA" in os.environ:
|
|
|
yield os.path.join(os.environ["PROGRAMDATA"], "Git", "config")
|
|
|
|
|
@@ -1228,7 +1281,7 @@ def parse_submodules(config: ConfigFile) -> Iterator[tuple[bytes, bytes, bytes]]
|
|
|
list of tuples (submodule path, url, name),
|
|
|
where name is quoted part of the section's name.
|
|
|
"""
|
|
|
- for section in config.keys():
|
|
|
+ for section in config.sections():
|
|
|
section_kind, section_name = section
|
|
|
if section_kind == b"submodule":
|
|
|
try:
|