Browse Source

Update filter interface to pass path parameter to smudge methods

Add path parameter to smudge method signatures across filter drivers to
enable path-aware filtering. This allows filters to make decisions based
on file paths when converting content during checkout.
Jelmer Vernooij 1 tháng trước cách đây
mục cha
commit
1ddc2c87f0
4 tập tin đã thay đổi với 174 bổ sung47 xóa
  1. 22 13
      dulwich/filters.py
  2. 150 32
      dulwich/lfs.py
  3. 1 1
      dulwich/line_ending.py
  4. 1 1
      tests/test_sparse_patterns.py

+ 22 - 13
dulwich/filters.py

@@ -43,7 +43,7 @@ class FilterDriver(Protocol):
         """Apply clean filter (working tree → repository)."""
         ...
 
-    def smudge(self, data: bytes) -> bytes:
+    def smudge(self, data: bytes, path: bytes = b"") -> bytes:
         """Apply smudge filter (repository → working tree)."""
         ...
 
@@ -56,10 +56,12 @@ class ProcessFilterDriver:
         clean_cmd: Optional[str] = None,
         smudge_cmd: Optional[str] = None,
         required: bool = False,
+        cwd: Optional[str] = None,
     ) -> None:
         self.clean_cmd = clean_cmd
         self.smudge_cmd = smudge_cmd
         self.required = required
+        self.cwd = cwd
 
     def clean(self, data: bytes) -> bytes:
         """Apply clean filter using external process."""
@@ -75,6 +77,7 @@ class ProcessFilterDriver:
                 input=data,
                 capture_output=True,
                 check=True,
+                cwd=self.cwd,
             )
             return result.stdout
         except subprocess.CalledProcessError as e:
@@ -84,20 +87,24 @@ class ProcessFilterDriver:
             logging.warning(f"Optional clean filter failed: {e}")
             return data
 
-    def smudge(self, data: bytes) -> bytes:
+    def smudge(self, data: bytes, path: bytes = b"") -> bytes:
         """Apply smudge filter using external process."""
         if not self.smudge_cmd:
             if self.required:
                 raise FilterError("Smudge command is required but not configured")
             return data
 
+        # Substitute %f placeholder with file path
+        cmd = self.smudge_cmd.replace("%f", path.decode("utf-8", errors="replace"))
+
         try:
             result = subprocess.run(
-                self.smudge_cmd,
+                cmd,
                 shell=True,
                 input=data,
                 capture_output=True,
                 check=True,
+                cwd=self.cwd,
             )
             return result.stdout
         except subprocess.CalledProcessError as e:
@@ -140,19 +147,19 @@ class FilterRegistry:
         if name in self._drivers:
             return self._drivers[name]
 
-        # Try to create from factory
-        if name in self._factories:
-            factory_driver = self._factories[name](self)
-            self._drivers[name] = factory_driver
-            return factory_driver
-
-        # Try to create from config
+        # Try to create from config first (respect user configuration)
         if self.config is not None:
             config_driver = self._create_from_config(name)
             if config_driver is not None:
                 self._drivers[name] = config_driver
                 return config_driver
 
+        # Try to create from factory as fallback
+        if name in self._factories:
+            factory_driver = self._factories[name](self)
+            self._drivers[name] = factory_driver
+            return factory_driver
+
         return None
 
     def _create_from_config(self, name: str) -> Optional[FilterDriver]:
@@ -187,7 +194,9 @@ class FilterRegistry:
         required = self.config.get_boolean(("filter", name), "required", False)
 
         if clean_cmd or smudge_cmd:
-            return ProcessFilterDriver(clean_cmd, smudge_cmd, required)
+            # Get repository working directory
+            repo_path = self.repo.path if self.repo else None
+            return ProcessFilterDriver(clean_cmd, smudge_cmd, required, repo_path)
 
         return None
 
@@ -205,7 +214,7 @@ class FilterRegistry:
             lfs_dir = tempfile.mkdtemp(prefix="dulwich-lfs-")
             lfs_store = LFSStore.create(lfs_dir)
 
