Browse Source

Merge trunk.

Jelmer Vernooij 13 years ago
parent
commit
7ec4cd22fe
5 changed files with 234 additions and 31 deletions
  1. 3 0
      NEWS
  2. 144 30
      dulwich/client.py
  3. 33 0
      dulwich/protocol.py
  4. 27 1
      dulwich/tests/test_client.py
  5. 27 0
      dulwich/tests/test_protocol.py

+ 3 - 0
NEWS

@@ -8,6 +8,9 @@
 
   * New ``Repo.clone`` method. (Jelmer Vernooij, #725369)
 
+  * ``GitClient.send_pack`` now supports the 'side-band-64k' capability.
+    (Jelmer Vernooij)
+
  CHANGES
 
   * unittest2 or python >= 2.7 is now required for the testsuite.

+ 144 - 30
dulwich/client.py

@@ -34,6 +34,7 @@ from dulwich.errors import (
     UpdateRefsError,
     )
 from dulwich.protocol import (
+    PktLineParser,
     Protocol,
     TCP_GIT_PORT,
     ZERO_SHA,
@@ -54,13 +55,72 @@ def _fileno_can_read(fileno):
     """Check if a file descriptor is readable."""
     return len(select.select([fileno], [], [], 0)[0]) > 0
 
-COMMON_CAPABILITIES = ['ofs-delta']
-FETCH_CAPABILITIES = ['multi_ack', 'side-band-64k'] + COMMON_CAPABILITIES
+COMMON_CAPABILITIES = ['ofs-delta', 'side-band-64k']
+FETCH_CAPABILITIES = ['multi_ack'] + COMMON_CAPABILITIES
 SEND_CAPABILITIES = ['report-status'] + COMMON_CAPABILITIES
 
+
+class ReportStatusParser(object):
+    """Handle status as reported by servers with the 'report-status' capability.
+    """
+
+    def __init__(self):
+        self._done = False
+        self._pack_status = None
+        self._ref_status_ok = True
+        self._ref_statuses = []
+
+    def check(self):
+        """Check if there were any errors and, if so, raise exceptions.
+
+        :raise SendPackError: Raised when the server could not unpack
+        :raise UpdateRefsError: Raised when refs could not be updated
+        """
+        if self._pack_status not in ('unpack ok', None):
+            raise SendPackError(self._pack_status)
+        if not self._ref_status_ok:
+            ref_status = {}
+            ok = set()
+            for status in self._ref_statuses:
+                if ' ' not in status:
+                    # malformed response, move on to the next one
+                    continue
+                status, ref = status.split(' ', 1)
+
+                if status == 'ng':
+                    if ' ' in ref:
+                        ref, status = ref.split(' ', 1)
+                else:
+                    ok.add(ref)
+                ref_status[ref] = status
+            raise UpdateRefsError('%s failed to update' %
+                                  ', '.join([ref for ref in ref_status
+                                             if ref not in ok]),
+                                  ref_status=ref_status)
+
+    def handle_packet(self, pkt):
+        """Handle a packet.
+
+        :raise GitProtocolError: Raised when packets are received after a
+            flush packet.
+        """
+        if self._done:
+            raise GitProtocolError("received more data after status report")
+        if pkt is None:
+            self._done = True
+            return
+        if self._pack_status is None:
+            self._pack_status = pkt.strip()
+        else:
+            ref_status = pkt.strip()
+            self._ref_statuses.append(ref_status)
+            if not ref_status.startswith('ok '):
+                self._ref_status_ok = False
+
+
 # TODO(durin42): this doesn't correctly degrade if the server doesn't
 # support some capabilities. This should work properly with servers
-# that don't support side-band-64k and multi_ack.
+# that don't support multi_ack.
 class GitClient(object):
     """Git smart server client.
 
@@ -92,12 +152,14 @@ class GitClient(object):
             refs[ref] = sha
         return refs, server_capabilities
 
-    def send_pack(self, path, determine_wants, generate_pack_contents):
+    def send_pack(self, path, determine_wants, generate_pack_contents,
+                  progress=None):
         """Upload a pack to a remote repository.
 
         :param path: Repository path
         :param generate_pack_contents: Function that can return a sequence of the
             shas of the objects to upload.
+        :param progress: Optional progress function
 
         :raises SendPackError: if server rejects the pack data
         :raises UpdateRefsError: if the server supports report-status
@@ -173,6 +235,26 @@ class GitClient(object):
                                              if ref not in ok]),
                                   ref_status=ref_status)
 
