Browse Source

Fix handling of symrefs with protocol v2. Fixes #1389

Jelmer Vernooij 5 months ago
parent
commit
f1075d25f3
3 changed files with 63 additions and 39 deletions
  1. 5 0
      NEWS
  2. 54 33
      dulwich/client.py
  3. 4 6
      tests/test_client.py

+ 5 - 0
NEWS

@@ -1,3 +1,8 @@
+0.22.4	UNRELEASED
+
+ * Fix handling of symrefs with protocol v2.
+   (Jelmer Vernooij, #1389)
+
 0.22.3	2024-10-15
 
  * Improve wheel building in CI, so we can upload wheels for the next release.

+ 54 - 33
dulwich/client.py

@@ -258,7 +258,33 @@ def read_server_capabilities(pkt_seq):
     return set(server_capabilities)
 
 
-def read_pkt_refs(pkt_seq, server_capabilities=None):
+def read_pkt_refs_v2(
+    pkt_seq,
+) -> Tuple[Dict[bytes, bytes], Dict[bytes, bytes], Dict[bytes, bytes]]:
+    refs = {}
+    symrefs = {}
+    peeled = {}
+    # Receive refs from server
+    for pkt in pkt_seq:
+        parts = pkt.rstrip(b"\n").split(b" ")
+        sha = parts[0]
+        if sha == b"unborn":
+            sha = None
+        ref = parts[1]
+        for part in parts[2:]:
+            if part.startswith(b"peeled:"):
+                peeled[ref] = part[7:]
+            elif part.startswith(b"symref-target:"):
+                symrefs[ref] = part[14:]
+            else:
+                logging.warning("unknown part in pkt-ref: %s", part)
+        refs[ref] = sha
+
+    return refs, symrefs, peeled
+
+
+def read_pkt_refs_v1(pkt_seq) -> Tuple[Dict[bytes, bytes], Set[bytes]]:
+    server_capabilities = None
     refs = {}
     # Receive refs from server
     for pkt in pkt_seq:
@@ -267,24 +293,13 @@ def read_pkt_refs(pkt_seq, server_capabilities=None):
             raise GitProtocolError(ref.decode("utf-8", "replace"))
         if server_capabilities is None:
             (ref, server_capabilities) = extract_capabilities(ref)
-        else:  # Git protocol-v2:
-            try:
-                symref, target = ref.split(b" ", 1)
-            except ValueError:
-                pass
-            else:
-                if symref and target and target[:14] == b"symref-target:":
-                    server_capabilities.add(
-                        b"%s=%s:%s"
-                        % (CAPABILITY_SYMREF, symref, target.split(b":", 1)[1])
-                    )
-                    ref = symref
         refs[ref] = sha
 
     if len(refs) == 0:
         return {}, set()
     if refs == {CAPABILITIES_REF: ZERO_SHA}:
         refs = {}
+    assert server_capabilities is not None
     return refs, set(server_capabilities)
 
 
@@ -1207,7 +1222,7 @@ class TraditionalGitClient(GitClient):
         proto, unused_can_read, stderr = self._connect(b"receive-pack", path)
         with proto:
             try:
-                old_refs, server_capabilities = read_pkt_refs(proto.read_pkt_seq())
+                old_refs, server_capabilities = read_pkt_refs_v1(proto.read_pkt_seq())
             except HangupException as exc:
                 raise _remote_error_from_stderr(stderr) from exc
             (
@@ -1340,7 +1355,7 @@ class TraditionalGitClient(GitClient):
                     server_capabilities = read_server_capabilities(proto.read_pkt_seq())
                     refs = None
                 else:
-                    refs, server_capabilities = read_pkt_refs(proto.read_pkt_seq())
+                    refs, server_capabilities = read_pkt_refs_v1(proto.read_pkt_seq())
             except HangupException as exc:
                 raise _remote_error_from_stderr(stderr) from exc
             (
@@ -1356,9 +1371,7 @@ class TraditionalGitClient(GitClient):
                 for prefix in ref_prefix:
                     proto.write_pkt_line(b"ref-prefix " + prefix)
                 proto.write_pkt_line(None)
-                refs, server_capabilities = read_pkt_refs(
-                    proto.read_pkt_seq(), server_capabilities
-                )
+                refs, symrefs, _peeled = read_pkt_refs_v2(proto.read_pkt_seq())
 
             if refs is None:
                 proto.write_pkt_line(None)
@@ -1436,17 +1449,22 @@ class TraditionalGitClient(GitClient):
             proto.write(b"0001")  # delim-pkt
             proto.write_pkt_line(b"symrefs")
             proto.write_pkt_line(None)
+            with proto:
+                try:
+                    refs, _symrefs, _peeled = read_pkt_refs_v2(proto.read_pkt_seq())
+                except HangupException as exc:
+                    raise _remote_error_from_stderr(stderr) from exc
+                proto.write_pkt_line(None)
+                return refs
         else:
-            server_capabilities = None  # read_pkt_refs will find them
-        with proto:
-            try:
-                refs, server_capabilities = read_pkt_refs(
-                    proto.read_pkt_seq(), server_capabilities
-                )
-            except HangupException as exc:
-                raise _remote_error_from_stderr(stderr) from exc
-            proto.write_pkt_line(None)
-            return refs
+            with proto:
+                try:
+                    refs, server_capabilities = read_pkt_refs_v1(proto.read_pkt_seq())
+                except HangupException as exc:
+                    raise _remote_error_from_stderr(stderr) from exc
+                proto.write_pkt_line(None)
+                (_symrefs, _agent) = _extract_symrefs_and_agent(server_capabilities)
+                return refs
 
     def archive(
         self,
@@ -2395,6 +2413,9 @@ class AbstractHttpGitClient(GitClient):
                 self.protocol_version = server_protocol_version
                 if self.protocol_version == 2:
                     server_capabilities, resp, read, proto = begin_protocol_v2(proto)
+                    (refs, _symrefs, _peeled) = read_pkt_refs_v2(proto.read_pkt_seq())
+                    return refs, server_capabilities, base_url
+
                 else:
                     server_capabilities = None  # read_pkt_refs will find them
                     try:
@@ -2425,11 +2446,11 @@ class AbstractHttpGitClient(GitClient):
                         server_capabilities, resp, read, proto = begin_protocol_v2(
                             proto
                         )
-                (
-                    refs,
-                    server_capabilities,
-                ) = read_pkt_refs(proto.read_pkt_seq(), server_capabilities)
-                return refs, server_capabilities, base_url
+                    (
+                        refs,
+                        server_capabilities,
+                    ) = read_pkt_refs_v1(proto.read_pkt_seq())
+                    return refs, server_capabilities, base_url
             else:
                 self.protocol_version = 0  # dumb servers only support protocol v0
                 return read_info_refs(resp), set(), base_url

+ 4 - 6
tests/test_client.py

@@ -47,6 +47,7 @@ from dulwich.client import (
     SubprocessSSHVendor,
     TCPGitClient,
     TraditionalGitClient,
+    _extract_symrefs_and_agent,
     _remote_error_from_stderr,
     check_wants,
     default_urllib3_manager,
@@ -54,7 +55,6 @@ from dulwich.client import (
     get_transport_and_path,
     get_transport_and_path_from_url,
     parse_rsync_url,
-    _extract_symrefs_and_agent,
 )
 from dulwich.config import ConfigDict
 from dulwich.objects import Commit, Tree
@@ -1871,11 +1871,9 @@ And this line is just random noise, too.
 
 
 class TestExtractAgentAndSymrefs(TestCase):
-
     def test_extract_agent_and_symrefs(self):
-        (agent, symrefs) = _extract_symrefs_and_agent(
-            [b"agent=git/2.31.1", b"symref=HEAD:refs/heads/master"
-             ])
+        (symrefs, agent) = _extract_symrefs_and_agent(
+            [b"agent=git/2.31.1", b"symref=HEAD:refs/heads/master"]
+        )
         self.assertEqual(agent, b"git/2.31.1")
         self.assertEqual(symrefs, {b"HEAD": b"refs/heads/master"})
-