Kaynağa Gözat

Expose object-format capability in client protocol

Jelmer Vernooij 1 ay önce
ebeveyn
işleme
935a3bb5e5
1 değiştirilmiş dosya ile 89 ekleme ve 15 silme
  1. 89 15
      dulwich/client.py

+ 89 - 15
dulwich/client.py

@@ -394,6 +394,24 @@ def read_server_capabilities(pkt_seq: Iterable[bytes]) -> set[bytes]:
     return set(server_capabilities)
     return set(server_capabilities)
 
 
 
 
+def extract_object_format_from_capabilities(
+    capabilities: set[bytes],
+) -> str | None:
+    """Extract object format from server capabilities.
+
+    Args:
+        capabilities: Server capabilities
+
+    Returns:
+        Object format name as string (e.g., "sha1", "sha256"), or None if not specified
+    """
+    for capability in capabilities:
+        k, v = parse_capability(capability)
+        if k == b"object-format":
+            return v.decode("ascii")
+    return None
+
+
 def read_pkt_refs_v2(
 def read_pkt_refs_v2(
     pkt_seq: Iterable[bytes],
     pkt_seq: Iterable[bytes],
 ) -> tuple[dict[Ref, ObjectID | None], dict[Ref, Ref], dict[Ref, ObjectID]]:
 ) -> tuple[dict[Ref, ObjectID | None], dict[Ref, Ref], dict[Ref, ObjectID]]:
@@ -510,11 +528,13 @@ class FetchPackResult(_DeprecatedDictProxy):
       refs: Dictionary with all remote refs
       refs: Dictionary with all remote refs
       symrefs: Dictionary with remote symrefs
       symrefs: Dictionary with remote symrefs
       agent: User agent string
       agent: User agent string
+      object_format: Object format name (e.g., "sha1", "sha256") used by the remote, or None if not specified
     """
     """
 
 
     refs: dict[Ref, ObjectID | None]
     refs: dict[Ref, ObjectID | None]
     symrefs: dict[Ref, Ref]
     symrefs: dict[Ref, Ref]
     agent: bytes | None
     agent: bytes | None
+    object_format: str | None
 
 
     def __init__(
     def __init__(
         self,
         self,
@@ -523,6 +543,7 @@ class FetchPackResult(_DeprecatedDictProxy):
         agent: bytes | None,
         agent: bytes | None,
         new_shallow: set[ObjectID] | None = None,
         new_shallow: set[ObjectID] | None = None,
         new_unshallow: set[ObjectID] | None = None,
         new_unshallow: set[ObjectID] | None = None,
+        object_format: str | None = None,
     ) -> None:
     ) -> None:
         """Initialize FetchPackResult.
         """Initialize FetchPackResult.
 
 
@@ -532,12 +553,14 @@ class FetchPackResult(_DeprecatedDictProxy):
             agent: User agent string
             agent: User agent string
             new_shallow: New shallow commits
             new_shallow: New shallow commits
             new_unshallow: New unshallow commits
             new_unshallow: New unshallow commits
+            object_format: Object format name (e.g., "sha1", "sha256") used by the remote
         """
         """
         self.refs = refs
         self.refs = refs
         self.symrefs = symrefs
         self.symrefs = symrefs
         self.agent = agent
         self.agent = agent
         self.new_shallow = new_shallow
         self.new_shallow = new_shallow
         self.new_unshallow = new_unshallow
         self.new_unshallow = new_unshallow
+        self.object_format = object_format
 
 
     def __eq__(self, other: object) -> bool:
     def __eq__(self, other: object) -> bool:
         """Check equality with another object."""
         """Check equality with another object."""
@@ -563,21 +586,28 @@ class LsRemoteResult(_DeprecatedDictProxy):
     Attributes:
     Attributes:
       refs: Dictionary with all remote refs
       refs: Dictionary with all remote refs
       symrefs: Dictionary with remote symrefs
       symrefs: Dictionary with remote symrefs
+      object_format: Object format name (e.g., "sha1", "sha256") used by the remote, or None if not specified
     """
     """
 
 
     symrefs: dict[Ref, Ref]
     symrefs: dict[Ref, Ref]
+    object_format: str | None
 
 
     def __init__(
     def __init__(
-        self, refs: dict[Ref, ObjectID | None], symrefs: dict[Ref, Ref]
+        self,
+        refs: dict[Ref, ObjectID | None],
+        symrefs: dict[Ref, Ref],
+        object_format: str | None = None,
     ) -> None:
     ) -> None:
         """Initialize LsRemoteResult.
         """Initialize LsRemoteResult.
 
 
         Args:
         Args:
             refs: Dictionary with all remote refs
             refs: Dictionary with all remote refs
             symrefs: Dictionary with remote symrefs
             symrefs: Dictionary with remote symrefs
