2
0
Эх сурвалжийг харах

Split out Config.set and Config.add (#1549)

Fixes #1545
Jelmer Vernooij 2 долоо хоног өмнө
parent
commit
eee9d0725a
3 өөрчлөгдсөн 194 нэмэгдсэн , 3 устгасан
  1. 4 1
      NEWS
  2. 28 0
      dulwich/config.py
  3. 162 2
      tests/test_config.py

+ 4 - 1
NEWS

@@ -7,7 +7,10 @@
 
  * Fix wheels workflow. (Jelmer Vernooij)
 
- * Bump PyO3 to 0.25. (Jelmer Vernooij)
+ * ``Config.set`` replaces values by default, ``Config.add``
+   appends them. (Jelmer Vernooij, #1545)
+
+* Bump PyO3 to 0.25. (Jelmer Vernooij)
 
 0.22.8	2025-03-02
 

+ 28 - 0
dulwich/config.py

@@ -96,6 +96,13 @@ class CaseInsensitiveOrderedMultiDict(MutableMapping):
         self._real.append((key, value))
         self._keyed[lower_key(key)] = value
 
+    def set(self, key, value) -> None:
+        # This method replaces all existing values for the key
+        lower = lower_key(key)
+        self._real = [(k, v) for k, v in self._real if lower_key(k) != lower]
+        self._real.append((key, value))
+        self._keyed[lower] = value
+
     def __delitem__(self, key) -> None:
         key = lower_key(key)
         del self._keyed[key]
@@ -340,6 +347,27 @@ class ConfigDict(Config, MutableMapping[Section, MutableMapping[Name, Value]]):
     ) -> None:
         section, name = self._check_section_and_name(section, name)
 
+        if isinstance(value, bool):
+            value = b"true" if value else b"false"
+
+        if not isinstance(value, bytes):
+            value = value.encode(self.encoding)
+
+        section_dict = self._values.setdefault(section)
+        if hasattr(section_dict, "set"):
+            section_dict.set(name, value)
+        else:
+            section_dict[name] = value
+
+    def add(
+        self,
+        section: SectionLike,
+        name: NameLike,
+        value: Union[ValueLike, bool],
+    ) -> None:
+        """Add a value to a configuration setting, creating a multivar if needed."""
+        section, name = self._check_section_and_name(section, name)
+
         if isinstance(value, bool):
             value = b"true" if value else b"false"
 

+ 162 - 2
tests/test_config.py

@@ -28,6 +28,7 @@ from unittest import skipIf
 from unittest.mock import patch
 
 from dulwich.config import (
+    CaseInsensitiveOrderedMultiDict,
     ConfigDict,
     ConfigFile,
     StackedConfig,
@@ -185,6 +186,14 @@ class ConfigFileTests(TestCase):
         c.write_to_file(f)
         self.assertEqual(b"[core]\n\tfoo = bar\n", f.getvalue())
 
+    def test_write_to_file_section_multiple(self) -> None:
+        c = ConfigFile()
+        c.set((b"core",), b"foo", b"old")
+        c.set((b"core",), b"foo", b"new")
+        f = BytesIO()
+        c.write_to_file(f)
+        self.assertEqual(b"[core]\n\tfoo = new\n", f.getvalue())
+
     def test_write_to_file_subsection(self) -> None:
         c = ConfigFile()
         c.set((b"branch", b"blie"), b"foo", b"bar")
@@ -306,6 +315,20 @@ class ConfigDictTests(TestCase):
 
         self.assertEqual([(b"core2",)], list(cd.sections()))
 
+    def test_set_vs_add(self) -> None:
+        cd = ConfigDict()
+        # Test add() creates multivars
+        cd.add((b"core",), b"foo", b"value1")
+        cd.add((b"core",), b"foo", b"value2")
+        self.assertEqual(
+            [b"value1", b"value2"], list(cd.get_multivar((b"core",), b"foo"))
+        )
+
+        # Test set() replaces values
+        cd.set((b"core",), b"foo", b"value3")
+        self.assertEqual([b"value3"], list(cd.get_multivar((b"core",), b"foo")))
+        self.assertEqual(b"value3", cd.get((b"core",), b"foo"))
+
 
 class StackedConfigTests(TestCase):
     def test_default_backends(self) -> None:
@@ -482,8 +505,8 @@ class ApplyInsteadOfTests(TestCase):
 
     def test_apply_multiple(self) -> None:
         config = ConfigDict()
-        config.set(("url", "https://samba.org/"), "insteadOf", "https://blah.com/")
-        config.set(("url", "https://samba.org/"), "insteadOf", "https://example.com/")
+        config.add(("url", "https://samba.org/"), "insteadOf", "https://blah.com/")
+        config.add(("url", "https://samba.org/"), "insteadOf", "https://example.com/")
         self.assertEqual(
             [b"https://blah.com/", b"https://example.com/"],
             list(config.get_multivar(("url", "https://samba.org/"), "insteadOf")),
@@ -491,3 +514,140 @@ class ApplyInsteadOfTests(TestCase):
         self.assertEqual(
             "https://samba.org/", apply_instead_of(config, "https://example.com/")
         )
+
+
+class CaseInsensitiveConfigTests(TestCase):
+    def test_case_insensitive(self) -> None:
+        config = CaseInsensitiveOrderedMultiDict()
+        config[("core",)] = "value"
+        self.assertEqual("value", config[("CORE",)])
+        self.assertEqual("value", config[("CoRe",)])
+        self.assertEqual([("core",)], list(config.keys()))
+
+    def test_multiple_set(self) -> None:
+        config = CaseInsensitiveOrderedMultiDict()
+        config[("core",)] = "value1"
+        config[("core",)] = "value2"
+        # The second set overwrites the first one
+        self.assertEqual("value2", config[("core",)])
+        self.assertEqual("value2", config[("CORE",)])
+
+    def test_get_all(self) -> None:
+        config = CaseInsensitiveOrderedMultiDict()
+        config[("core",)] = "value1"
+        config[("CORE",)] = "value2"
+        config[("CoRe",)] = "value3"
+        self.assertEqual(
+            ["value1", "value2", "value3"], list(config.get_all(("core",)))
+        )
+        self.assertEqual(
+            ["value1", "value2", "value3"], list(config.get_all(("CORE",)))
+        )
+
+    def test_delitem(self) -> None:
+        config = CaseInsensitiveOrderedMultiDict()
+        config[("core",)] = "value1"
+        config[("CORE",)] = "value2"
+        config[("other",)] = "value3"
+        del config[("core",)]
+        self.assertNotIn(("core",), config)
+        self.assertNotIn(("CORE",), config)
+        self.assertEqual("value3", config[("other",)])
+        self.assertEqual(1, len(config))
+
+    def test_len(self) -> None:
+        config = CaseInsensitiveOrderedMultiDict()
+        self.assertEqual(0, len(config))
+        config[("core",)] = "value1"
+        self.assertEqual(1, len(config))
+        config[("CORE",)] = "value2"
+        self.assertEqual(1, len(config))  # Same key, case insensitive
+        config[("other",)] = "value3"
+        self.assertEqual(2, len(config))
+
+    def test_make_from_dict(self) -> None:
+        original = {("core",): "value1", ("other",): "value2"}
+        config = CaseInsensitiveOrderedMultiDict.make(original)
+        self.assertEqual("value1", config[("core",)])
+        self.assertEqual("value1", config[("CORE",)])
+        self.assertEqual("value2", config[("other",)])
+
+    def test_make_from_self(self) -> None:
+        config1 = CaseInsensitiveOrderedMultiDict()
+        config1[("core",)] = "value"
+        config2 = CaseInsensitiveOrderedMultiDict.make(config1)
+        self.assertIs(config1, config2)
+
+    def test_make_invalid_type(self) -> None:
+        self.assertRaises(TypeError, CaseInsensitiveOrderedMultiDict.make, "invalid")
+
+    def test_get_with_default(self) -> None:
+        config = CaseInsensitiveOrderedMultiDict()
+        config[("core",)] = "value"
+        self.assertEqual("value", config.get(("core",)))
+        self.assertEqual("value", config.get(("CORE",)))
+        self.assertEqual("default", config.get(("missing",), "default"))
+        # Test SENTINEL behavior
+        result = config.get(("missing",))
+        self.assertIsInstance(result, CaseInsensitiveOrderedMultiDict)
+        self.assertEqual(0, len(result))
+
+    def test_setdefault(self) -> None:
+        config = CaseInsensitiveOrderedMultiDict()
+        # Set new value
+        result1 = config.setdefault(("core",), "value1")
+        self.assertEqual("value1", result1)
+        self.assertEqual("value1", config[("core",)])
+        # Try to set again with different case - should return existing
+        result2 = config.setdefault(("CORE",), "value2")
+        self.assertEqual("value1", result2)
+        self.assertEqual("value1", config[("core",)])
+
+    def test_values(self) -> None:
+        config = CaseInsensitiveOrderedMultiDict()
+        config[("core",)] = "value1"
+        config[("other",)] = "value2"
+        config[("CORE",)] = "value3"  # Overwrites previous core value
+        self.assertEqual({"value3", "value2"}, set(config.values()))
+
+    def test_items_iteration(self) -> None:
+        config = CaseInsensitiveOrderedMultiDict()
+        config[("core",)] = "value1"
+        config[("other",)] = "value2"
+        config[("CORE",)] = "value3"
+        items = list(config.items())
+        self.assertEqual(3, len(items))
+        self.assertEqual((("core",), "value1"), items[0])
+        self.assertEqual((("other",), "value2"), items[1])
+        self.assertEqual((("CORE",), "value3"), items[2])
+
+    def test_str_keys(self) -> None:
+        config = CaseInsensitiveOrderedMultiDict()
+        config["core"] = "value"
+        self.assertEqual("value", config["CORE"])
+        self.assertEqual("value", config["CoRe"])
+
+    def test_nested_tuple_keys(self) -> None:
+        config = CaseInsensitiveOrderedMultiDict()
+        config[("branch", "master")] = "value"
+        self.assertEqual("value", config[("BRANCH", "MASTER")])
+        self.assertEqual("value", config[("Branch", "Master")])
+
+
+class ConfigFileSetTests(TestCase):
+    def test_set_replaces_value(self) -> None:
+        # Test that set() replaces the value instead of appending
+        cf = ConfigFile()
+        cf.set((b"core",), b"sshCommand", b"ssh -i ~/.ssh/id_rsa1")
+        cf.set((b"core",), b"sshCommand", b"ssh -i ~/.ssh/id_rsa2")
+
+        # Should only have one value
+        self.assertEqual(b"ssh -i ~/.ssh/id_rsa2", cf.get((b"core",), b"sshCommand"))
+
+        # When written to file, should only have one entry
+        f = BytesIO()
+        cf.write_to_file(f)
+        content = f.getvalue()
+        self.assertEqual(1, content.count(b"sshCommand"))
+        self.assertIn(b"sshCommand = ssh -i ~/.ssh/id_rsa2", content)
+        self.assertNotIn(b"id_rsa1", content)