Browse Source

consistently pass ref_prefix, protocol_version around

Jelmer Vernooij 4 months ago
parent
commit
77d0e0eff4
1 changed files with 66 additions and 19 deletions
  1. 66 19
      dulwich/client.py

+ 66 - 19
dulwich/client.py

@@ -115,6 +115,7 @@ from .protocol import (
     extract_capability_names,
     parse_capability,
     pkt_line,
+    pkt_seq,
 )
 from .refs import (
     PEELED_TAG_SUFFIX,
@@ -1043,7 +1044,12 @@ class GitClient:
         """
         raise NotImplementedError(self.fetch_pack)
 
-    def get_refs(self, path):
+    def get_refs(
+        self,
+        path,
+        protocol_version: Optional[int] = None,
+        ref_prefix: Optional[list[Ref]] = None,
+    ):
         """Retrieve the current refs from a git smart server.
 
         Args:
@@ -1193,7 +1199,7 @@ class TraditionalGitClient(GitClient):
         self._remote_path_encoding = path_encoding
         super().__init__(**kwargs)
 
-    async def _connect(self, cmd, path, protocol_version=None):
+    async def _connect(self, cmd, path, protocol_version: Optional[int] = None):
         """Create a connection to the server.
 
         This method is abstract - concrete implementations should
@@ -1437,7 +1443,12 @@ class TraditionalGitClient(GitClient):
             )
             return FetchPackResult(refs, symrefs, agent, new_shallow, new_unshallow)
 
-    def get_refs(self, path, protocol_version=None):
+    def get_refs(
+        self,
+        path,
+        protocol_version: Optional[int] = None,
+        ref_prefix: Optional[list[Ref]] = None,
+    ):
         """Retrieve the current refs from a git smart server."""
         # stock `git ls-remote` uses upload-pack
         if (
@@ -1463,6 +1474,10 @@ class TraditionalGitClient(GitClient):
             proto.write(b"0001")  # delim-pkt
             proto.write_pkt_line(b"symrefs")
             proto.write_pkt_line(b"peel")
+            if ref_prefix is None:
+                ref_prefix = DEFAULT_REF_PREFIX
+            for prefix in ref_prefix:
+                proto.write_pkt_line(b"ref-prefix " + prefix)
             proto.write_pkt_line(None)
             with proto:
                 try:
@@ -1551,7 +1566,7 @@ class TCPGitClient(TraditionalGitClient):
             netloc += ":%d" % self._port
         return urlunsplit(("git", netloc, path, "", ""))
 
-    def _connect(self, cmd, path, protocol_version=None):
+    def _connect(self, cmd, path, protocol_version: Optional[int] = None):
         if not isinstance(cmd, bytes):
             raise TypeError(cmd)
         if not isinstance(path, bytes):
@@ -1671,7 +1686,7 @@ class SubprocessGitClient(TraditionalGitClient):
 
     git_command = None
 
-    def _connect(self, service, path, protocol_version=None):
+    def _connect(self, service, path, protocol_version: Optional[int] = None):
         if not isinstance(service, bytes):
             raise TypeError(service)
         if isinstance(path, bytes):
@@ -1893,7 +1908,12 @@ class LocalGitClient(GitClient):
             )
             return FetchPackResult(r.get_refs(), symrefs, agent)
 
-    def get_refs(self, path):
+    def get_refs(
+        self,
+        path,
+        protocol_version: Optional[int] = None,
+        ref_prefix: Optional[list[Ref]] = None,
+    ):
         """Retrieve the current refs from a local on-disk repository."""
         with self._open_repo(path) as target:
             return target.get_refs()
@@ -1955,7 +1975,7 @@ class SubprocessSSHVendor(SSHVendor):
         password=None,
         key_filename=None,
         ssh_command=None,
-        protocol_version=None,
+        protocol_version: Optional[int] = None,
     ):
         if password is not None:
             raise NotImplementedError(
@@ -2130,7 +2150,7 @@ class SSHGitClient(TraditionalGitClient):
         assert isinstance(cmd, bytes)
         return cmd
 
-    def _connect(self, cmd, path, protocol_version=None):
+    def _connect(self, cmd, path, protocol_version: Optional[int] = None):
         if not isinstance(cmd, bytes):
             raise TypeError(cmd)
         if isinstance(path, bytes):
@@ -2364,7 +2384,11 @@ class AbstractHttpGitClient(GitClient):
         raise NotImplementedError(self._http_request)
 
     def _discover_references(
-        self, service, base_url, protocol_version=None
+        self,
+        service,
+        base_url,
+        protocol_version: Optional[int] = None,
+        ref_prefix: Optional[list[Ref]] = None,
     ) -> tuple[
         dict[Ref, ObjectID], set[bytes], str, dict[Ref, Ref], dict[Ref, ObjectID]
     ]:
@@ -2410,15 +2434,25 @@ class AbstractHttpGitClient(GitClient):
             if not self.dumb:
 
                 def begin_protocol_v2(proto):
+                    nonlocal ref_prefix
                     server_capabilities = read_server_capabilities(proto.read_pkt_seq())
+                    if ref_prefix is None:
+                        ref_prefix = DEFAULT_REF_PREFIX
+
+                    pkts = [
+                        b"symrefs",
+                        b"peel",
+                    ]
+                    for prefix in ref_prefix:
+                        pkts.append(b"ref-prefix " + prefix)
+
+                    body = b"".join([
+                        pkt_line(b"command=ls-refs\n"),
+                        b"0001",
+                        pkt_seq(*pkts)])
+
                     resp, read = self._smart_request(
-                        service.decode("ascii"),
-                        base_url,
-                        pkt_line(b"command=ls-refs\n")
-                        + b"0001"
-                        + pkt_line(b"symrefs")
-                        + pkt_line(b"peel")
-                        + b"0000",
+                        service.decode("ascii"), base_url, body
                     )
                     proto = Protocol(read, None)
                     return server_capabilities, resp, read, proto
@@ -2616,7 +2650,10 @@ class AbstractHttpGitClient(GitClient):
         """
         url = self._get_url(path)
         refs, server_capabilities, url, symrefs, peeled = self._discover_references(
-            b"git-upload-pack", url, protocol_version
+            b"git-upload-pack",
+            url,
+            protocol_version=protocol_version,
+            ref_prefix=ref_prefix,
         )
         (
             negotiated_capabilities,
@@ -2681,10 +2718,20 @@ class AbstractHttpGitClient(GitClient):
         finally:
             resp.close()
 
-    def get_refs(self, path):
+    def get_refs(
+        self,
+        path,
+        protocol_version: Optional[int] = None,
+        ref_prefix: Optional[list[Ref]] = None,
+    ):
         """Retrieve the current refs from a git smart server."""
         url = self._get_url(path)
-        refs, _, _, _, peeled = self._discover_references(b"git-upload-pack", url)
+        refs, _, _, _, peeled = self._discover_references(
+            b"git-upload-pack",
+            url,
+            protocol_version=protocol_version,
+            ref_prefix=ref_prefix,
+        )
         for refname, refvalue in peeled.items():
             refs[refname + PEELED_TAG_SUFFIX] = refvalue
         return refs