浏览代码

consistently pass around ref_prefix and protocol_version in dulwich.client (#1421)

Jelmer Vernooij 4 月之前
父节点
当前提交
a90b9d75fe
共有 5 个文件被更改,包括 117 次插入26 次删除
  1. 94 25
      dulwich/client.py
  2. 1 0
      dulwich/porcelain.py
  3. 9 0
      dulwich/protocol.py
  4. 1 1
      dulwich/repo.py
  5. 12 0
      tests/test_protocol.py

+ 94 - 25
dulwich/client.py

@@ -115,6 +115,7 @@ from .protocol import (
     extract_capability_names,
     parse_capability,
     pkt_line,
+    pkt_seq,
 )
 from .refs import (
     PEELED_TAG_SUFFIX,
@@ -128,6 +129,12 @@ from .refs import (
 )
 from .repo import Repo
 
+# Default ref prefix, used if none is specified.
+# GitHub defaults to just sending HEAD if no ref-prefix is
+# specified, so explicitly request all refs to match
+# behaviour with v1 when no ref-prefix is specified.
+DEFAULT_REF_PREFIX = [b"HEAD", b"refs/"]
+
 ObjectID = bytes
 
 
@@ -1037,7 +1044,12 @@ class GitClient:
         """
         raise NotImplementedError(self.fetch_pack)
 
-    def get_refs(self, path):
+    def get_refs(
+        self,
+        path,
+        protocol_version: Optional[int] = None,
+        ref_prefix: Optional[list[Ref]] = None,
+    ):
         """Retrieve the current refs from a git smart server.
 
         Args:
@@ -1187,7 +1199,12 @@ class TraditionalGitClient(GitClient):
         self._remote_path_encoding = path_encoding
         super().__init__(**kwargs)
 
-    async def _connect(self, cmd, path, protocol_version=None):
+    def _connect(
+        self,
+        cmd: bytes,
+        path: Union[str, bytes],
+        protocol_version: Optional[int] = None,
+    ) -> tuple[Protocol, Callable[[], bool], Optional[IO[bytes]]]:
         """Create a connection to the server.
 
         This method is abstract - concrete implementations should
@@ -1375,10 +1392,7 @@ class TraditionalGitClient(GitClient):
                 proto.write_pkt_line(b"symrefs")
                 proto.write_pkt_line(b"peel")
                 if ref_prefix is None:
-                    # GitHub defaults to just sending HEAD if no ref-prefix is
-                    # specified, so explicitly request all refs to match
-                    # behaviour with v1 when no ref-prefix is specified.
-                    ref_prefix = [b"HEAD", b"refs/"]
+                    ref_prefix = DEFAULT_REF_PREFIX
                 for prefix in ref_prefix:
                     proto.write_pkt_line(b"ref-prefix " + prefix)
                 proto.write_pkt_line(None)
@@ -1434,7 +1448,12 @@ class TraditionalGitClient(GitClient):
             )
             return FetchPackResult(refs, symrefs, agent, new_shallow, new_unshallow)
 
-    def get_refs(self, path, protocol_version=None):
+    def get_refs(
+        self,
+        path,
+        protocol_version: Optional[int] = None,
+        ref_prefix: Optional[list[Ref]] = None,
+    ):
         """Retrieve the current refs from a git smart server."""
         # stock `git ls-remote` uses upload-pack
         if (
@@ -1460,6 +1479,10 @@ class TraditionalGitClient(GitClient):
             proto.write(b"0001")  # delim-pkt
             proto.write_pkt_line(b"symrefs")
             proto.write_pkt_line(b"peel")
+            if ref_prefix is None:
+                ref_prefix = DEFAULT_REF_PREFIX
+            for prefix in ref_prefix:
+                proto.write_pkt_line(b"ref-prefix " + prefix)
             proto.write_pkt_line(None)
             with proto:
                 try:
@@ -1548,7 +1571,12 @@ class TCPGitClient(TraditionalGitClient):
             netloc += ":%d" % self._port
         return urlunsplit(("git", netloc, path, "", ""))
 
-    def _connect(self, cmd, path, protocol_version=None):
+    def _connect(
+        self,
+        cmd: bytes,
+        path: Union[str, bytes],
+        protocol_version: Optional[int] = None,
+    ) -> tuple[Protocol, Callable[[], bool], Optional[IO[bytes]]]:
         if not isinstance(cmd, bytes):
             raise TypeError(cmd)
         if not isinstance(path, bytes):
@@ -1558,8 +1586,8 @@ class TCPGitClient(TraditionalGitClient):
         )
         s = None
         err = OSError(f"no address found for {self._host}")
-        for family, socktype, proto, canonname, sockaddr in sockaddrs:
-            s = socket.socket(family, socktype, proto)
+        for family, socktype, protof, canonname, sockaddr in sockaddrs:
+            s = socket.socket(family, socktype, protof)
             s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
             try:
                 s.connect(sockaddr)
@@ -1668,7 +1696,12 @@ class SubprocessGitClient(TraditionalGitClient):
 
     git_command = None
 
-    def _connect(self, service, path, protocol_version=None):
+    def _connect(
+        self,
+        service: bytes,
+        path: Union[bytes, str],
+        protocol_version: Optional[int] = None,
+    ) -> tuple[Protocol, Callable[[], bool], Optional[IO[bytes]]]:
         if not isinstance(service, bytes):
             raise TypeError(service)
         if isinstance(path, bytes):
@@ -1890,7 +1923,12 @@ class LocalGitClient(GitClient):
             )
             return FetchPackResult(r.get_refs(), symrefs, agent)
 
-    def get_refs(self, path):
+    def get_refs(
+        self,
+        path,
+        protocol_version: Optional[int] = None,
+        ref_prefix: Optional[list[Ref]] = None,
+    ):
         """Retrieve the current refs from a local on-disk repository."""
         with self._open_repo(path) as target:
             return target.get_refs()
@@ -1952,7 +1990,7 @@ class SubprocessSSHVendor(SSHVendor):
         password=None,
         key_filename=None,
         ssh_command=None,
-        protocol_version=None,
+        protocol_version: Optional[int] = None,
     ):
         if password is not None:
             raise NotImplementedError(
@@ -2127,7 +2165,12 @@ class SSHGitClient(TraditionalGitClient):
         assert isinstance(cmd, bytes)
         return cmd
 
-    def _connect(self, cmd, path, protocol_version=None):
+    def _connect(
+        self,
+        cmd: bytes,
+        path: Union[str, bytes],
+        protocol_version: Optional[int] = None,
+    ) -> tuple[Protocol, Callable[[], bool], Optional[IO[bytes]]]:
         if not isinstance(cmd, bytes):
             raise TypeError(cmd)
         if isinstance(path, bytes):
@@ -2361,7 +2404,11 @@ class AbstractHttpGitClient(GitClient):
         raise NotImplementedError(self._http_request)
 
     def _discover_references(
-        self, service, base_url, protocol_version=None
+        self,
+        service,
+        base_url,
+        protocol_version: Optional[int] = None,
+        ref_prefix: Optional[list[Ref]] = None,
     ) -> tuple[
         dict[Ref, ObjectID], set[bytes], str, dict[Ref, Ref], dict[Ref, ObjectID]
     ]:
@@ -2407,15 +2454,24 @@ class AbstractHttpGitClient(GitClient):
             if not self.dumb:
 
                 def begin_protocol_v2(proto):
+                    nonlocal ref_prefix
                     server_capabilities = read_server_capabilities(proto.read_pkt_seq())
+                    if ref_prefix is None:
+                        ref_prefix = DEFAULT_REF_PREFIX
+
+                    pkts = [
+                        b"symrefs",
+                        b"peel",
+                    ]
+                    for prefix in ref_prefix:
+                        pkts.append(b"ref-prefix " + prefix)
+
+                    body = b"".join(
+                        [pkt_line(b"command=ls-refs\n"), b"0001", pkt_seq(*pkts)]
+                    )
+
                     resp, read = self._smart_request(
-                        service.decode("ascii"),
-                        base_url,
-                        pkt_line(b"command=ls-refs\n")
-                        + b"0001"
-                        + pkt_line(b"symrefs")
-                        + pkt_line(b"peel")
-                        + b"0000",
+                        service.decode("ascii"), base_url, body
                     )
                     proto = Protocol(read, None)
                     return server_capabilities, resp, read, proto
@@ -2613,7 +2669,10 @@ class AbstractHttpGitClient(GitClient):
         """
         url = self._get_url(path)
         refs, server_capabilities, url, symrefs, peeled = self._discover_references(
-            b"git-upload-pack", url, protocol_version
+            b"git-upload-pack",
+            url,
+            protocol_version=protocol_version,
+            ref_prefix=ref_prefix,
         )
         (
             negotiated_capabilities,
@@ -2678,10 +2737,20 @@ class AbstractHttpGitClient(GitClient):
         finally:
             resp.close()
 
-    def get_refs(self, path):
+    def get_refs(
+        self,
+        path,
+        protocol_version: Optional[int] = None,
+        ref_prefix: Optional[list[Ref]] = None,
+    ):
         """Retrieve the current refs from a git smart server."""
         url = self._get_url(path)
-        refs, _, _, _, peeled = self._discover_references(b"git-upload-pack", url)
+        refs, _, _, _, peeled = self._discover_references(
+            b"git-upload-pack",
+            url,
+            protocol_version=protocol_version,
+            ref_prefix=ref_prefix,
+        )
         for refname, refvalue in peeled.items():
             refs[refname + PEELED_TAG_SUFFIX] = refvalue
         return refs

+ 1 - 0
dulwich/porcelain.py

@@ -587,6 +587,7 @@ def clone(
         depth=depth,
         filter_spec=filter_spec,
         protocol_version=protocol_version,
+        **kwargs,
     )
 
 

+ 9 - 0
dulwich/protocol.py

@@ -185,6 +185,15 @@ def pkt_line(data):
     return ("%04x" % (len(data) + 4)).encode("ascii") + data
 
 
+def pkt_seq(*seq):
+    """Wrap a sequence of data in pkt-lines.
+
+    Args:
+      seq: An iterable of strings to wrap.
+    """
+    return b"".join([pkt_line(s) for s in seq]) + pkt_line(None)
+
+
 class Protocol:
     """Class for interacting with a remote git process over the wire.
 

+ 1 - 1
dulwich/repo.py

@@ -1225,7 +1225,7 @@ class Repo(BaseRepo):
             pass
         if committer is None:
             config = self.get_config_stack()
-            committer = self._get_user_identity(config)
+            committer = get_user_identity(config)
         check_user_identity(committer)
         if timestamp is None:
             timestamp = int(time.time())

+ 12 - 0
tests/test_protocol.py

@@ -35,11 +35,23 @@ from dulwich.protocol import (
     ack_type,
     extract_capabilities,
     extract_want_line_capabilities,
+    pkt_line,
+    pkt_seq,
 )
 
 from . import TestCase
 
 
+class PktLinetests:
+    def test_pkt_line(self):
+        self.assertEqual(b"0007bla", pkt_line(b"bla"))
+        self.assertEqual(b"0000", pkt_line(None))
+
+    def test_pkt_seq(self):
+        self.assertEqual(b"0007bla0003foo0000", pkt_seq([b"bla", b"foo"]))
+        self.assertEqual(b"0000", pkt_seq([]))
+
+
 class BaseProtocolTests:
     def test_write_pkt_line_none(self):
         self.proto.write_pkt_line(None)