Procházet zdrojové kódy

Factor out v1ReceivePackHeader.

Jelmer Vernooij před 2 roky
rodič
revize
a5b137aa7f
1 změnil soubory, kde provedl 73 přidání a 62 odebrání
  1. 73 62
      dulwich/client.py

+ 73 - 62
dulwich/client.py

@@ -417,6 +417,68 @@ def _read_shallow_updates(proto):
     return (new_shallow, new_unshallow)
 
 
+class _v1ReceivePackHeader(object):
+
+    def __init__(self, capabilities, old_refs, new_refs):
+        self.want = []
+        self.have = []
+        self._it = self._handle_receive_pack_head(capabilities, old_refs, new_refs)
+        self.sent_capabilities = False
+
+    def __iter__(self):
+        return self._it
+
+    def _handle_receive_pack_head(self, capabilities, old_refs, new_refs):
+        """Handle the head of a 'git-receive-pack' request.
+
+        Args:
+          proto: Protocol object to read from
+          capabilities: List of negotiated capabilities
+          old_refs: Old refs, as received from the server
+          new_refs: Refs to change
+
+        Returns:
+          (have, want) tuple
+        """
+        self.have = [x for x in old_refs.values() if not x == ZERO_SHA]
+
+        for refname in new_refs:
+            if not isinstance(refname, bytes):
+                raise TypeError("refname is not a bytestring: %r" % refname)
+            old_sha1 = old_refs.get(refname, ZERO_SHA)
+            if not isinstance(old_sha1, bytes):
+                raise TypeError(
+                    "old sha1 for %s is not a bytestring: %r" % (refname, old_sha1)
+                )
+            new_sha1 = new_refs.get(refname, ZERO_SHA)
+            if not isinstance(new_sha1, bytes):
+                raise TypeError(
+                    "old sha1 for %s is not a bytestring %r" % (refname, new_sha1)
+                )
+
+            if old_sha1 != new_sha1:
+                logger.debug(
+                    'Sending updated ref %r: %r -> %r',
+                    refname, old_sha1, new_sha1)
+                if self.sent_capabilities:
+                    yield old_sha1 + b" " + new_sha1 + b" " + refname
+                else:
+                    yield (
+                        old_sha1
+                        + b" "
+                        + new_sha1
+                        + b" "
+                        + refname
+                        + b"\0"
+                        + b" ".join(sorted(capabilities))
+                    )
+                    self.sent_capabilities = True
+            if new_sha1 not in self.have and new_sha1 != ZERO_SHA:
+                self.want.append(new_sha1)
+        yield None
+
+
+
 # TODO(durin42): this doesn't correctly degrade if the server doesn't
 # support some capabilities. This should work properly with servers
 # that don't support multi_ack.
@@ -688,58 +750,6 @@ class GitClient(object):
         # The packfile MUST NOT be sent if the only command used is delete.
         return any(sha != ZERO_SHA for sha in new_refs.values())
 
-    def _handle_receive_pack_head(self, proto, capabilities, old_refs, new_refs):
-        """Handle the head of a 'git-receive-pack' request.
-
-        Args:
-          proto: Protocol object to read from
-          capabilities: List of negotiated capabilities
-          old_refs: Old refs, as received from the server
-          new_refs: Refs to change
-
-        Returns:
-          (have, want) tuple
-        """
-        want = []
-        have = [x for x in old_refs.values() if not x == ZERO_SHA]
-        sent_capabilities = False
-
-        for refname in new_refs:
-            if not isinstance(refname, bytes):
-                raise TypeError("refname is not a bytestring: %r" % refname)
-            old_sha1 = old_refs.get(refname, ZERO_SHA)
-            if not isinstance(old_sha1, bytes):
-                raise TypeError(
-                    "old sha1 for %s is not a bytestring: %r" % (refname, old_sha1)
-                )
-            new_sha1 = new_refs.get(refname, ZERO_SHA)
-            if not isinstance(new_sha1, bytes):
-                raise TypeError(
-                    "old sha1 for %s is not a bytestring %r" % (refname, new_sha1)
-                )
-
-            if old_sha1 != new_sha1:
-                logger.debug(
-                    'Sending updated ref %r: %r -> %r',
-                    refname, old_sha1, new_sha1)
-                if sent_capabilities:
-                    proto.write_pkt_line(old_sha1 + b" " + new_sha1 + b" " + refname)
-                else:
-                    proto.write_pkt_line(
-                        old_sha1
-                        + b" "
-                        + new_sha1
-                        + b" "
-                        + refname
-                        + b"\0"
-                        + b" ".join(sorted(capabilities))
-                    )
-                    sent_capabilities = True
-            if new_sha1 not in have and new_sha1 != ZERO_SHA:
-                want.append(new_sha1)
-        proto.write_pkt_line(None)
-        return (have, want)
-
     def _negotiate_receive_pack_capabilities(self, server_capabilities):
         negotiated_capabilities = self._send_capabilities & server_capabilities
         agent = None
@@ -1052,13 +1062,14 @@ class TraditionalGitClient(GitClient):
                     ref_status = None
                 return SendPackResult(old_refs, agent=agent, ref_status=ref_status)
 
-            (have, want) = self._handle_receive_pack_head(
-                proto, negotiated_capabilities, old_refs, new_refs
-            )
+            header_handler = _v1ReceivePackHeader(negotiated_capabilities, old_refs, new_refs)
+
+            for pkt_line in header_handler:
+                proto.write_pkt_line(pkt_line)
 
             pack_data_count, pack_data = generate_pack_data(
-                have,
-                want,
+                header_handler.have,
+                header_handler.want,
                 ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities),
             )
 
@@ -2007,12 +2018,12 @@ class AbstractHttpGitClient(GitClient):
             raise NotImplementedError(self.fetch_pack)
         req_data = BytesIO()
         req_proto = Protocol(None, req_data.write)
-        (have, want) = self._handle_receive_pack_head(
-            req_proto, negotiated_capabilities, old_refs, new_refs
-        )
+        header_handler = _v1ReceivePackHeader(negotiated_capabilities, old_refs, new_refs)
+        for pkt_line in header_handler:
+            req_proto.write_pkt_line(pkt_line)
         pack_data_count, pack_data = generate_pack_data(
-            have,
-            want,
+            header_handler.have,
+            header_handler.want,
             ofs_delta=(CAPABILITY_OFS_DELTA in negotiated_capabilities),
         )
         if self._should_send_pack(new_refs):