+    def _read_side_band64k_data(self, proto, channel_callbacks):
+        """Read per-channel data.
+
+        This requires the side-band-64k capability.
+
+        :param proto: Protocol object to read from
+        :param channel_callbacks: Dictionary mapping channels to packet
+            handlers to use. None for a callback discards channel data.
+        """
+        for pkt in proto.read_pkt_seq():
+            channel = ord(pkt[0])
+            pkt = pkt[1:]
+            try:
+                cb = channel_callbacks[channel]
+            except KeyError:
+                raise AssertionError('Invalid sideband channel %d' % channel)
+            else:
+                if cb is not None:
+                    cb(pkt)
+
 
 class TraditionalGitClient(GitClient):
     """Traditional Git client."""
@@ -191,13 +273,14 @@ class TraditionalGitClient(GitClient):
         """
         raise NotImplementedError()
 
-    # TODO(durin42): add side-band-64k capability support here and advertise it
-    def send_pack(self, path, determine_wants, generate_pack_contents):
+    def send_pack(self, path, determine_wants, generate_pack_contents,
+                  progress=None):
         """Upload a pack to a remote repository.
 
         :param path: Repository path
         :param generate_pack_contents: Function that can return a sequence of the
             shas of the objects to upload.
+        :param progress: Optional callback called with progress updates
 
         :raises SendPackError: if server rejects the pack data
         :raises UpdateRefsError: if the server supports report-status
@@ -205,8 +288,9 @@ class TraditionalGitClient(GitClient):
         """
         proto, unused_can_read = self._connect('receive-pack', path)
         old_refs, server_capabilities = self._read_refs(proto)
+        negotiated_capabilities = list(self._send_capabilities)
         if 'report-status' not in server_capabilities:
-            self._send_capabilities.remove('report-status')
+            negotiated_capabilities.remove('report-status')
         new_refs = determine_wants(old_refs)
         if not new_refs:
             proto.write_pkt_line(None)
@@ -224,7 +308,7 @@ class TraditionalGitClient(GitClient):
                 else:
                     proto.write_pkt_line(
                       '%s %s %s\0%s' % (old_sha1, new_sha1, refname,
-                                        ' '.join(self._send_capabilities)))
+                                        ' '.join(negotiated_capabilities)))
                     sent_capabilities = True
             if new_sha1 not in have and new_sha1 != ZERO_SHA:
                 want.append(new_sha1)
@@ -233,9 +317,22 @@ class TraditionalGitClient(GitClient):
             return new_refs
         objects = generate_pack_contents(have, want)
         entries, sha = write_pack_objects(proto.write_file(), objects)
-
-        if 'report-status' in self._send_capabilities:
-            self._parse_status_report(proto)
+        if 'report-status' in negotiated_capabilities:
+            report_status_parser = ReportStatusParser()
+        else:
+            report_status_parser = None
+        if "side-band-64k" in negotiated_capabilities:
+            channel_callbacks = { 2: progress }
+            if 'report-status' in negotiated_capabilities:
+                channel_callbacks[1] = PktLineParser(
+                    report_status_parser.handle_packet).parse
+            self._read_side_band64k_data(proto, channel_callbacks)
+        else:
+            if 'report-status':
+                for pkt in proto.read_pkt_seq():
+                    report_status_parser.handle_packet(pkt)
+        if report_status_parser is not None:
+            report_status_parser.check()
         # wait for EOF before returning
         data = proto.read()
         if data:
@@ -243,7 +340,7 @@ class TraditionalGitClient(GitClient):
         return new_refs
 
     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
-                   progress):
+                   progress=None):
         """Retrieve a pack from a git smart server.
 
         :param determine_wants: Callback that returns list of commits to fetch