+            object_format: Object format name (e.g., "sha1", "sha256") used by the remote
         """
         """
         self.refs = refs
         self.refs = refs
         self.symrefs = symrefs
         self.symrefs = symrefs
+        self.object_format = object_format
 
 
     def _warn_deprecated(self) -> None:
     def _warn_deprecated(self) -> None:
         import warnings
         import warnings
@@ -1075,6 +1105,8 @@ class GitClient:
             os.mkdir(target_path)
             os.mkdir(target_path)
 
 
         try:
         try:
+            # Create repository with default SHA-1 format initially
+            # We'll update it based on the remote's object format after the first fetch
             target = None
             target = None
             if not bare:
             if not bare:
                 target = Repo.init(target_path)
                 target = Repo.init(target_path)
@@ -1114,6 +1146,15 @@ class GitClient:
                 filter_spec=filter_spec,
                 filter_spec=filter_spec,
                 protocol_version=protocol_version,
                 protocol_version=protocol_version,
             )
             )
+
+            # Update object format if the remote uses a different one
+            # This must happen before any objects are written, but fetch has already
+            # transferred them. Subclasses can override to detect format earlier.
+            if (
+                result.object_format
+                and result.object_format != target.object_format.name
+            ):
+                target._update_object_format(result.object_format)
             if origin is not None:
             if origin is not None:
                 _import_remote_refs(
                 _import_remote_refs(
                     target.refs, origin, result.refs, message=ref_message
                     target.refs, origin, result.refs, message=ref_message
@@ -1660,6 +1701,7 @@ class TraditionalGitClient(GitClient):
             refs: dict[Ref, ObjectID | None]
             refs: dict[Ref, ObjectID | None]
             symrefs: dict[Ref, Ref]
             symrefs: dict[Ref, Ref]
             agent: bytes | None
             agent: bytes | None
+            object_format: str | None
             if self.protocol_version == 2:
             if self.protocol_version == 2:
                 try:
                 try:
                     server_capabilities = read_server_capabilities(proto.read_pkt_seq())
                     server_capabilities = read_server_capabilities(proto.read_pkt_seq())
@@ -1670,6 +1712,9 @@ class TraditionalGitClient(GitClient):
                     symrefs,
                     symrefs,
                     agent,
                     agent,
                 ) = self._negotiate_upload_pack_capabilities(server_capabilities)
                 ) = self._negotiate_upload_pack_capabilities(server_capabilities)
+                object_format = extract_object_format_from_capabilities(
+                    server_capabilities
+                )
 
 
                 proto.write_pkt_line(b"command=ls-refs\n")
                 proto.write_pkt_line(b"command=ls-refs\n")
                 proto.write(b"0001")  # delim-pkt
                 proto.write(b"0001")  # delim-pkt
@@ -1695,13 +1740,18 @@ class TraditionalGitClient(GitClient):
                     symrefs,
                     symrefs,
                     agent,
                     agent,
                 ) = self._negotiate_upload_pack_capabilities(server_capabilities)
                 ) = self._negotiate_upload_pack_capabilities(server_capabilities)
+                object_format = extract_object_format_from_capabilities(
+                    server_capabilities
+                )
 
 
                 if ref_prefix is not None:
                 if ref_prefix is not None:
                     refs = filter_ref_prefix(refs, ref_prefix)
                     refs = filter_ref_prefix(refs, ref_prefix)
 
 
             if refs is None:
             if refs is None:
                 proto.write_pkt_line(None)
                 proto.write_pkt_line(None)
-                return FetchPackResult(refs, symrefs, agent)
+                return FetchPackResult(
+                    refs, symrefs, agent, object_format=object_format
+                )
 
 
             try:
             try:
                 # Filter out None values (shouldn't be any in v1 protocol)
                 # Filter out None values (shouldn't be any in v1 protocol)
@@ -1719,7 +1769,9 @@ class TraditionalGitClient(GitClient):
                 wants = [cid for cid in wants if cid != ZERO_SHA]
                 wants = [cid for cid in wants if cid != ZERO_SHA]
             if not wants:
             if not wants:
                 proto.write_pkt_line(None)
                 proto.write_pkt_line(None)
-                return FetchPackResult(refs, symrefs, agent)
+                return FetchPackResult(
+                    refs, symrefs, agent, object_format=object_format
+                )
             if self.protocol_version == 2:
             if self.protocol_version == 2:
                 proto.write_pkt_line(b"command=fetch\n")
                 proto.write_pkt_line(b"command=fetch\n")
                 proto.write(b"0001")  # delim-pkt
                 proto.write(b"0001")  # delim-pkt