-        return LFSFilterDriver(lfs_store)
+        return LFSFilterDriver(lfs_store, repo=registry.repo)
 
     def _create_text_filter(self, registry: "FilterRegistry") -> FilterDriver:
         """Create text filter driver for line ending conversion.
@@ -397,7 +406,7 @@ class FilterBlobNormalizer:
             return blob
 
         # Apply smudge filter
-        filtered_data = filter_driver.smudge(blob.data)
+        filtered_data = filter_driver.smudge(blob.data, path)
         if filtered_data == blob.data:
             return blob
 

+ 150 - 32
dulwich/lfs.py

@@ -21,16 +21,17 @@
 
 import hashlib
 import json
+import logging
 import os
 import tempfile
 from collections.abc import Iterable
 from dataclasses import dataclass
 from typing import TYPE_CHECKING, BinaryIO, Optional, Union
-from urllib.error import HTTPError
 from urllib.parse import urljoin
 from urllib.request import Request, urlopen
 
 if TYPE_CHECKING:
+    from .config import Config
     from .repo import Repo
 
 
@@ -200,8 +201,9 @@ class LFSPointer:
 class LFSFilterDriver:
     """LFS filter driver implementation."""
 
-    def __init__(self, lfs_store: "LFSStore") -> None:
+    def __init__(self, lfs_store: "LFSStore", repo: Optional["Repo"] = None) -> None:
         self.lfs_store = lfs_store
+        self.repo = repo
 
     def clean(self, data: bytes) -> bytes:
         """Convert file content to LFS pointer (clean filter)."""
@@ -217,7 +219,7 @@ class LFSFilterDriver:
         pointer = LFSPointer(sha, len(data))
         return pointer.to_bytes()
 
-    def smudge(self, data: bytes) -> bytes:
+    def smudge(self, data: bytes, path: bytes = b"") -> bytes:
         """Convert LFS pointer to file content (smudge filter)."""
         # Try to parse as LFS pointer
         pointer = LFSPointer.from_bytes(data)
@@ -234,23 +236,145 @@ class LFSFilterDriver:
             with self.lfs_store.open_object(pointer.oid) as f:
                 return f.read()
         except KeyError:
-            # Object not found in LFS store, return pointer as-is
-            # This matches Git LFS behavior when object is missing
+            # Object not found in LFS store, try to download it
+            try:
+                content = self._download_object(pointer)
+                return content
+            except LFSError as e:
+                # Download failed, fall back to returning pointer
+                logging.warning("LFS object download failed for %s: %s", pointer.oid, e)
+                pass
+
+            # Return pointer as-is when object is missing and download failed
             return data
 
+    def _download_object(self, pointer: LFSPointer) -> bytes:
+        """Download an LFS object from the server.
+
+        Args:
+            pointer: LFS pointer containing OID and size
+
+        Returns:
+            Downloaded content
+
+        Raises:
+            LFSError: If download fails for any reason
+        """
+        if self.repo is None:
+            raise LFSError("No repository available for LFS download")
+
+        # Get LFS server URL from remote
+        lfs_url = self._get_lfs_url()
+        if not lfs_url:
+            raise LFSError("No LFS server URL configured")
+
+        # Create LFS client and download
+        config = self.repo.get_config_stack() if self.repo else None
+        client = LFSClient(lfs_url, config=config)
+        content = client.download(pointer.oid, pointer.size)
+
+        # Store the downloaded content in local LFS store
+        stored_oid = self.lfs_store.write_object([content])
+
+        # Verify the stored OID matches what we expected
+        if stored_oid != pointer.oid:
+            raise LFSError(f"Downloaded OID mismatch: expected {pointer.oid}, got {stored_oid}")
+
+        return content
+
+    def _get_lfs_url(self) -> Optional[str]:
+        """Get LFS server URL from repository configuration.
+
+        Returns:
+            LFS server URL or None if not configured
+        """
+        if self.repo is None:
+            return None
+
+        # Try to get LFS URL from config first
+        config = self.repo.get_config_stack()
+        try:
+            return config.get((b"lfs",), b"url").decode()
+        except KeyError:
+            pass
+
+        # Fall back to deriving from remote URL (same as git-lfs)
+        try:
+            remote_url = config.get((b"remote", b"origin"), b"url").decode()
+        except KeyError:
+            pass
+        else:
+            # Convert SSH URLs to HTTPS if needed
+            if remote_url.startswith("git@"):
+                # Convert git@host:user/repo.git to https://host/user/repo.git
+                if ":" in remote_url and "/" in remote_url:
+                    host_and_path = remote_url[4:]  # Remove "git@"
+                    if ":" in host_and_path:
+                        host, path = host_and_path.split(":", 1)
+                        remote_url = f"https://{host}/{path}"
+
+            # Ensure URL ends with .git for consistent LFS endpoint
+            if not remote_url.endswith(".git"):
+                remote_url = f"{remote_url}.git"
+
+            # Standard LFS endpoint is remote_url + "/info/lfs"
+            lfs_url = f"{remote_url}/info/lfs"
+            
+            # Validate URL by parsing it
+            from urllib.parse import urlparse
+            parsed = urlparse(lfs_url)
+            if not parsed.scheme or not parsed.netloc:
+                return None
+                
+            return lfs_url
+
+        return None
+
+
+def _get_lfs_user_agent(config):
+    """Get User-Agent string for LFS requests, respecting git config."""
+    try:
+        if config:
+            # Use configured user agent verbatim if set
+            return config.get(b"http", b"useragent").decode()
+    except KeyError:
+        pass
+
+    # Default LFS user agent (similar to git-lfs format)
+    from . import __version__
+
+    version_str = ".".join([str(x) for x in __version__])
+    return f"git-lfs/dulwich/{version_str}"
+
 
 class LFSClient:
     """LFS client for network operations."""
 
-    def __init__(self, url: str, auth: Optional[tuple[str, str]] = None) -> None:
+    def __init__(self, url: str, config: Optional["Config"] = None) -> None:
         """Initialize LFS client.
 
         Args:
             url: LFS server URL
