Jelmer Vernooij преди 2 години
родител
ревизия
d8d09a4e27
променени са 1 файла, в които са добавени 29 реда и са изтрити 27 реда
  1. 29 27
      dulwich/client.py

+ 29 - 27
dulwich/client.py

@@ -102,6 +102,7 @@ from dulwich.protocol import (
     ZERO_SHA,
     extract_capabilities,
     parse_capability,
+    pkt_line,
 )
 from dulwich.pack import (
     write_pack_data,
@@ -229,11 +230,11 @@ class ReportStatusParser(object):
             self._ref_statuses.append(ref_status)
 
 
-def read_pkt_refs(proto):
+def read_pkt_refs(pkt_seq):
     server_capabilities = None
     refs = {}
     # Receive refs from server
-    for pkt in proto.read_pkt_seq():
+    for pkt in pkt_seq:
         (sha, ref) = pkt.rstrip(b"\n").split(None, 1)
         if sha == b"ERR":
             raise GitProtocolError(ref.decode("utf-8", "replace"))
@@ -404,10 +405,10 @@ class SendPackResult(object):
         return "%s(%r, %r)" % (self.__class__.__name__, self.refs, self.agent)
 
 
-def _read_shallow_updates(proto):
+def _read_shallow_updates(pkt_seq):
     new_shallow = set()
     new_unshallow = set()
-    for pkt in proto.read_pkt_seq():
+    for pkt in pkt_seq:
         cmd, sha = pkt.split(b" ", 1)
         if cmd == COMMAND_SHALLOW:
             new_shallow.add(sha.strip())
@@ -862,7 +863,7 @@ class GitClient(object):
                 )
             proto.write_pkt_line(None)
             if can_read is not None:
-                (new_shallow, new_unshallow) = _read_shallow_updates(proto)
+                (new_shallow, new_unshallow) = _read_shallow_updates(proto.read_pkt_seq())
             else:
                 new_shallow = new_unshallow = None
         else:
@@ -1017,7 +1018,7 @@ class TraditionalGitClient(GitClient):
         proto, unused_can_read, stderr = self._connect(b"receive-pack", path)
         with proto:
             try:
-                old_refs, server_capabilities = read_pkt_refs(proto)
+                old_refs, server_capabilities = read_pkt_refs(proto.read_pkt_seq())
             except HangupException:
                 raise _remote_error_from_stderr(stderr)
             (
@@ -1065,8 +1066,8 @@ class TraditionalGitClient(GitClient):
 
             header_handler = _v1ReceivePackHeader(negotiated_capabilities, old_refs, new_refs)
 
-            for pkt_line in header_handler:
-                proto.write_pkt_line(pkt_line)
+            for pkt in header_handler:
+                proto.write_pkt_line(pkt)
 
             pack_data_count, pack_data = generate_pack_data(
                 header_handler.have,
@@ -1111,7 +1112,7 @@ class TraditionalGitClient(GitClient):
         proto, can_read, stderr = self._connect(b"upload-pack", path)
         with proto:
             try:
-                refs, server_capabilities = read_pkt_refs(proto)
+                refs, server_capabilities = read_pkt_refs(proto.read_pkt_seq())
             except HangupException:
                 raise _remote_error_from_stderr(stderr)
             (
@@ -1160,7 +1161,7 @@ class TraditionalGitClient(GitClient):
         proto, _, stderr = self._connect(b"upload-pack", path)
         with proto:
             try:
-                refs, _ = read_pkt_refs(proto)
+                refs, _ = read_pkt_refs(proto.read_pkt_seq())
             except HangupException:
                 raise _remote_error_from_stderr(stderr)
             proto.write_pkt_line(None)
@@ -1951,7 +1952,7 @@ class AbstractHttpGitClient(GitClient):
                     raise GitProtocolError(
                         "unexpected first line %r from smart server" % pkt
                     )
-                return read_pkt_refs(proto) + (base_url,)
+                return read_pkt_refs(proto.read_pkt_seq()) + (base_url,)
             else:
                 return read_info_refs(resp), set(), base_url
         finally:
@@ -2018,21 +2019,21 @@ class AbstractHttpGitClient(GitClient):
             return SendPackResult(new_refs, agent=agent, ref_status={})
         if self.dumb:
             raise NotImplementedError(self.fetch_pack)
-        req_data = BytesIO()
-        req_proto = Protocol(None, req_data.write)
-        header_handler = _v1ReceivePackHeader(negotiated_capabilities, old_refs, new_refs)
-        for pkt_line in header_handler:
-            req_proto.write_pkt_line(pkt_line)
-        pack_data_count, pack_data = generate_pack_data(
-            header_handler.have,
-            header_handler.want,
-            ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities),
-        )
-        if self._should_send_pack(new_refs):
-            for chunk in PackChunkGenerator(pack_data_count, pack_data):
-                req_proto.write(chunk)
+
+        def body_generator():
+            header_handler = _v1ReceivePackHeader(negotiated_capabilities, old_refs, new_refs)
+            for pkt in header_handler:
+                yield pkt_line(pkt)
+            pack_data_count, pack_data = generate_pack_data(
+                header_handler.have,
+                header_handler.want,
+                ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities),
+            )
+            if self._should_send_pack(new_refs):
+                yield from PackChunkGenerator(pack_data_count, pack_data)
+
         resp, read = self._smart_request(
-            "git-receive-pack", url, data=req_data.getvalue()
+            "git-receive-pack", url, data=b"".join(body_generator())
         )
         try:
             resp_proto = Protocol(read, None)
@@ -2101,7 +2102,7 @@ class AbstractHttpGitClient(GitClient):
         try:
             resp_proto = Protocol(read, None)
             if new_shallow is None and new_unshallow is None:
-                (new_shallow, new_unshallow) = _read_shallow_updates(resp_proto)
+                (new_shallow, new_unshallow) = _read_shallow_updates(resp_proto.read_pkt_seq())
             self._handle_upload_pack_tail(
                 resp_proto,
                 negotiated_capabilities,
@@ -2200,7 +2201,8 @@ class Urllib3HttpGitClient(AbstractHttpGitClient):
             req_headers["Accept-Encoding"] = "identity"
 
         if data is None:
-            resp = self.pool_manager.request("GET", url, headers=req_headers, preload_content=False)
+            resp = self.pool_manager.request(
+                "GET", url, headers=req_headers, preload_content=False)
         else:
             resp = self.pool_manager.request(
                 "POST", url, headers=req_headers, body=data, preload_content=False