Sfoglia il codice sorgente

Make config keys (sections and names) case insensitive

This is done to match the behaviour of git itself.

The main change here is to add the lower_key function
and the CaseInsensitiveDict class.

lower_key coerces keys that are used in DictConfig to
make sure all keys consist only of lower-cased strings.

CaseInsensitiveDict has altered __getitem__ and
__setitem__ methods to call lower_key on the key for
both get and set operations.

CaseInsensitiveDict.make takes care of the creation
logic found in ConfigDict.__init__ as well as converting
any passed-in dicts to CaseInsensitiveDict.

CaseInsensitiveDict.get and CaseInsensitiveDict.setdefault
allow the default value returned to be a CaseInsensitiveDict
without repeating it everywhere.

ConfigDict.check_section_and_name just moves some logic
that was previously repeated.

Fixes #599
Alistair Broomhead 7 anni fa
parent
commit
34b45ecb70
2 ha cambiato i file con 86 aggiunte e 12 eliminazioni
  1. 80 11
      dulwich/config.py
  2. 6 1
      dulwich/tests/test_config.py

+ 80 - 11
dulwich/config.py

@@ -30,6 +30,7 @@ import errno
 import os
 
 from collections import (
+    Iterable,
     OrderedDict,
     MutableMapping,
     )
@@ -39,6 +40,69 @@ from dulwich.file import GitFile
 
 
 DEFAULT_ENCODING = 'utf-8'
+SENTINAL = object()
+
+
+def lower_key(key):
+    if isinstance(key, (bytes, str)):
+        return key.lower()
+
+    if isinstance(key, Iterable):
+        return type(key)(
+            map(lower_key, key)
+        )
+
+    return key
+
+
+class CaseInsensitiveDict(OrderedDict):
+    @classmethod
+    def make(cls, dict_in=None):
+
+        if isinstance(dict_in, cls):
+            return dict_in
+
+        out = cls()
+
+        if dict_in is None:
+            return out
+
+        if not isinstance(dict_in, MutableMapping):
+            raise TypeError
+
+        for key, value in dict_in.items():
+            out[key] = value
+
+        return out
+
+    def __setitem__(self, key, value, **kwargs):
+        key = lower_key(key)
+
+        super(CaseInsensitiveDict, self).__setitem__(key, value,  **kwargs)
+
+    def __getitem__(self, item):
+        key = lower_key(item)
+
+        return super(CaseInsensitiveDict, self).__getitem__(key)
+
+    def get(self, key, default=SENTINAL):
+        try:
+            return self[key]
+        except KeyError:
+            pass
+
+        if default is SENTINAL:
+            return type(self)()
+
+        return default
+
+    def setdefault(self, key, default=SENTINAL):
+        try:
+            return self[key]
+        except KeyError:
+            self[key] = self.get(key, default)
+
+        return self[key]
 
 
 class Config(object):
@@ -112,9 +176,7 @@ class ConfigDict(Config, MutableMapping):
 
     def __init__(self, values=None):
         """Create a new ConfigDict."""
-        if values is None:
-            values = OrderedDict()
-        self._values = values
+        self._values = CaseInsensitiveDict.make(values)
 
     def __repr__(self):
         return "%s(%r)" % (self.__class__.__name__, self._values)
@@ -147,31 +209,38 @@ class ConfigDict(Config, MutableMapping):
         else:
             return (parts[0], None, parts[1])
 
-    def get(self, section, name):
+    @staticmethod
+    def check_section_and_name(section, name):
         if not isinstance(section, tuple):
             section = (section, )
         if not all([isinstance(subsection, bytes) for subsection in section]):
             raise TypeError(section)
         if not isinstance(name, bytes):
             raise TypeError(name)
+
+        return section
+
+    def get(self, section, name):
+        section = self.check_section_and_name(section, name)
+
         if len(section) > 1:
             try:
                 return self._values[section][name]
             except KeyError:
                 pass
+
         return self._values[(section[0],)][name]
 
     def set(self, section, name, value):
-        if not isinstance(section, tuple):
-            section = (section, )
-        if not isinstance(name, bytes):
-            raise TypeError(name)
+        section = self.check_section_and_name(section, name)
+
         if type(value) not in (bool, bytes):
             raise TypeError(value)
-        self._values.setdefault(section, OrderedDict())[name] = value
+
+        self._values.setdefault(section)[name] = value
 
     def iteritems(self, section):
-        return self._values.get(section, OrderedDict()).items()
+        return self._values.get(section).items()
 
     def itersections(self):
         return self._values.keys()
@@ -324,7 +393,7 @@ class ConfigFile(ConfigDict):
                             section = (pts[0], pts[1])
                         else:
                             section = (pts[0], )
-                    ret._values[section] = OrderedDict()
+                    ret._values.setdefault(section)
                 if _strip_comments(line).strip() == b"":
                     continue
                 if section is None:

+ 6 - 1
dulwich/tests/test_config.py

@@ -96,11 +96,16 @@ class ConfigFileTests(TestCase):
         self.assertEqual(b"bar", cf.get((b"core", ), b"foo"))
         self.assertEqual(b"bar", cf.get((b"core", b"foo"), b"foo"))
 
-    def test_from_file_section_case_insensitive(self):
+    def test_from_file_section_case_insensitive_lower(self):
         cf = self.from_file(b"[cOre]\nfOo = bar\n")
         self.assertEqual(b"bar", cf.get((b"core", ), b"foo"))
         self.assertEqual(b"bar", cf.get((b"core", b"foo"), b"foo"))
 
+    def test_from_file_section_case_insensitive_mixed(self):
+        cf = self.from_file(b"[cOre]\nfOo = bar\n")
+        self.assertEqual(b"bar", cf.get((b"core", ), b"fOo"))
+        self.assertEqual(b"bar", cf.get((b"cOre", b"fOo"), b"fOo"))
+
     def test_from_file_with_mixed_quoted(self):
         cf = self.from_file(b"[core]\nfoo = \"bar\"la\n")
         self.assertEqual(b"barla", cf.get((b"core", ), b"foo"))