@@ -253,13 +350,14 @@ class TraditionalGitClient(GitClient):
         """
         proto, can_read = self._connect('upload-pack', path)
         (refs, server_capabilities) = self._read_refs(proto)
+        negotiated_capabilities = list(self._fetch_capabilities)
         wants = determine_wants(refs)
         if not wants:
             proto.write_pkt_line(None)
             return refs
         assert isinstance(wants, list) and type(wants[0]) == str
         proto.write_pkt_line('want %s %s\n' % (
-            wants[0], ' '.join(self._fetch_capabilities)))
+            wants[0], ' '.join(negotiated_capabilities)))
         for want in wants[1:]:
             proto.write_pkt_line('want %s\n' % want)
         proto.write_pkt_line(None)
@@ -282,18 +380,15 @@ class TraditionalGitClient(GitClient):
             if len(parts) < 3 or parts[2] != 'continue':
                 break
             pkt = proto.read_pkt_line()
-        # TODO(durin42): this is broken if the server didn't support the
-        # side-band-64k capability.
-        for pkt in proto.read_pkt_seq():
-            channel = ord(pkt[0])
-            pkt = pkt[1:]
-            if channel == 1:
-                pack_data(pkt)
-            elif channel == 2:
-                if progress is not None:
-                    progress(pkt)
-            else:
-                raise AssertionError('Invalid sideband channel %d' % channel)
+        if "side-band-64k" in negotiated_capabilities:
+            self._read_side_band64k_data(proto, {1: pack_data, 2: progress})
+            # wait for EOF before returning
+            data = proto.read()
+            if data:
+                raise Exception('Unexpected response %r' % data)
+        else:
+            # FIXME: Buffering?
+            pack_data(self.read())
         return refs
 
 
@@ -464,12 +559,14 @@ class HttpGitClient(GitClient):
             raise ValueError("Invalid content-type from server: %s" % resp.info().gettype())
         return resp
 
-    def send_pack(self, path, determine_wants, generate_pack_contents):
+    def send_pack(self, path, determine_wants, generate_pack_contents,
+                  progress=None):
         """Upload a pack to a remote repository.
 
         :param path: Repository path
         :param generate_pack_contents: Function that can return a sequence of the
             shas of the objects to upload.
+        :param progress: Optional progress function
 
         :raises SendPackError: if server rejects the pack data
         :raises UpdateRefsError: if the server supports report-status
