Sfoglia il codice sorgente

Some improvements to LFS config handling

Jelmer Vernooij 1 mese fa
parent
commit
e25b7e1d1c
4 ha cambiato i file con 99 aggiunte e 77 eliminazioni
  1. 2 1
      dulwich/filters.py
  2. 64 65
      dulwich/lfs.py
  3. 3 3
      dulwich/porcelain.py
  4. 30 8
      tests/test_porcelain_filters.py

+ 2 - 1
dulwich/filters.py

@@ -214,7 +214,8 @@ class FilterRegistry:
             lfs_dir = tempfile.mkdtemp(prefix="dulwich-lfs-")
             lfs_store = LFSStore.create(lfs_dir)
 
-        return LFSFilterDriver(lfs_store, repo=registry.repo)
+        config = registry.repo.get_config_stack() if registry.repo else None
+        return LFSFilterDriver(lfs_store, config=config)
 
     def _create_text_filter(self, registry: "FilterRegistry") -> FilterDriver:
         """Create text filter driver for line ending conversion.

+ 64 - 65
dulwich/lfs.py

@@ -27,7 +27,7 @@ import tempfile
 from collections.abc import Iterable
 from dataclasses import dataclass
 from typing import TYPE_CHECKING, BinaryIO, Optional, Union
-from urllib.parse import urljoin
+from urllib.parse import urljoin, urlparse
 from urllib.request import Request, urlopen
 
 if TYPE_CHECKING:
@@ -97,6 +97,13 @@ class LFSStore:
             return cls.create(lfs_dir)
         return cls(lfs_dir)
 
+    @classmethod
+    def from_controldir(cls, controldir: str, create: bool = False) -> "LFSStore":
+        lfs_dir = os.path.join(controldir, "lfs")
+        if create:
+            return cls.create(lfs_dir)
+        return cls(lfs_dir)
+
     def _sha_path(self, sha: str) -> str:
         return os.path.join(self.path, "objects", sha[0:2], sha[2:4], sha)
 
@@ -201,9 +208,11 @@ class LFSPointer:
 class LFSFilterDriver:
     """LFS filter driver implementation."""
 
-    def __init__(self, lfs_store: "LFSStore", repo: Optional["Repo"] = None) -> None:
+    def __init__(
+        self, lfs_store: "LFSStore", config: Optional["Config"] = None
+    ) -> None:
         self.lfs_store = lfs_store
-        self.repo = repo
+        self.config = config
 
     def clean(self, data: bytes) -> bytes:
         """Convert file content to LFS pointer (clean filter)."""
@@ -243,10 +252,9 @@ class LFSFilterDriver:
             except LFSError as e:
                 # Download failed, fall back to returning pointer
                 logging.warning("LFS object download failed for %s: %s", pointer.oid, e)
-                pass
 
-            # Return pointer as-is when object is missing and download failed
-            return data
+                # Return pointer as-is when object is missing and download failed
+                return data
 
     def _download_object(self, pointer: LFSPointer) -> bytes:
         """Download an LFS object from the server.
@@ -260,17 +268,13 @@ class LFSFilterDriver:
         Raises:
             LFSError: If download fails for any reason
         """
-        if self.repo is None:
-            raise LFSError("No repository available for LFS download")
-
-        # Get LFS server URL from remote
-        lfs_url = self._get_lfs_url()
-        if not lfs_url:
-            raise LFSError("No LFS server URL configured")
+        if self.config is None:
+            raise LFSError("No configuration available for LFS download")
 
         # Create LFS client and download
-        config = self.repo.get_config_stack() if self.repo else None
-        client = LFSClient(lfs_url, config=config)
+        client = LFSClient.from_config(self.config)
+        if client is None:
+            raise LFSError("No LFS client available from configuration")
         content = client.download(pointer.oid, pointer.size)
 
         # Store the downloaded content in local LFS store
@@ -278,25 +282,53 @@ class LFSFilterDriver:
 
         # Verify the stored OID matches what we expected
         if stored_oid != pointer.oid:
-            raise LFSError(f"Downloaded OID mismatch: expected {pointer.oid}, got {stored_oid}")
+            raise LFSError(
+                f"Downloaded OID mismatch: expected {pointer.oid}, got {stored_oid}"
+            )
 
         return content
 
-    def _get_lfs_url(self) -> Optional[str]:
-        """Get LFS server URL from repository configuration.
 