-            auth: Optional (username, password) tuple for authentication
+            config: Optional git config for authentication/proxy settings
         """
-        self.url = url.rstrip("/")
-        self.auth = auth
+        self._base_url = url.rstrip("/") + "/"  # Ensure trailing slash for urljoin
+        self.config = config
+        self._pool_manager = None
+
+    @property
+    def url(self) -> str:
+        """Get the LFS server URL without trailing slash."""
+        return self._base_url.rstrip("/")
+
+    def _get_pool_manager(self):
+        """Get urllib3 pool manager with git config applied."""
+        if self._pool_manager is None:
+            # For now, use plain urllib3 since dulwich's version has issues with LFS
+            # TODO: Investigate why default_urllib3_manager breaks LFS requests
+            import urllib3
+
+            self._pool_manager = urllib3.PoolManager()
+        return self._pool_manager
 
     def _make_request(
         self,
@@ -260,29 +384,21 @@ class LFSClient:
         headers: Optional[dict[str, str]] = None,
     ) -> bytes:
         """Make an HTTP request to the LFS server."""
-        url = urljoin(self.url, path)
+        url = urljoin(self._base_url, path)
         req_headers = {
             "Accept": "application/vnd.git-lfs+json",
             "Content-Type": "application/vnd.git-lfs+json",
+            "User-Agent": _get_lfs_user_agent(self.config),
         }
         if headers:
             req_headers.update(headers)
 
-        req = Request(url, data=data, headers=req_headers, method=method)
-
-        if self.auth:
-            import base64
-
-            auth_str = f"{self.auth[0]}:{self.auth[1]}"
-            b64_auth = base64.b64encode(auth_str.encode()).decode("ascii")
-            req.add_header("Authorization", f"Basic {b64_auth}")
-
-        try:
-            with urlopen(req) as response:
-                return response.read()
-        except HTTPError as e:
-            error_body = e.read().decode("utf-8", errors="ignore")
-            raise LFSError(f"LFS server error {e.code}: {error_body}")
+        # Use urllib3 pool manager with git config applied
+        pool_manager = self._get_pool_manager()
+        response = pool_manager.request(method, url, headers=req_headers, body=data)
+        if response.status >= 400:
+            raise ValueError(f"HTTP {response.status}: {response.data.decode('utf-8', errors='ignore')}")
+        return response.data
 
     def batch(
         self,
@@ -311,8 +427,10 @@ class LFSClient:
             data["ref"] = {"name": ref}
 
         response = self._make_request(
-            "POST", "/objects/batch", json.dumps(data).encode("utf-8")
+            "POST", "objects/batch", json.dumps(data).encode("utf-8")
         )
+        if not response:
+            raise ValueError("Empty response from LFS server")
         response_data = json.loads(response)
         return self._parse_batch_response(response_data)
 
@@ -378,14 +496,14 @@ class LFSClient:
         download_action = obj.actions["download"]
         download_url = download_action.href
 
-        # Download the object
-        req = Request(download_url)
+        # Download the object using urllib3 with git config
+        download_headers = {"User-Agent": _get_lfs_user_agent(self.config)}
         if download_action.header:
-            for name, value in download_action.header.items():
-                req.add_header(name, value)
+            download_headers.update(download_action.header)
 
-        with urlopen(req) as response:
-            content = response.read()
+        pool_manager = self._get_pool_manager()
+        response = pool_manager.request("GET", download_url, headers=download_headers)
+        content = response.data
 
         # Verify size
         if len(content) != size:

+ 1 - 1
dulwich/line_ending.py

@@ -178,7 +178,7 @@ class LineEndingFilter(FilterDriver):
 
         return self.clean_conversion(data)
 
-    def smudge(self, data: bytes) -> bytes:
+    def smudge(self, data: bytes, path: bytes = b"") -> bytes:
         """Apply line ending conversion for checkout (repository -> working tree)."""
         if self.smudge_conversion is None:
             return data

+ 1 - 1
tests/test_sparse_patterns.py

@@ -543,7 +543,7 @@ class ApplyIncludedPathsTests(TestCase):
 
         # Create a simple filter that converts content to uppercase
         class UppercaseFilter:
-            def smudge(self, input_bytes):
+            def smudge(self, input_bytes, path=b""):
                 return input_bytes.upper()
 
             def clean(self, input_bytes):