Просмотр исходного кода

Fix LocalGitClient to honor source repository object format during clone.

Jelmer Vernooij 1 месяц назад
Родитель
Сommit
d1f5e6c71e
2 измененных файлов с 164 добавлено и 11 удалено
  1. 116 11
      dulwich/client.py
  2. 48 0
      tests/test_client.py

+ 116 - 11
dulwich/client.py

@@ -407,7 +407,7 @@ def extract_object_format_from_capabilities(
     """
     """
     for capability in capabilities:
     for capability in capabilities:
         k, v = parse_capability(capability)
         k, v = parse_capability(capability)
-        if k == b"object-format":
+        if k == b"object-format" and v is not None:
             return v.decode("ascii")
             return v.decode("ascii")
     return None
     return None
 
 
@@ -1105,8 +1105,9 @@ class GitClient:
             os.mkdir(target_path)
             os.mkdir(target_path)
 
 
         try:
         try:
-            # Create repository with default SHA-1 format initially
-            # We'll update it based on the remote's object format after the first fetch
+            # For network clones, create repository with default SHA-1 format
+            # If remote uses SHA-256, fetch() will raise GitProtocolError
+            # Subclasses (e.g., LocalGitClient) override to detect format first
             target = None
             target = None
             if not bare:
             if not bare:
                 target = Repo.init(target_path)
                 target = Repo.init(target_path)
@@ -1147,14 +1148,11 @@ class GitClient:
                 protocol_version=protocol_version,
                 protocol_version=protocol_version,
             )
             )
 
 
-            # Update object format if the remote uses a different one
-            # This must happen before any objects are written, but fetch has already
-            # transferred them. Subclasses can override to detect format earlier.
-            if (
-                result.object_format
-                and result.object_format != target.object_format.name
-            ):
-                target._update_object_format(result.object_format)
+            # Note: For network clones, if remote uses a different object format than
+            # the default SHA-1, fetch() will raise GitProtocolError. Subclasses like
+            # LocalGitClient override clone() to detect format before creating the repo.
+            # TODO: Fix network clones to detect object format before creating target repo.
+
             if origin is not None:
             if origin is not None:
                 _import_remote_refs(
                 _import_remote_refs(
                     target.refs, origin, result.refs, message=ref_message
                     target.refs, origin, result.refs, message=ref_message
@@ -2571,6 +2569,113 @@ class LocalGitClient(GitClient):
                 refs, symrefs, object_format=target.object_format.name
                 refs, symrefs, object_format=target.object_format.name
             )
             )
 
 
