Преглед на файлове

Add get_signature_vendor() factory function

Jelmer Vernooij преди 4 седмици
родител
ревизия
a510467b13
променени са 2 файла, в които са добавени 109 реда и са изтрити 1 реда
  1. 48 0
      dulwich/signature.py
  2. 61 1
      tests/test_signature.py

+ 48 - 0
dulwich/signature.py

@@ -325,5 +325,53 @@ class GPGCliSignatureVendor(SignatureVendor):
                 )
 
 
+def get_signature_vendor(
+    format: str | None = None, config: "Config | None" = None
+) -> SignatureVendor:
+    """Get a signature vendor for the specified format.
+
+    Args:
+      format: Signature format. If None, reads from config's gpg.format setting.
+              Supported values:
+              - "openpgp": Use OpenPGP/GPG signatures (default)
+              - "x509": Use X.509 signatures (not yet implemented)
+              - "ssh": Use SSH signatures (not yet implemented)
+      config: Optional Git configuration
+
+    Returns:
+      SignatureVendor instance for the requested format
+
+    Raises:
+      ValueError: if the format is not supported
+    """
+    # Determine format from config if not specified
+    if format is None:
+        if config is not None:
+            try:
+                format_bytes = config.get((b"gpg",), b"format")
+                format = format_bytes.decode("utf-8") if format_bytes else "openpgp"
+            except KeyError:
+                format = "openpgp"
+        else:
+            format = "openpgp"
+
+    format_lower = format.lower()
+
+    if format_lower == "openpgp":
+        # Try to use GPG package vendor first, fall back to CLI
+        try:
+            import gpg  # noqa: F401
+
+            return GPGSignatureVendor(config=config)
+        except ImportError:
+            return GPGCliSignatureVendor(config=config)
+    elif format_lower == "x509":
+        raise ValueError("X.509 signatures are not yet supported")
+    elif format_lower == "ssh":
+        raise ValueError("SSH signatures are not yet supported")
+    else:
+        raise ValueError(f"Unsupported signature format: {format}")
+
+
 # Default GPG vendor instance
 gpg_vendor = GPGSignatureVendor()

+ 61 - 1
tests/test_signature.py

@@ -26,7 +26,12 @@ import subprocess
 import unittest
 
 from dulwich.config import ConfigDict
-from dulwich.signature import GPGCliSignatureVendor, GPGSignatureVendor, SignatureVendor
+from dulwich.signature import (
+    GPGCliSignatureVendor,
+    GPGSignatureVendor,
+    SignatureVendor,
+    get_signature_vendor,
+)
 
 try:
     import gpg
@@ -268,3 +273,58 @@ class GPGCliSignatureVendorTests(unittest.TestCase):
         config = ConfigDict()
         vendor = GPGCliSignatureVendor(config=config)
         self.assertEqual(vendor.gpg_command, "gpg")
+
+
+class GetSignatureVendorTests(unittest.TestCase):
+    """Tests for get_signature_vendor function."""
+
+    def test_default_format(self) -> None:
+        """Test that default format is openpgp."""
+        vendor = get_signature_vendor()
+        self.assertIsInstance(vendor, (GPGSignatureVendor, GPGCliSignatureVendor))
+
+    def test_explicit_openpgp_format(self) -> None:
+        """Test explicitly requesting openpgp format."""
+        vendor = get_signature_vendor(format="openpgp")
+        self.assertIsInstance(vendor, (GPGSignatureVendor, GPGCliSignatureVendor))
+
+    def test_format_from_config(self) -> None:
+        """Test reading format from config."""
+        config = ConfigDict()
+        config.set((b"gpg",), b"format", b"openpgp")
+
+        vendor = get_signature_vendor(config=config)
+        self.assertIsInstance(vendor, (GPGSignatureVendor, GPGCliSignatureVendor))
+
+    def test_format_case_insensitive(self) -> None:
+        """Test that format is case-insensitive."""
+        vendor = get_signature_vendor(format="OpenPGP")
+        self.assertIsInstance(vendor, (GPGSignatureVendor, GPGCliSignatureVendor))
+
+    def test_x509_not_supported(self) -> None:
+        """Test that x509 format raises ValueError."""
+        with self.assertRaises(ValueError) as cm:
+            get_signature_vendor(format="x509")
+        self.assertIn("X.509", str(cm.exception))
+
+    def test_ssh_not_supported(self) -> None:
+        """Test that ssh format raises ValueError."""
+        with self.assertRaises(ValueError) as cm:
+            get_signature_vendor(format="ssh")
+        self.assertIn("SSH", str(cm.exception))
+
+    def test_invalid_format(self) -> None:
+        """Test that invalid format raises ValueError."""
+        with self.assertRaises(ValueError) as cm:
+            get_signature_vendor(format="invalid")
+        self.assertIn("Unsupported", str(cm.exception))
+
+    def test_config_passed_to_vendor(self) -> None:
+        """Test that config is passed to the vendor."""
+        config = ConfigDict()
+        config.set((b"gpg",), b"program", b"gpg2")
+
+        vendor = get_signature_vendor(format="openpgp", config=config)
+        # If CLI vendor is used, check that config was passed
+        if isinstance(vendor, GPGCliSignatureVendor):
+            self.assertEqual(vendor.gpg_command, "gpg2")