Pārlūkot izejas kodu

client: raise an exception when send_pack fails

Change-Id: I55f6caa6f5c050f5feb556c3c21be18d3e57fcba
Augie Fackler 15 gadi atpakaļ
vecāks
revīzija
bc59c3d5f5

+ 55 - 14
dulwich/client.py

@@ -28,7 +28,8 @@ import subprocess
 
 from dulwich.errors import (
     ChecksumMismatch,
-    HangupException,
+    SendPackError,
+    UpdateRefsError,
     )
 from dulwich.protocol import (
     Protocol,
@@ -47,8 +48,11 @@ def _fileno_can_read(fileno):
 
 COMMON_CAPABILITIES = ["ofs-delta"]
 FETCH_CAPABILITIES = ["multi_ack", "side-band-64k"] + COMMON_CAPABILITIES
-SEND_CAPABILITIES = [] + COMMON_CAPABILITIES
+SEND_CAPABILITIES = ['report-status'] + COMMON_CAPABILITIES
 
+# TODO(durin42): this doesn't correctly degrade if the server doesn't
+# support some capabilities. This should work properly with servers
+# that don't support side-band-64k and multi_ack.
 class GitClient(object):
     """Git smart server client.
 
@@ -91,8 +95,14 @@ class GitClient(object):
         :param path: Repository path
         :param generate_pack_contents: Function that can return the shas of the
             objects to upload.
+
+        :raises SendPackError: if server rejects the pack data
+        :raises UpdateRefsError: if the server supports report-status
+                                 and rejects ref updates
         """
         old_refs, server_capabilities = self.read_refs()
+        if 'report-status' not in server_capabilities:
+            self._send_capabilities.remove('report-status')
         new_refs = determine_wants(old_refs)
         if not new_refs:
             self.proto.write_pkt_line(None)
@@ -105,7 +115,8 @@ class GitClient(object):
             new_sha1 = new_refs.get(refname, ZERO_SHA)
             if old_sha1 != new_sha1:
                 if sent_capabilities:
-                    self.proto.write_pkt_line("%s %s %s" % (old_sha1, new_sha1, refname))
+                    self.proto.write_pkt_line("%s %s %s" % (old_sha1, new_sha1,
+                                                            refname))
                 else:
                     self.proto.write_pkt_line(
                       "%s %s %s\0%s" % (old_sha1, new_sha1, refname,
@@ -120,17 +131,47 @@ class GitClient(object):
         (entries, sha) = write_pack_data(self.proto.write_file(), objects,
                                          len(objects))
 
-        # read the final confirmation sha
-        try:
-            client_sha = self.proto.read_pkt_line()
-        except HangupException:
-            # for git-daemon versions before v1.6.6.1-26-g38a81b4, there is
-            # nothing to read; catch this and hide from the user.
-            pass
-        else:
-            if not client_sha in (None, "", sha):
-                raise ChecksumMismatch(sha, client_sha)
-
+        if 'report-status' in self._send_capabilities:
+            unpack = self.proto.read_pkt_line().strip()
+            if unpack != 'unpack ok':
+                st = True
+                # flush remaining error data
+                while st is not None:
+                    st = self.proto.read_pkt_line()
+                raise SendPackError(unpack)
+            statuses = []
+            errs = False
+            ref_status = self.proto.read_pkt_line()
+            while ref_status:
+                ref_status = ref_status.strip()
+                statuses.append(ref_status)
+                if not ref_status.startswith('ok '):
+                    errs = True
+                ref_status = self.proto.read_pkt_line()
+
+            if errs:
+                ref_status = {}
+                ok = set()
+                for status in statuses:
+                    if ' ' not in status:
+                        # malformed response, move on to the next one
+                        continue
+                    status, ref = status.split(' ', 1)
+
+                    if status == 'ng':
+                        if ' ' in ref:
+                            ref, status = ref.split(' ', 1)
+                    else:
+                        ok.add(ref)
+                    ref_status[ref] = status
+                raise UpdateRefsError('%s failed to update' %
+                                      ', '.join([ref for ref in ref_status
+                                                 if ref not in ok]),
+                                      ref_status=ref_status)
+        # wait for EOF before returning
+        data = self.proto.read()
+        if data:
+            raise SendPackError('Unexpected response %r' % data)
         return new_refs
 
     def fetch(self, path, target, determine_wants=None, progress=None):

+ 15 - 0
dulwich/errors.py

@@ -114,6 +114,21 @@ class GitProtocolError(Exception):
         Exception.__init__(self, *args, **kwargs)
 
 
+class SendPackError(GitProtocolError):
+    """An error occurred during send_pack."""
+
+    def __init__(self, *args, **kwargs):
+        Exception.__init__(self, *args, **kwargs)
+
+
+class UpdateRefsError(GitProtocolError):
+    """The server reported errors updating refs."""
+
+    def __init__(self, *args, **kwargs):
+        self.ref_status = kwargs.pop('ref_status')
+        Exception.__init__(self, *args, **kwargs)
+
+
 class HangupException(GitProtocolError):
     """Hangup exception."""
 

+ 150 - 0
dulwich/tests/compat/test_client.py

@@ -0,0 +1,150 @@
+# test_client.py -- Compatibilty tests for git client.
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Compatibilty tests between the Dulwich client and the cgit server."""
+
+import os
+import shutil
+import signal
+import tempfile
+
+from dulwich import client
+from dulwich import errors
+from dulwich import file
+from dulwich import index
+from dulwich import protocol
+from dulwich import object_store
+from dulwich import objects
+from dulwich import repo
+from dulwich.tests import (
+    TestSkipped,
+    )
+
+from utils import (
+    CompatTestCase,
+    check_for_daemon,
+    import_repo_to_dir,
+    run_git,
+    )
+
+class DulwichClientTest(CompatTestCase):
+    """Tests for client/server compatibility."""
+
+    def setUp(self):
+        if check_for_daemon(limit=1):
+            raise TestSkipped('git-daemon was already running on port %s' %
+                              protocol.TCP_GIT_PORT)
+        CompatTestCase.setUp(self)
+        fd, self.pidfile = tempfile.mkstemp(prefix='dulwich-test-git-client',
+                                            suffix=".pid")
+        os.fdopen(fd).close()
+        self.gitroot = os.path.dirname(import_repo_to_dir('server_new.export'))
+        dest = os.path.join(self.gitroot, 'dest')
+        file.ensure_dir_exists(dest)
+        run_git(['init', '--bare'], cwd=dest)
+        run_git(
+            ['daemon', '--verbose', '--export-all',
+             '--pid-file=%s' % self.pidfile, '--base-path=%s' % self.gitroot,
+             '--detach', '--reuseaddr', '--enable=receive-pack',
+             '--listen=localhost', self.gitroot], cwd=self.gitroot)
+        if not check_for_daemon():
+            raise TestSkipped('git-daemon failed to start')
+
+    def tearDown(self):
+        CompatTestCase.tearDown(self)
+        try:
+            os.kill(int(open(self.pidfile).read().strip()), signal.SIGKILL)
+            os.unlink(self.pidfile)
+        except (OSError, IOError):
+            pass
+        shutil.rmtree(self.gitroot)
+
+    def test_send_pack(self):
+        c = client.TCPGitClient('localhost')
+        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        src = repo.Repo(srcpath)
+        sendrefs = dict(src.get_refs())
+        del sendrefs['HEAD']
+        c.send_pack('/dest', lambda _: sendrefs,
+                    src.object_store.generate_pack_contents)
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        self.assertReposEqual(src, dest)
+
+    def test_send_without_report_status(self):
+        c = client.TCPGitClient('localhost')
+        c._send_capabilities.remove('report-status')
+        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        src = repo.Repo(srcpath)
+        sendrefs = dict(src.get_refs())
+        del sendrefs['HEAD']
+        c.send_pack('/dest', lambda _: sendrefs,
+                    src.object_store.generate_pack_contents)
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        self.assertReposEqual(src, dest)
+
+    def disable_ff_and_make_dummy_commit(self):
+        # disable non-fast-forward pushes to the server
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        run_git(['config', 'receive.denyNonFastForwards', 'true'], cwd=dest.path)
+        b = objects.Blob.from_string('hi')
+        dest.object_store.add_object(b)
+        t = index.commit_tree(dest.object_store, [('hi', b.id, 0100644)])
+        c = objects.Commit()
+        c.author = c.committer = 'Foo Bar <foo@example.com>'
+        c.author_time = c.commit_time = 0
+        c.author_timezone = c.commit_timezone = 0
+        c.message = 'hi'
+        c.tree = t
+        dest.object_store.add_object(c)
+        return dest, c.id
+
+    def compute_send(self):
+        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        src = repo.Repo(srcpath)
+        sendrefs = dict(src.get_refs())
+        del sendrefs['HEAD']
+        return sendrefs, src.object_store.generate_pack_contents
+
+    def test_send_pack_one_error(self):
+        dest, dummy_commit = self.disable_ff_and_make_dummy_commit()
+        dest.refs['refs/heads/master'] = dummy_commit
+        sendrefs, gen_pack = self.compute_send()
+        c = client.TCPGitClient('localhost')
+        try:
+            c.send_pack('/dest', lambda _: sendrefs, gen_pack)
+        except errors.UpdateRefsError, e:
+            self.assertEqual('refs/heads/master failed to update', str(e))
+            self.assertEqual({'refs/heads/branch': 'ok',
+                              'refs/heads/master': 'non-fast-forward'},
+                             e.ref_status)
+
+    def test_send_pack_multiple_errors(self):
+        dest, dummy = self.disable_ff_and_make_dummy_commit()
+        # set up for two non-ff errors
+        dest.refs['refs/heads/branch'] = dest.refs['refs/heads/master'] = dummy
+        sendrefs, gen_pack = self.compute_send()
+        c = client.TCPGitClient('localhost')
+        try:
+            c.send_pack('/dest', lambda _: sendrefs, gen_pack)
+        except errors.UpdateRefsError, e:
+            self.assertEqual('refs/heads/branch, refs/heads/master failed to '
+                             'update', str(e))
+            self.assertEqual({'refs/heads/branch': 'non-fast-forward',
+                              'refs/heads/master': 'non-fast-forward'},
+                             e.ref_status)

+ 45 - 4
dulwich/tests/compat/utils.py

@@ -19,12 +19,16 @@
 
 """Utilities for interacting with cgit."""
 
+import errno
 import os
+import socket
 import subprocess
 import tempfile
+import time
 import unittest
 
 from dulwich.repo import Repo
+from dulwich.protocol import TCP_GIT_PORT
 
 from dulwich.tests import (
     TestSkipped,
@@ -108,15 +112,15 @@ def run_git_or_fail(args, git_path=_DEFAULT_GIT, input=None, **popen_kwargs):
     return stdout
 
 
-def import_repo(name):
+def import_repo_to_dir(name):
     """Import a repo from a fast-export file in a temporary directory.
 
     These are used rather than binary repos for compat tests because they are
     more compact an human-editable, and we already depend on git.
 
     :param name: The name of the repository export file, relative to
-        dulwich/tests/data/repos
-    :returns: An initialized Repo object that lives in a temporary directory.
+        dulwich/tests/data/repos.
+    :returns: The path to the imported repository.
     """
     temp_dir = tempfile.mkdtemp()
     export_path = os.path.join(os.path.dirname(__file__), os.pardir, 'data',
@@ -127,7 +131,44 @@ def import_repo(name):
     run_git_or_fail(['fast-import'], input=export_file.read(),
                     cwd=temp_repo_dir)
     export_file.close()
-    return Repo(temp_repo_dir)
+    return temp_repo_dir
+
+def import_repo(name):
+    """Import a repo from a fast-export file in a temporary directory.
+
+    :param name: The name of the repository export file, relative to
+        dulwich/tests/data/repos.
+    :returns: An initialized Repo object that lives in a temporary directory.
+    """
+    return Repo(import_repo_to_dir(name))
+
+
+def check_for_daemon(limit=10, delay=0.1, timeout=0.1, port=TCP_GIT_PORT):
+    """Check for a running TCP daemon.
+
+    Defaults to checking 10 times with a delay of 0.1 sec between tries.
+
+    :param limit: Number of attempts before deciding no daemon is running.
+    :param delay: Delay between connection attempts.
+    :param timeout: Socket timeout for connection attempts.
+    :param port: Port on which we expect the daemon to appear.
+    :returns: A boolean, true if a daemon is running on the specified port,
+        false if not.
+    """
+    for _ in xrange(limit):
+        time.sleep(delay)
+        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        s.settimeout(delay)
+        try:
+            s.connect(('localhost', port))
+            s.close()
+            return True
+        except socket.error, e:
+            if getattr(e, 'errno', False) and e.errno != errno.ECONNREFUSED:
+                raise
+            elif e.args[0] != errno.ECONNREFUSED:
+                raise
+    return False
 
 
 class CompatTestCase(unittest.TestCase):

+ 1 - 1
dulwich/tests/test_client.py

@@ -37,7 +37,7 @@ class GitClientTests(TestCase):
         self.assertEquals(set(['multi_ack', 'side-band-64k', 'ofs-delta',
                                'thin-pack']),
                           set(self.client._fetch_capabilities))
-        self.assertEquals(set(['ofs-delta']),
+        self.assertEquals(set(['ofs-delta', 'report-status']),
                           set(self.client._send_capabilities))
 
     def test_fetch_pack_none(self):