-        Returns:
-            LFS server URL or None if not configured
+def _get_lfs_user_agent(config):
+    """Get User-Agent string for LFS requests, respecting git config."""
+    try:
+        if config:
+            # Use configured user agent verbatim if set
+            return config.get(b"http", b"useragent").decode()
+    except KeyError:
+        pass
+
+    # Default LFS user agent (similar to git-lfs format)
+    from . import __version__
+
+    version_str = ".".join([str(x) for x in __version__])
+    return f"git-lfs/dulwich/{version_str}"
+
+
+class LFSClient:
+    """LFS client for network operations."""
+
+    def __init__(self, url: str, config: Optional["Config"] = None) -> None:
+        """Initialize LFS client.
+
+        Args:
+            url: LFS server URL
+            config: Optional git config for authentication/proxy settings
         """
-        if self.repo is None:
-            return None
+        self._base_url = url.rstrip("/") + "/"  # Ensure trailing slash for urljoin
+        self.config = config
+        self._pool_manager = None
 
+    @classmethod
+    def from_config(cls, config: "Config") -> Optional["LFSClient"]:
+        """Create LFS client from git config."""
         # Try to get LFS URL from config first
-        config = self.repo.get_config_stack()
         try:
-            return config.get((b"lfs",), b"url").decode()
+            url = config.get((b"lfs",), b"url").decode()
         except KeyError:
             pass
+        else:
+            return cls(url, config)
 
         # Fall back to deriving from remote URL (same as git-lfs)
         try:
@@ -319,47 +351,14 @@ class LFSFilterDriver:
 
             # Standard LFS endpoint is remote_url + "/info/lfs"
             lfs_url = f"{remote_url}/info/lfs"
-            
-            # Validate URL by parsing it
-            from urllib.parse import urlparse
+
             parsed = urlparse(lfs_url)
             if not parsed.scheme or not parsed.netloc:
                 return None
-                
-            return lfs_url
-
-        return None
-
-
-def _get_lfs_user_agent(config):
-    """Get User-Agent string for LFS requests, respecting git config."""
-    try:
-        if config:
-            # Use configured user agent verbatim if set
-            return config.get(b"http", b"useragent").decode()
-    except KeyError:
-        pass
-
-    # Default LFS user agent (similar to git-lfs format)
-    from . import __version__
-
-    version_str = ".".join([str(x) for x in __version__])
-    return f"git-lfs/dulwich/{version_str}"
-
 
-class LFSClient:
-    """LFS client for network operations."""
-
-    def __init__(self, url: str, config: Optional["Config"] = None) -> None:
-        """Initialize LFS client.
+            return LFSClient(lfs_url, config)
 
-        Args:
-            url: LFS server URL
-            config: Optional git config for authentication/proxy settings
-        """
-        self._base_url = url.rstrip("/") + "/"  # Ensure trailing slash for urljoin
-        self.config = config
-        self._pool_manager = None
+        return None
 
     @property
     def url(self) -> str:
@@ -369,11 +368,9 @@ class LFSClient:
     def _get_pool_manager(self):
         """Get urllib3 pool manager with git config applied."""
         if self._pool_manager is None:
-            # For now, use plain urllib3 since dulwich's version has issues with LFS
-            # TODO: Investigate why default_urllib3_manager breaks LFS requests
-            import urllib3
+            from dulwich.client import default_urllib3_manager
 