@@ -477,13 +574,12 @@ class HttpGitClient(GitClient):
         """
         url = urlparse.urljoin(self.url, path)
         old_refs, server_capabilities = self._discover_references("git-receive-pack", url)
+        negotiated_capabilities = list(self._send_capabilities)
         new_refs = determine_wants(old_refs)
         if not new_refs:
             return {}
         if self.dumb:
             raise NotImplementedError(self.fetch_pack)
-        if 'report-status' not in server_capabilities:
-            raise ValueError("Server does not support report-status")
         req_data = StringIO()
         req_proto = Protocol(None, req_data.write)
         want = []
@@ -514,8 +610,26 @@ class HttpGitClient(GitClient):
             raise ValueError("invalid http response during git-receive-pack: %d"
                              % resp.getcode())
         resp_proto = Protocol(resp.read, None)
-        if 'report-status' in self._send_capabilities:
-            self._parse_status_report(resp_proto)
+        if 'report-status' in negotiated_capabilities:
+            report_status_parser = ReportStatusParser()
+        else:
+            report_status_parser = None
+        if "side-band-64k" in negotiated_capabilities:
+            channel_callbacks = { 2: progress }
+            if 'report-status' in negotiated_capabilities:
+                channel_callbacks[1] = PktLineParser(
+                    report_status_parser.handle_packet).parse
+            self._read_side_band64k_data(resp_proto, channel_callbacks)
+        else:
+            if 'report-status':
+                for pkt in resp_proto.read_pkt_seq():
+                    report_status_parser.handle_packet(pkt)
+        if report_status_parser is not None:
+            report_status_parser.check()
+        # wait for EOF before returning
+        data = resp_proto.read()
+        if data:
+            raise SendPackError('Unexpected response %r' % data)
         return new_refs
 
     def fetch_pack(self, path, determine_wants, graph_walker, pack_data,

+ 33 - 0
dulwich/protocol.py

@@ -406,3 +406,36 @@ class BufferedPktLineWriter(object):
             self._write(data)
         self._len = 0
         self._wbuf = StringIO()
+
+
+class PktLineParser(object):
+    """Packet line parser that hands completed packets off to a callback.
+    """
+
+    def __init__(self, handle_pkt):
+        self.handle_pkt = handle_pkt
+        self._readahead = StringIO()
+
+    def parse(self, data):
+        """Parse a fragment of data and call back for any completed packets.
+        """
+        self._readahead.write(data)
+        buf = self._readahead.getvalue()
+        if len(buf) < 4:
+            return
+        while len(buf) >= 4:
+            size = int(buf[:4], 16)
+            if size == 0:
+                self.handle_pkt(None)
+                buf = buf[4:]
+            elif size <= len(buf):
+                self.handle_pkt(buf[4:size])
+                buf = buf[size:]
+            else:
+                break
+        self._readahead = StringIO()
+        self._readahead.write(buf)
+
+    def get_tail(self):
+        """Read back any unused data."""
+        return self._readahead.getvalue()

+ 27 - 1
dulwich/tests/test_client.py

@@ -23,6 +23,9 @@ from dulwich.client import (
     TCPGitClient,
     SubprocessGitClient,
     SSHGitClient,
+    ReportStatusParser,
+    SendPackError,
+    UpdateRefsError,
     get_transport_and_path,
     )
 from dulwich.tests import (
@@ -60,7 +63,7 @@ class GitClientTests(TestCase):
         self.assertEquals(set(['multi_ack', 'side-band-64k', 'ofs-delta',
                                'thin-pack']),
                           set(self.client._fetch_capabilities))
-        self.assertEquals(set(['ofs-delta', 'report-status']),
+        self.assertEquals(set(['ofs-delta', 'report-status', 'side-band-64k']),
                           set(self.client._send_capabilities))
 
     def test_fetch_pack_none(self):
@@ -151,3 +154,26 @@ class SSHGitClientTests(TestCase):
         self.assertEquals('/usr/lib/git/git-upload-pack',
             self.client._get_cmd_path('upload-pack'))
 
+
+class ReportStatusParserTests(TestCase):
+
+    def test_invalid_pack(self):
+        parser = ReportStatusParser()
+        parser.handle_packet("unpack error - foo bar")
+        parser.handle_packet("ok refs/foo/bar")
+        parser.handle_packet(None)
+        self.assertRaises(SendPackError, parser.check)
+
+    def test_update_refs_error(self):
+        parser = ReportStatusParser()
+        parser.handle_packet("unpack ok")
+        parser.handle_packet("ng refs/foo/bar need to pull")
+        parser.handle_packet(None)
+        self.assertRaises(UpdateRefsError, parser.check)
+
+    def test_ok(self):
+        parser = ReportStatusParser()
+        parser.handle_packet("unpack ok")
+        parser.handle_packet("ok refs/foo/bar")
+        parser.handle_packet(None)
+        parser.check()

+ 27 - 0
dulwich/tests/test_protocol.py

@@ -25,6 +25,7 @@ from dulwich.errors import (
     HangupException,
     )
 from dulwich.protocol import (
+    PktLineParser,
     Protocol,
     ReceivableProtocol,
     extract_capabilities,
@@ -280,3 +281,29 @@ class BufferedPktLineWriterTests(TestCase):
         self._writer.write('z')
         self._writer.flush()
         self.assertOutputEquals('0005z')
+
+
+class PktLineParserTests(TestCase):
+
+    def test_none(self):
+        pktlines = []
+        parser = PktLineParser(pktlines.append)
+        parser.parse("0000")
+        self.assertEquals(pktlines, [None])
+        self.assertEquals("", parser.get_tail())
+
+    def test_small_fragments(self):
+        pktlines = []
+        parser = PktLineParser(pktlines.append)
+        parser.parse("00")
+        parser.parse("05")
+        parser.parse("z0000")
+        self.assertEquals(pktlines, ["z", None])
+        self.assertEquals("", parser.get_tail())
+
+    def test_multiple_packets(self):
+        pktlines = []
+        parser = PktLineParser(pktlines.append)
+        parser.parse("0005z0006aba")
+        self.assertEquals(pktlines, ["z", "ab"])
+        self.assertEquals("a", parser.get_tail())