@@ -1757,7 +1809,9 @@ class TraditionalGitClient(GitClient):
                 progress,
                 progress,
                 protocol_version=self.protocol_version,
                 protocol_version=self.protocol_version,
             )
             )
-            return FetchPackResult(refs, symrefs, agent, new_shallow, new_unshallow)
+            return FetchPackResult(
+                refs, symrefs, agent, new_shallow, new_unshallow, object_format
+            )
 
 
     def get_refs(
     def get_refs(
         self,
         self,
@@ -1785,6 +1839,7 @@ class TraditionalGitClient(GitClient):
         self.protocol_version = server_protocol_version
         self.protocol_version = server_protocol_version
         if self.protocol_version == 2:
         if self.protocol_version == 2:
             server_capabilities = read_server_capabilities(proto.read_pkt_seq())
             server_capabilities = read_server_capabilities(proto.read_pkt_seq())
+            object_format = extract_object_format_from_capabilities(server_capabilities)
             proto.write_pkt_line(b"command=ls-refs\n")
             proto.write_pkt_line(b"command=ls-refs\n")
             proto.write(b"0001")  # delim-pkt
             proto.write(b"0001")  # delim-pkt
             proto.write_pkt_line(b"symrefs")
             proto.write_pkt_line(b"symrefs")
@@ -1802,7 +1857,7 @@ class TraditionalGitClient(GitClient):
                 proto.write_pkt_line(None)
                 proto.write_pkt_line(None)
                 for refname, refvalue in peeled.items():
                 for refname, refvalue in peeled.items():
                     refs[Ref(refname + PEELED_TAG_SUFFIX)] = refvalue
                     refs[Ref(refname + PEELED_TAG_SUFFIX)] = refvalue
-                return LsRemoteResult(refs, symrefs)
+                return LsRemoteResult(refs, symrefs, object_format=object_format)
         else:
         else:
             with proto:
             with proto:
                 try:
                 try:
@@ -1814,10 +1869,13 @@ class TraditionalGitClient(GitClient):
                 except HangupException as exc:
                 except HangupException as exc:
                     raise _remote_error_from_stderr(stderr) from exc
                     raise _remote_error_from_stderr(stderr) from exc
                 proto.write_pkt_line(None)
                 proto.write_pkt_line(None)
+                object_format = extract_object_format_from_capabilities(
+                    server_capabilities
+                )
                 (symrefs, _agent) = _extract_symrefs_and_agent(server_capabilities)
                 (symrefs, _agent) = _extract_symrefs_and_agent(server_capabilities)
                 if ref_prefix is not None:
                 if ref_prefix is not None:
                     refs = filter_ref_prefix(refs, ref_prefix)
                     refs = filter_ref_prefix(refs, ref_prefix)
-                return LsRemoteResult(refs, symrefs)
+                return LsRemoteResult(refs, symrefs, object_format=object_format)
 
 
     def archive(
     def archive(
         self,
         self,
@@ -2396,7 +2454,10 @@ class LocalGitClient(GitClient):
                 depth=depth,
                 depth=depth,
             )
             )
             return FetchPackResult(
             return FetchPackResult(
-                _to_optional_dict(refs), r.refs.get_symrefs(), agent_string()
+                _to_optional_dict(refs),
+                r.refs.get_symrefs(),
+                agent_string(),
+                object_format=r.object_format.name,
             )
             )
 
 
     def fetch_pack(
     def fetch_pack(
@@ -2454,7 +2515,9 @@ class LocalGitClient(GitClient):
             # Did the process short-circuit (e.g. in a stateless RPC call)?
             # Did the process short-circuit (e.g. in a stateless RPC call)?
             # Note that the client still expects a 0-object pack in most cases.
             # Note that the client still expects a 0-object pack in most cases.
             if object_ids is None:
             if object_ids is None:
-                return FetchPackResult(None, symrefs, agent)
+                return FetchPackResult(
+                    None, symrefs, agent, object_format=r.object_format.name
+                )
             write_pack_from_container(
             write_pack_from_container(
                 pack_data,  # type: ignore[arg-type]
                 pack_data,  # type: ignore[arg-type]
                 r.object_store,
                 r.object_store,
@@ -2463,7 +2526,12 @@ class LocalGitClient(GitClient):
                 object_format=r.object_format,
                 object_format=r.object_format,
             )
             )
             # Convert refs to Optional type for FetchPackResult
             # Convert refs to Optional type for FetchPackResult
-            return FetchPackResult(_to_optional_dict(r.get_refs()), symrefs, agent)
+            return FetchPackResult(
+                _to_optional_dict(r.get_refs()),
+                symrefs,
+                agent,
+                object_format=r.object_format.name,
+            )
 
 
     def get_refs(
     def get_refs(
         self,
         self,
@@ -2489,7 +2557,9 @@ class LocalGitClient(GitClient):
                 except (KeyError, ValueError):
                 except (KeyError, ValueError):
                     # Not a symbolic ref or error reading it
                     # Not a symbolic ref or error reading it
                     pass
                     pass
-            return LsRemoteResult(refs, symrefs)
+            return LsRemoteResult(
+                refs, symrefs, object_format=target.object_format.name
+            )
 
 
 
 
 class BundleClient(GitClient):
 class BundleClient(GitClient):
@@ -3880,6 +3950,7 @@ class AbstractHttpGitClient(GitClient):
             capa_symrefs,
             capa_symrefs,
             agent,
             agent,
         ) = self._negotiate_upload_pack_capabilities(server_capabilities)
         ) = self._negotiate_upload_pack_capabilities(server_capabilities)
+        object_format = extract_object_format_from_capabilities(server_capabilities)
         if not symrefs and capa_symrefs:
         if not symrefs and capa_symrefs:
             symrefs = capa_symrefs
             symrefs = capa_symrefs
         # Filter out None values from refs for determine_wants
         # Filter out None values from refs for determine_wants
@@ -3891,7 +3962,7 @@ class AbstractHttpGitClient(GitClient):
         if wants is not None:
         if wants is not None:
             wants = [cid for cid in wants if cid != ZERO_SHA]
             wants = [cid for cid in wants if cid != ZERO_SHA]
         if not wants and not self.dumb:
         if not wants and not self.dumb:
-            return FetchPackResult(refs, symrefs, agent)
+            return FetchPackResult(refs, symrefs, agent, object_format=object_format)
         elif self.dumb:
         elif self.dumb:
             # Use dumb HTTP protocol
             # Use dumb HTTP protocol
             from .dumb import DumbRemoteHTTPRepo
             from .dumb import DumbRemoteHTTPRepo
@@ -3930,7 +4001,7 @@ class AbstractHttpGitClient(GitClient):
                     object_format=DEFAULT_OBJECT_FORMAT,
                     object_format=DEFAULT_OBJECT_FORMAT,
                 )
                 )
 
 
-            return FetchPackResult(refs, symrefs, agent)
+            return FetchPackResult(refs, symrefs, agent, object_format=object_format)
         req_data = BytesIO()
         req_data = BytesIO()
         req_proto = Protocol(None, req_data.write)  # type: ignore
         req_proto = Protocol(None, req_data.write)  # type: ignore
         (new_shallow, new_unshallow) = _handle_upload_pack_head(
         (new_shallow, new_unshallow) = _handle_upload_pack_head(
@@ -3977,7 +4048,9 @@ class AbstractHttpGitClient(GitClient):
                 progress,
                 progress,
                 protocol_version=self.protocol_version,
                 protocol_version=self.protocol_version,
             )
             )
-            return FetchPackResult(refs, symrefs, agent, new_shallow, new_unshallow)
+            return FetchPackResult(
+                refs, symrefs, agent, new_shallow, new_unshallow, object_format
+            )
         finally:
         finally:
             resp.close()
             resp.close()
 
 
@@ -3989,15 +4062,16 @@ class AbstractHttpGitClient(GitClient):
     ) -> LsRemoteResult:
     ) -> LsRemoteResult:
         """Retrieve the current refs from a git smart server."""
         """Retrieve the current refs from a git smart server."""
         url = self._get_url(path)
         url = self._get_url(path)
-        refs, _, _, symrefs, peeled = self._discover_references(
+        refs, server_capabilities, _, symrefs, peeled = self._discover_references(
             b"git-upload-pack",
             b"git-upload-pack",
             url,
             url,
             protocol_version=protocol_version,
             protocol_version=protocol_version,
             ref_prefix=ref_prefix,
             ref_prefix=ref_prefix,
         )
         )
+        object_format = extract_object_format_from_capabilities(server_capabilities)
         for refname, refvalue in peeled.items():
         for refname, refvalue in peeled.items():
             refs[Ref(refname + PEELED_TAG_SUFFIX)] = refvalue
             refs[Ref(refname + PEELED_TAG_SUFFIX)] = refvalue
-        return LsRemoteResult(refs, symrefs)
+        return LsRemoteResult(refs, symrefs, object_format=object_format)
 
 
     def get_url(self, path: str) -> str:
     def get_url(self, path: str) -> str:
         """Get the HTTP URL for a path."""
         """Get the HTTP URL for a path."""