+    def clone(
+        self,
+        path: str,
+        target_path: str,
+        mkdir: bool = True,
+        bare: bool = False,
+        origin: str | None = "origin",
+        checkout: bool | None = None,
+        branch: str | None = None,
+        progress: Callable[[bytes], None] | None = None,
+        depth: int | None = None,
+        ref_prefix: Sequence[bytes] | None = None,
+        filter_spec: bytes | None = None,
+        protocol_version: int | None = None,
+    ) -> Repo:
+        """Clone a local repository.
+
+        For local clones, we can detect the object format before creating
+        the target repository.
+        """
+        # Detect the object format from the source repository
+        with self._open_repo(path) as source_repo:
+            object_format_name = source_repo.object_format.name
+
+        if mkdir:
+            os.mkdir(target_path)
+
+        try:
+            # Create repository with the correct object format from the start
+            target = None
+            if not bare:
+                target = Repo.init(target_path, object_format=object_format_name)
+                if checkout is None:
+                    checkout = True
+            else:
+                if checkout:
+                    raise ValueError("checkout and bare are incompatible")
+                target = Repo.init_bare(target_path, object_format=object_format_name)
+
+            encoded_path = path.encode("utf-8")
+
+            assert target is not None
+            if origin is not None:
+                target_config = target.get_config()
+                target_config.set(
+                    (b"remote", origin.encode("utf-8")), b"url", encoded_path
+                )
+                target_config.set(
+                    (b"remote", origin.encode("utf-8")),
+                    b"fetch",
+                    b"+refs/heads/*:refs/remotes/" + origin.encode("utf-8") + b"/*",
+                )
+                target_config.write_to_path()
+
+            ref_message = b"clone: from " + encoded_path
+            result = self.fetch(
+                path.encode("utf-8"),
+                target,
+                progress=progress,
+                depth=depth,
+                ref_prefix=ref_prefix,
+                filter_spec=filter_spec,
+                protocol_version=protocol_version,
+            )
+
+            if origin is not None:
+                _import_remote_refs(
+                    target.refs, origin, result.refs, message=ref_message
+                )
+
+            origin_head = result.symrefs.get(HEADREF)
+            origin_sha = result.refs.get(HEADREF)
+            if origin is None or (origin_sha and not origin_head):
+                # set detached HEAD
+                if origin_sha is not None:
+                    target.refs[HEADREF] = origin_sha
+                    head = origin_sha
+                else:
+                    head = None
+            else:
+                _set_origin_head(target.refs, origin.encode("utf-8"), origin_head)
+                head_ref = _set_default_branch(
+                    target.refs,
+                    origin.encode("utf-8"),
+                    origin_head,
+                    branch.encode("utf-8") if branch is not None else None,
+                    ref_message,
+                )
+
+                # Update target head
+                if head_ref:
+                    head = _set_head(target.refs, head_ref, ref_message)
+                else:
+                    head = None
+
+            if checkout and head is not None:
+                target.get_worktree().reset_index()
+        except BaseException:
+            if target is not None:
+                target.close()
+            if mkdir:
+                import shutil
+
+                shutil.rmtree(target_path)
+            raise
+        return target
+
 
 
 class BundleClient(GitClient):
 class BundleClient(GitClient):
     """Git Client that reads from a bundle file."""
     """Git Client that reads from a bundle file."""

+ 48 - 0
tests/test_client.py

@@ -1115,6 +1115,54 @@ class LocalGitClientTests(TestCase):
         expected[b"refs/remotes/origin/master"] = expected[b"refs/heads/master"]
         expected[b"refs/remotes/origin/master"] = expected[b"refs/heads/master"]
         self.assertEqual(expected, result_repo.get_refs())
         self.assertEqual(expected, result_repo.get_refs())
 
 
+    def test_clone_sha256_local(self) -> None:
+        """Test that cloning a SHA-256 local repo creates a SHA-256 clone."""
+        client = LocalGitClient()
+
+        # Create a SHA-256 source repository
+        source_path = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, source_path)
+        source_repo = Repo.init(source_path, object_format="sha256")
+
+        # Verify source is SHA-256
+        self.assertEqual("sha256", source_repo.object_format.name)
+        source_repo.close()
+
+        # Clone the repository
+        target_path = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, target_path)
+        cloned_repo = client.clone(source_path, target_path, mkdir=False)
+        self.addCleanup(cloned_repo.close)
+
+        # Verify the clone uses SHA-256
+        self.assertEqual("sha256", cloned_repo.object_format.name)
+
+        # Verify the config has the correct objectformat extension
+        config = cloned_repo.get_config()
+        self.assertEqual(b"sha256", config.get((b"extensions",), b"objectformat"))
+
+    def test_clone_sha1_local(self) -> None:
+        """Test that cloning a SHA-1 local repo creates a SHA-1 clone."""
+        client = LocalGitClient()
+
+        # Create a SHA-1 source repository
+        source_path = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, source_path)
+        source_repo = Repo.init(source_path, object_format="sha1")
+
+        # Verify source is SHA-1
+        self.assertEqual("sha1", source_repo.object_format.name)
+        source_repo.close()
+
+        # Clone the repository
+        target_path = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, target_path)
+        cloned_repo = client.clone(source_path, target_path, mkdir=False)
+        self.addCleanup(cloned_repo.close)
+
+        # Verify the clone uses SHA-1
+        self.assertEqual("sha1", cloned_repo.object_format.name)
+
     def test_fetch_empty(self) -> None:
     def test_fetch_empty(self) -> None:
         c = LocalGitClient()
         c = LocalGitClient()
         s = open_repo("a.git")
         s = open_repo("a.git")