Browse Source

Make config keys (sections and names) case insensitive

This is done to match the behaviour of git itself.

The main change here is to add the lower_key function
and the CaseInsensitiveDict class.

lower_key coerces keys that are used in DictConfig to
make sure all keys consist only of lower-cased strings.

CaseInsensitiveDict has altered __getitem__ and
__setitem__ methods to call lower_key on the key for
both get and set operations.

CaseInsensitiveDict.make takes care of the creation
logic found in ConfigDict.__init__ as well as converting
any passed-in dicts to CaseInsensitiveDict.

CaseInsensitiveDict.get and CaseInsensitiveDict.setdefault
allow the default value returned to be a CaseInsensitiveDict
without repeating it everywhere.

ConfigDict.check_section_and_name just moves some logic
that was previously repeated.

Fixes #599
Alistair Broomhead 7 years ago
parent
commit
34b45ecb70
2 changed files with 86 additions and 12 deletions
  1. 80 11
      dulwich/config.py
  2. 6 1
      dulwich/tests/test_config.py

+ 80 - 11
dulwich/config.py

@@ -30,6 +30,7 @@ import errno
 import os
 
 from collections import (
+    Iterable,
     OrderedDict,
     MutableMapping,
     )
@@ -39,6 +40,69 @@ from dulwich.file import GitFile
 
 
 DEFAULT_ENCODING = 'utf-8'
+SENTINAL = object()
+
+
+def lower_key(key):
+    if isinstance(key, (bytes, str)):
+        return key.lower()
+
+    if isinstance(key, Iterable):
+        return type(key)(
+            map(lower_key, key)
+        )
+
+    return key
+
+
+class CaseInsensitiveDict(OrderedDict):
+    @classmethod
+    def make(cls, dict_in=None):
+
+        if isinstance(dict_in, cls):
+            return dict_in
+
+        out = cls()
+
+        if dict_in is None:
+            return out
+
+        if not isinstance(dict_in, MutableMapping):
+            raise TypeError
+
+        for key, value in dict_in.items():
+            out[key] = value
+
+        return out
+
+    def __setitem__(self, key, value, **kwargs):
+        key = lower_key(key)
+
+        super(CaseInsensitiveDict, self).__setitem__(key, value,  **kwargs)
+
+    def __getitem__(self, item):
+        key = lower_key(item)
+
+        return super(CaseInsensitiveDict, self).__getitem__(key)
+
+    def get(self, key, default=SENTINAL):
+        try:
+            return self[key]
+        except KeyError:
+            pass
+
+        if default is SENTINAL:
+            return type(self)()
+
+        return default
+
+    def setdefault(self, key, default=SENTINAL):
+        try:
+            return self[key]
+        except KeyError:
+            self[key] = self.get(key, default)
+
+        return self[key]
 
 
 class Config(object):
@@ -112,9 +176,7 @@ class ConfigDict(Config, MutableMapping):
 
     def __init__(self, values=None):
         """Create a new ConfigDict."""
-        if values is None:
-            values = OrderedDict()
-        self._values = values
+        self._values = CaseInsensitiveDict.make(values)
 
     def __repr__(self):
         return "%s(%r)" % (self.__class__.__name__, self._values)
@@ -147,31 +209,38 @@ class ConfigDict(Config, MutableMapping):
         else:
             return (parts[0], None, parts[1])
 
-    def get(self, section, name):
+    @staticmethod
+    def check_section_and_name(section, name):
         if not isinstance(section, tuple):
             section = (section, )
         if not all([isinstance(subsection, bytes) for subsection in section]):
             raise TypeError(section)
         if not isinstance(name, bytes):
             raise TypeError(name)
+
+        return section
+
+    def get(self, section, name):
+        section = self.check_section_and_name(section, name)
+
         if len(section) > 1:
             try:
                 return self._values[section][name]
             except KeyError:
                 pass
+
         return self._values[(section[0],)][name]
 
     def set(self, section, name, value):
-        if not isinstance(section, tuple):
-            section = (section, )
-        if not isinstance(name, bytes):
-            raise TypeError(name)
+        section = self.check_section_and_name(section, name)
+
         if type(value) not in (bool, bytes):
             raise TypeError(value)
-        self._values.setdefault(section, OrderedDict())[name] = value
+
+        self._values.setdefault(section)[name] = value
 
     def iteritems(self, section):
-        return self._values.get(section, OrderedDict()).items()
+        return self._values.get(section).items()
 
     def itersections(self):
         return self._values.keys()
@@ -324,7 +393,7 @@ class ConfigFile(ConfigDict):
                             section = (pts[0], pts[1])
                         else:
                             section = (pts[0], )
-                    ret._values[section] = OrderedDict()
+                    ret._values.setdefault(section)
                 if _strip_comments(line).strip() == b"":
                     continue
                 if section is None:

+ 6 - 1
dulwich/tests/test_config.py

@@ -96,11 +96,16 @@ class ConfigFileTests(TestCase):
         self.assertEqual(b"bar", cf.get((b"core", ), b"foo"))
         self.assertEqual(b"bar", cf.get((b"core", b"foo"), b"foo"))
 
-    def test_from_file_section_case_insensitive(self):
+    def test_from_file_section_case_insensitive_lower(self):
         cf = self.from_file(b"[cOre]\nfOo = bar\n")
         self.assertEqual(b"bar", cf.get((b"core", ), b"foo"))
         self.assertEqual(b"bar", cf.get((b"core", b"foo"), b"foo"))
 
+    def test_from_file_section_case_insensitive_mixed(self):
+        cf = self.from_file(b"[cOre]\nfOo = bar\n")
+        self.assertEqual(b"bar", cf.get((b"core", ), b"fOo"))
+        self.assertEqual(b"bar", cf.get((b"cOre", b"fOo"), b"fOo"))
+
     def test_from_file_with_mixed_quoted(self):
         cf = self.from_file(b"[core]\nfoo = \"bar\"la\n")
         self.assertEqual(b"barla", cf.get((b"core", ), b"foo"))