-            self._pool_manager = urllib3.PoolManager()
+            self._pool_manager = default_urllib3_manager(self.config)
         return self._pool_manager
 
     def _make_request(
@@ -397,7 +394,9 @@ class LFSClient:
         pool_manager = self._get_pool_manager()
         response = pool_manager.request(method, url, headers=req_headers, body=data)
         if response.status >= 400:
-            raise ValueError(f"HTTP {response.status}: {response.data.decode('utf-8', errors='ignore')}")
+            raise ValueError(
+                f"HTTP {response.status}: {response.data.decode('utf-8', errors='ignore')}"
+            )
         return response.data
 
     def batch(

+ 3 - 3
dulwich/porcelain.py

@@ -5068,7 +5068,7 @@ def lfs_clean(repo=".", path=None):
 
         # Get LFS store
         lfs_store = LFSStore.from_repo(r)
-        filter_driver = LFSFilterDriver(lfs_store, repo=r)
+        filter_driver = LFSFilterDriver(lfs_store, config=r.get_config())
 
         # Read file content
         full_path = os.path.join(r.path, path)
@@ -5097,7 +5097,7 @@ def lfs_smudge(repo=".", pointer_content=None):
 
         # Get LFS store
         lfs_store = LFSStore.from_repo(r)
-        filter_driver = LFSFilterDriver(lfs_store, repo=r)
+        filter_driver = LFSFilterDriver(lfs_store, config=r.get_config())
 
         # Smudge the pointer (retrieve actual content)
         return filter_driver.smudge(pointer_content)
@@ -5162,7 +5162,7 @@ def lfs_migrate(repo=".", include=None, exclude=None, everything=False):
     with open_repo_closing(repo) as r:
         # Initialize LFS if needed
         lfs_store = LFSStore.from_repo(r, create=True)
-        filter_driver = LFSFilterDriver(lfs_store, repo=r)
+        filter_driver = LFSFilterDriver(lfs_store, config=r.get_config())
 
         # Get current index
         index = r.open_index()

+ 30 - 8
tests/test_porcelain_filters.py

@@ -218,15 +218,23 @@ class PorcelainFilterTests(TestCase):
 
     def test_process_filter_priority(self) -> None:
         """Test that process filters take priority over built-in ones."""
-        # Create a custom filter script
-        filter_script = os.path.join(self.test_dir, "test-filter.sh")
-        with open(filter_script, "w") as f:
-            f.write("#!/bin/sh\necho 'FILTERED'")
-        os.chmod(filter_script, 0o755)
+        # Create a cross-platform filter command
+        import sys
+
+        if sys.platform == "win32":
+            # On Windows, use echo command directly
+            filter_cmd = "echo FILTERED"
+        else:
+            # On Unix, create a shell script
+            filter_script = os.path.join(self.test_dir, "test-filter.sh")
+            with open(filter_script, "w") as f:
+                f.write("#!/bin/sh\necho 'FILTERED'")
+            os.chmod(filter_script, 0o755)
+            filter_cmd = filter_script
 
         # Configure custom filter
         config = self.repo.get_config()
-        config.set((b"filter", b"test"), b"smudge", filter_script.encode())
+        config.set((b"filter", b"test"), b"smudge", filter_cmd.encode())
         config.write_to_path()
 
         # Create .gitattributes
@@ -247,13 +255,27 @@ class PorcelainFilterTests(TestCase):
 
         # Test smudge
         result = test_driver.smudge(b"original", b"test.txt")
-        self.assertEqual(result, b"FILTERED\n")
+        # Strip line endings to handle platform differences
+        self.assertEqual(result.rstrip(), b"FILTERED")
 
     def test_commit_with_clean_filter(self) -> None:
         """Test committing with a clean filter."""
         # Set up a custom filter in git config
         config = self.repo.get_config()
-        config.set((b"filter", b"testfilter"), b"clean", b"sed 's/SECRET/REDACTED/g'")
+        import sys
+
+        if sys.platform == "win32":
+            # On Windows, use PowerShell for string replacement
+            config.set(
+                (b"filter", b"testfilter"),
+                b"clean",
+                b"powershell -Command \"$input -replace 'SECRET', 'REDACTED'\"",
+            )
+        else:
+            # On Unix, use sed
+            config.set(
+                (b"filter", b"testfilter"), b"clean", b"sed 's/SECRET/REDACTED/g'"
+            )
         config.write_to_path()
 
         # Create .gitattributes to use the filter