2
0
Эх сурвалжийг харах

Merge cleanups and use of proper per-service capabilities from Augie.

Jelmer Vernooij 15 жил өмнө
parent
commit
298f25e2cf

+ 83 - 40
dulwich/client.py

@@ -28,7 +28,8 @@ import subprocess
 
 from dulwich.errors import (
     ChecksumMismatch,
-    HangupException,
+    SendPackError,
+    UpdateRefsError,
     )
 from dulwich.protocol import (
     Protocol,
@@ -45,16 +46,19 @@ def _fileno_can_read(fileno):
     """Check if a file descriptor is readable."""
     return len(select.select([fileno], [], [], 0)[0]) > 0
 
+COMMON_CAPABILITIES = ["ofs-delta"]
+FETCH_CAPABILITIES = ["multi_ack", "side-band-64k"] + COMMON_CAPABILITIES
+SEND_CAPABILITIES = ['report-status'] + COMMON_CAPABILITIES
 
-CAPABILITIES = ["multi_ack", "side-band-64k", "ofs-delta"]
-
-
+# TODO(durin42): this doesn't correctly degrade if the server doesn't
+# support some capabilities. This should work properly with servers
+# that don't support side-band-64k and multi_ack.
 class GitClient(object):
     """Git smart server client.
 
     """
 
-    def __init__(self, can_read, read, write, thin_packs=True, 
+    def __init__(self, can_read, read, write, thin_packs=True,
         report_activity=None):
         """Create a new GitClient instance.
 
@@ -68,12 +72,10 @@ class GitClient(object):
         """
         self.proto = Protocol(read, write, report_activity)
         self._can_read = can_read
-        self._capabilities = list(CAPABILITIES)
+        self._fetch_capabilities = list(FETCH_CAPABILITIES)
+        self._send_capabilities = list(SEND_CAPABILITIES)
         if thin_packs:
-            self._capabilities.append("thin-pack")
-
-    def capabilities(self):
-        return " ".join(self._capabilities)
+            self._fetch_capabilities.append("thin-pack")
 
     def read_refs(self):
         server_capabilities = None
@@ -86,14 +88,21 @@ class GitClient(object):
             refs[ref] = sha
         return refs, server_capabilities
 
+    # TODO(durin42): add side-band-64k capability support here and advertise it
     def send_pack(self, path, determine_wants, generate_pack_contents):
         """Upload a pack to a remote repository.
 
         :param path: Repository path
-        :param generate_pack_contents: Function that can return the shas of the 
+        :param generate_pack_contents: Function that can return the shas of the
             objects to upload.
+
+        :raises SendPackError: if server rejects the pack data
+        :raises UpdateRefsError: if the server supports report-status
+                                 and rejects ref updates
         """
         old_refs, server_capabilities = self.read_refs()
+        if 'report-status' not in server_capabilities:
+            self._send_capabilities.remove('report-status')
         new_refs = determine_wants(old_refs)
         if not new_refs:
             self.proto.write_pkt_line(None)
@@ -106,9 +115,12 @@ class GitClient(object):
             new_sha1 = new_refs.get(refname, ZERO_SHA)
             if old_sha1 != new_sha1:
                 if sent_capabilities:
-                    self.proto.write_pkt_line("%s %s %s" % (old_sha1, new_sha1, refname))
+                    self.proto.write_pkt_line("%s %s %s" % (old_sha1, new_sha1,
+                                                            refname))
                 else:
-                    self.proto.write_pkt_line("%s %s %s\0%s" % (old_sha1, new_sha1, refname, self.capabilities()))
+                    self.proto.write_pkt_line(
+                      "%s %s %s\0%s" % (old_sha1, new_sha1, refname,
+                                        ' '.join(self._send_capabilities)))
                     sent_capabilities = True
             if not new_sha1 in (have, ZERO_SHA):
                 want.append(new_sha1)
@@ -116,20 +128,50 @@ class GitClient(object):
         if not want:
             return new_refs
         objects = generate_pack_contents(have, want)
-        (entries, sha) = write_pack_data(self.proto.write_file(), objects, 
+        (entries, sha) = write_pack_data(self.proto.write_file(), objects,
                                          len(objects))
-        
-        # read the final confirmation sha
-        try:
-            client_sha = self.proto.read_pkt_line()
-        except HangupException:
-            # for git-daemon versions before v1.6.6.1-26-g38a81b4, there is
-            # nothing to read; catch this and hide from the user.
-            pass
-        else:
-            if not client_sha in (None, "", sha):
-                raise ChecksumMismatch(sha, client_sha)
 
+        if 'report-status' in self._send_capabilities:
+            unpack = self.proto.read_pkt_line().strip()
+            if unpack != 'unpack ok':
+                st = True
+                # flush remaining error data
+                while st is not None:
+                    st = self.proto.read_pkt_line()
+                raise SendPackError(unpack)
+            statuses = []
+            errs = False
+            ref_status = self.proto.read_pkt_line()
+            while ref_status:
+                ref_status = ref_status.strip()
+                statuses.append(ref_status)
+                if not ref_status.startswith('ok '):
+                    errs = True
+                ref_status = self.proto.read_pkt_line()
+
+            if errs:
+                ref_status = {}
+                ok = set()
+                for status in statuses:
+                    if ' ' not in status:
+                        # malformed response, move on to the next one
+                        continue
+                    status, ref = status.split(' ', 1)
+
+                    if status == 'ng':
+                        if ' ' in ref:
+                            ref, status = ref.split(' ', 1)
+                    else:
+                        ok.add(ref)
+                    ref_status[ref] = status
+                raise UpdateRefsError('%s failed to update' %
+                                      ', '.join([ref for ref in ref_status
+                                                 if ref not in ok]),
+                                      ref_status=ref_status)
+        # wait for EOF before returning
+        data = self.proto.read()
+        if data:
+            raise SendPackError('Unexpected response %r' % data)
         return new_refs
 
     def fetch(self, path, target, determine_wants=None, progress=None):
@@ -137,7 +179,7 @@ class GitClient(object):
 
         :param path: Path to fetch from
         :param target: Target repository to fetch into
-        :param determine_wants: Optional function to determine what refs 
+        :param determine_wants: Optional function to determine what refs
             to fetch
         :param progress: Optional progress function
         :return: remote refs
@@ -146,7 +188,7 @@ class GitClient(object):
             determine_wants = target.object_store.determine_wants_all
         f, commit = target.object_store.add_pack()
         try:
-            return self.fetch_pack(path, determine_wants, 
+            return self.fetch_pack(path, determine_wants,
                 target.get_graph_walker(), f.write, progress)
         finally:
             commit()
@@ -166,7 +208,8 @@ class GitClient(object):
             self.proto.write_pkt_line(None)
             return refs
         assert isinstance(wants, list) and type(wants[0]) == str
-        self.proto.write_pkt_line("want %s %s\n" % (wants[0], self.capabilities()))
+        self.proto.write_pkt_line("want %s %s\n" % (
+            wants[0], ' '.join(self._fetch_capabilities)))
         for want in wants[1:]:
             self.proto.write_pkt_line("want %s\n" % want)
         self.proto.write_pkt_line(None)
@@ -189,6 +232,8 @@ class GitClient(object):
             if len(parts) < 3 or parts[2] != "continue":
                 break
             pkt = self.proto.read_pkt_line()
+        # TODO(durin42): this is broken if the server didn't support the
+        # side-band-64k capability.
         for pkt in self.proto.read_pkt_seq():
             channel = ord(pkt[0])
             pkt = pkt[1:]
@@ -224,9 +269,9 @@ class TCPGitClient(GitClient):
 
     def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress):
         """Fetch a pack from the remote host.
-        
+
         :param path: Path of the reposiutory on the remote host
-        :param determine_wants: Callback that receives available refs dict and 
+        :param determine_wants: Callback that receives available refs dict and
             should return list of sha's to fetch.
         :param graph_walker: GraphWalker instance used to find missing shas
         :param pack_data: Callback for writing pack data
@@ -262,18 +307,18 @@ class SubprocessGitClient(GitClient):
 
         :param path: Path to the git repository on the server
         :param changed_refs: Dictionary with new values for the refs
-        :param generate_pack_contents: Function that returns an iterator over 
+        :param generate_pack_contents: Function that returns an iterator over
             objects to send
         """
         client = self._connect("git-receive-pack", path)
         return client.send_pack(path, changed_refs, generate_pack_contents)
 
-    def fetch_pack(self, path, determine_wants, graph_walker, pack_data, 
+    def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
         progress):
         """Retrieve a pack from the server
 
         :param path: Path to the git repository on the server
-        :param determine_wants: Function that receives existing refs 
+        :param determine_wants: Function that receives existing refs
             on the server and returns a list of desired shas
         :param graph_walker: GraphWalker instance
         :param pack_data: Function that can write pack data
@@ -289,12 +334,8 @@ class SSHSubprocess(object):
 
     def __init__(self, proc):
         self.proc = proc
-
-    def send(self, data):
-        return os.write(self.proc.stdin.fileno(), data)
-
-    def recv(self, count):
-        return self.proc.stdout.read(count)
+        self.read = self.recv = proc.stdout.read
+        self.write = self.send = proc.stdin.write
 
     def close(self):
         self.proc.stdin.close()
@@ -331,7 +372,9 @@ class SSHGitClient(GitClient):
         self._kwargs = kwargs
 
     def send_pack(self, path, determine_wants, generate_pack_contents):
-        remote = get_ssh_vendor().connect_ssh(self.host, ["git-receive-pack '%s'" % path], port=self.port, username=self.username)
+        remote = get_ssh_vendor().connect_ssh(
+            self.host, ["git-receive-pack '%s'" % path],
+            port=self.port, username=self.username)
         client = GitClient(lambda: _fileno_can_read(remote.proc.stdout.fileno()), remote.recv, remote.send, *self._args, **self._kwargs)
         return client.send_pack(path, determine_wants, generate_pack_contents)
 

+ 23 - 8
dulwich/errors.py

@@ -1,17 +1,17 @@
 # errors.py -- errors for dulwich
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
 # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # or (at your option) any later version of the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -38,7 +38,7 @@ class ChecksumMismatch(Exception):
                 "Checksum mismatch: Expected %s, got %s" % (expected, got))
         else:
             Exception.__init__(self,
-                "Checksum mismatch: Expected %s, got %s; %s" % 
+                "Checksum mismatch: Expected %s, got %s; %s" %
                 (expected, got, extra))
 
 
@@ -81,21 +81,21 @@ class NotBlobError(WrongObjectException):
 
 class MissingCommitError(Exception):
     """Indicates that a commit was not found in the repository"""
-  
+
     def __init__(self, sha, *args, **kwargs):
         Exception.__init__(self, "%s is not in the revision store" % sha)
 
 
 class ObjectMissing(Exception):
     """Indicates that a requested object is missing."""
-  
+
     def __init__(self, sha, *args, **kwargs):
         Exception.__init__(self, "%s is not in the pack" % sha)
 
 
 class ApplyDeltaError(Exception):
     """Indicates that applying a delta failed."""
-    
+
     def __init__(self, *args, **kwargs):
         Exception.__init__(self, *args, **kwargs)
 
@@ -109,8 +109,23 @@ class NotGitRepository(Exception):
 
 class GitProtocolError(Exception):
     """Git protocol exception."""
-    
+
+    def __init__(self, *args, **kwargs):
+        Exception.__init__(self, *args, **kwargs)
+
+
+class SendPackError(GitProtocolError):
+    """An error occurred during send_pack."""
+
+    def __init__(self, *args, **kwargs):
+        Exception.__init__(self, *args, **kwargs)
+
+
+class UpdateRefsError(GitProtocolError):
+    """The server reported errors updating refs."""
+
     def __init__(self, *args, **kwargs):
+        self.ref_status = kwargs.pop('ref_status')
         Exception.__init__(self, *args, **kwargs)
 
 

+ 20 - 20
dulwich/object_store.py

@@ -60,7 +60,8 @@ class BaseObjectStore(object):
     """Object store interface."""
 
     def determine_wants_all(self, refs):
-	    return [sha for (ref, sha) in refs.iteritems() if not sha in self and not ref.endswith("^{}")]
+        return [sha for (ref, sha) in refs.iteritems()
+                if not sha in self and not ref.endswith("^{}")]
 
     def iter_shas(self, shas):
         """Iterate over the objects for the specified shas.
@@ -148,7 +149,7 @@ class BaseObjectStore(object):
                     newmode = None
                     newhexsha = None
                     newchildpath = None
-                if (want_unchanged or oldmode != newmode or 
+                if (want_unchanged or oldmode != newmode or
                     oldhexsha != newhexsha):
                     if stat.S_ISDIR(oldmode):
                         if newmode is None or stat.S_ISDIR(newmode):
@@ -182,7 +183,7 @@ class BaseObjectStore(object):
         while todo:
             (tid, tpath) = todo.pop()
             tree = self[tid]
-            for name, mode, hexsha in tree.iteritems(): 
+            for name, mode, hexsha in tree.iteritems():
                 path = posixpath.join(tpath, name)
                 if stat.S_ISDIR(mode):
                     todo.add((hexsha, path))
@@ -195,7 +196,7 @@ class BaseObjectStore(object):
 
         :param haves: Iterable over SHAs already in common.
         :param wants: Iterable over SHAs of objects to fetch.
-        :param progress: Simple progress function that will be called with 
+        :param progress: Simple progress function that will be called with
             updated progress strings.
         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
             sha for including tags.
@@ -221,7 +222,7 @@ class BaseObjectStore(object):
 
     def get_graph_walker(self, heads):
         """Obtain a graph walker for this object store.
-        
+
         :param heads: Local heads to start search with
         :return: GraphWalker object
         """
@@ -304,7 +305,7 @@ class PackBasedObjectStore(BaseObjectStore):
                 return pack.get_raw(sha)
             except KeyError:
                 pass
-        if hexsha is None: 
+        if hexsha is None:
             hexsha = sha_to_hex(name)
         ret = self._get_loose_object(hexsha)
         if ret is not None:
@@ -387,7 +388,7 @@ class DiskObjectStore(PackBasedObjectStore):
     def move_in_thin_pack(self, path):
         """Move a specific file containing a pack into the pack directory.
 
-        :note: The file should be on the same file system as the 
+        :note: The file should be on the same file system as the
             packs directory.
 
         :param path: Path to the pack file.
@@ -395,13 +396,13 @@ class DiskObjectStore(PackBasedObjectStore):
         data = ThinPackData(self.get_raw, path)
 
         # Write index for the thin pack (do we really need this?)
-        temppath = os.path.join(self.pack_dir, 
+        temppath = os.path.join(self.pack_dir,
             sha_to_hex(urllib2.randombytes(20))+".tempidx")
         data.create_index_v2(temppath)
         p = Pack.from_objects(data, load_pack_index(temppath))
 
         # Write a full pack version
-        temppath = os.path.join(self.pack_dir, 
+        temppath = os.path.join(self.pack_dir,
             sha_to_hex(urllib2.randombytes(20))+".temppack")
         write_pack(temppath, ((o, None) for o in p.iterobjects()), len(p))
         pack_sha = load_pack_index(temppath+".idx").objects_sha1()
@@ -415,14 +416,14 @@ class DiskObjectStore(PackBasedObjectStore):
     def move_in_pack(self, path):
         """Move a specific file containing a pack into the pack directory.
 
-        :note: The file should be on the same file system as the 
+        :note: The file should be on the same file system as the
             packs directory.
 
         :param path: Path to the pack file.
         """
         p = PackData(path)
         entries = p.sorted_entries()
-        basename = os.path.join(self.pack_dir, 
+        basename = os.path.join(self.pack_dir,
             "pack-%s" % iter_sha1(entry[0] for entry in entries))
         write_pack_index_v2(basename+".idx", entries, p.get_stored_checksum())
         p.close()
@@ -434,7 +435,7 @@ class DiskObjectStore(PackBasedObjectStore):
     def add_thin_pack(self):
         """Add a new thin pack to this object store.
 
-        Thin packs are packs that contain deltas with parents that exist 
+        Thin packs are packs that contain deltas with parents that exist
         in a different pack.
         """
         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
@@ -449,9 +450,9 @@ class DiskObjectStore(PackBasedObjectStore):
         return f, commit
 
     def add_pack(self):
-        """Add a new pack to this object store. 
+        """Add a new pack to this object store.
 
-        :return: Fileobject to write to and a commit function to 
+        :return: Fileobject to write to and a commit function to
             call when the pack is finished.
         """
         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
@@ -607,7 +608,7 @@ class ObjectStoreIterator(ObjectIterator):
     def __contains__(self, needle):
         """Check if an object is present.
 
-        :note: This checks if the object is present in 
+        :note: This checks if the object is present in
             the underlying object store, not if it would
             be yielded by the iterator.
 
@@ -617,7 +618,7 @@ class ObjectStoreIterator(ObjectIterator):
 
     def __getitem__(self, key):
         """Find an object by SHA1.
-        
+
         :note: This retrieves the object from the underlying
             object store. It will also succeed if the object would
             not be returned by the iterator.
@@ -652,7 +653,7 @@ def tree_lookup_path(lookup_obj, root_sha, path):
 class MissingObjectFinder(object):
     """Find the objects missing from another object store.
 
-    :param object_store: Object store containing at least all objects to be 
+    :param object_store: Object store containing at least all objects to be
         sent
     :param haves: SHA1s of commits not to send (already present in target)
     :param wants: SHA1s of commits to send
@@ -706,9 +707,8 @@ class MissingObjectFinder(object):
 
 
 class ObjectStoreGraphWalker(object):
-    """Graph walker that finds out what commits are missing from an object 
-    store.
-    
+    """Graph walker that finds what commits are missing from an object store.
+
     :ivar heads: Revisions without descendants in the local repo
     :ivar get_parents: Function to retrieve parents in the local repo
     """

+ 0 - 13
dulwich/tests/compat/server_utils.py

@@ -51,19 +51,6 @@ class ServerTests(object):
         tear_down_repo(self._old_repo)
         tear_down_repo(self._new_repo)
 
-    def assertReposEqual(self, repo1, repo2):
-        self.assertEqual(repo1.get_refs(), repo2.get_refs())
-        self.assertEqual(sorted(set(repo1.object_store)),
-                         sorted(set(repo2.object_store)))
-
-    def assertReposNotEqual(self, repo1, repo2):
-        refs1 = repo1.get_refs()
-        objs1 = set(repo1.object_store)
-        refs2 = repo2.get_refs()
-        objs2 = set(repo2.object_store)
-
-        self.assertFalse(refs1 == refs2 and objs1 == objs2)
-
     def test_push_to_dulwich(self):
         self.assertReposNotEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._old_repo)

+ 150 - 0
dulwich/tests/compat/test_client.py

@@ -0,0 +1,150 @@
+# test_client.py -- Compatibilty tests for git client.
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Compatibilty tests between the Dulwich client and the cgit server."""
+
+import os
+import shutil
+import signal
+import tempfile
+
+from dulwich import client
+from dulwich import errors
+from dulwich import file
+from dulwich import index
+from dulwich import protocol
+from dulwich import object_store
+from dulwich import objects
+from dulwich import repo
+from dulwich.tests import (
+    TestSkipped,
+    )
+
+from utils import (
+    CompatTestCase,
+    check_for_daemon,
+    import_repo_to_dir,
+    run_git,
+    )
+
+class DulwichClientTest(CompatTestCase):
+    """Tests for client/server compatibility."""
+
+    def setUp(self):
+        if check_for_daemon(limit=1):
+            raise TestSkipped('git-daemon was already running on port %s' %
+                              protocol.TCP_GIT_PORT)
+        CompatTestCase.setUp(self)
+        fd, self.pidfile = tempfile.mkstemp(prefix='dulwich-test-git-client',
+                                            suffix=".pid")
+        os.fdopen(fd).close()
+        self.gitroot = os.path.dirname(import_repo_to_dir('server_new.export'))
+        dest = os.path.join(self.gitroot, 'dest')
+        file.ensure_dir_exists(dest)
+        run_git(['init', '--bare'], cwd=dest)
+        run_git(
+            ['daemon', '--verbose', '--export-all',
+             '--pid-file=%s' % self.pidfile, '--base-path=%s' % self.gitroot,
+             '--detach', '--reuseaddr', '--enable=receive-pack',
+             '--listen=localhost', self.gitroot], cwd=self.gitroot)
+        if not check_for_daemon():
+            raise TestSkipped('git-daemon failed to start')
+
+    def tearDown(self):
+        CompatTestCase.tearDown(self)
+        try:
+            os.kill(int(open(self.pidfile).read().strip()), signal.SIGKILL)
+            os.unlink(self.pidfile)
+        except (OSError, IOError):
+            pass
+        shutil.rmtree(self.gitroot)
+
+    def test_send_pack(self):
+        c = client.TCPGitClient('localhost')
+        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        src = repo.Repo(srcpath)
+        sendrefs = dict(src.get_refs())
+        del sendrefs['HEAD']
+        c.send_pack('/dest', lambda _: sendrefs,
+                    src.object_store.generate_pack_contents)
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        self.assertReposEqual(src, dest)
+
+    def test_send_without_report_status(self):
+        c = client.TCPGitClient('localhost')
+        c._send_capabilities.remove('report-status')
+        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        src = repo.Repo(srcpath)
+        sendrefs = dict(src.get_refs())
+        del sendrefs['HEAD']
+        c.send_pack('/dest', lambda _: sendrefs,
+                    src.object_store.generate_pack_contents)
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        self.assertReposEqual(src, dest)
+
+    def disable_ff_and_make_dummy_commit(self):
+        # disable non-fast-forward pushes to the server
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        run_git(['config', 'receive.denyNonFastForwards', 'true'], cwd=dest.path)
+        b = objects.Blob.from_string('hi')
+        dest.object_store.add_object(b)
+        t = index.commit_tree(dest.object_store, [('hi', b.id, 0100644)])
+        c = objects.Commit()
+        c.author = c.committer = 'Foo Bar <foo@example.com>'
+        c.author_time = c.commit_time = 0
+        c.author_timezone = c.commit_timezone = 0
+        c.message = 'hi'
+        c.tree = t
+        dest.object_store.add_object(c)
+        return dest, c.id
+
+    def compute_send(self):
+        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        src = repo.Repo(srcpath)
+        sendrefs = dict(src.get_refs())
+        del sendrefs['HEAD']
+        return sendrefs, src.object_store.generate_pack_contents
+
+    def test_send_pack_one_error(self):
+        dest, dummy_commit = self.disable_ff_and_make_dummy_commit()
+        dest.refs['refs/heads/master'] = dummy_commit
+        sendrefs, gen_pack = self.compute_send()
+        c = client.TCPGitClient('localhost')
+        try:
+            c.send_pack('/dest', lambda _: sendrefs, gen_pack)
+        except errors.UpdateRefsError, e:
+            self.assertEqual('refs/heads/master failed to update', str(e))
+            self.assertEqual({'refs/heads/branch': 'ok',
+                              'refs/heads/master': 'non-fast-forward'},
+                             e.ref_status)
+
+    def test_send_pack_multiple_errors(self):
+        dest, dummy = self.disable_ff_and_make_dummy_commit()
+        # set up for two non-ff errors
+        dest.refs['refs/heads/branch'] = dest.refs['refs/heads/master'] = dummy
+        sendrefs, gen_pack = self.compute_send()
+        c = client.TCPGitClient('localhost')
+        try:
+            c.send_pack('/dest', lambda _: sendrefs, gen_pack)
+        except errors.UpdateRefsError, e:
+            self.assertEqual('refs/heads/branch, refs/heads/master failed to '
+                             'update', str(e))
+            self.assertEqual({'refs/heads/branch': 'non-fast-forward',
+                              'refs/heads/master': 'non-fast-forward'},
+                             e.ref_status)

+ 57 - 4
dulwich/tests/compat/utils.py

@@ -19,12 +19,16 @@
 
 """Utilities for interacting with cgit."""
 
+import errno
 import os
+import socket
 import subprocess
 import tempfile
+import time
 import unittest
 
 from dulwich.repo import Repo
+from dulwich.protocol import TCP_GIT_PORT
 
 from dulwich.tests import (
     TestSkipped,
@@ -108,15 +112,15 @@ def run_git_or_fail(args, git_path=_DEFAULT_GIT, input=None, **popen_kwargs):
     return stdout
 
 
-def import_repo(name):
+def import_repo_to_dir(name):
     """Import a repo from a fast-export file in a temporary directory.
 
     These are used rather than binary repos for compat tests because they are
     more compact an human-editable, and we already depend on git.
 
     :param name: The name of the repository export file, relative to
-        dulwich/tests/data/repos
-    :returns: An initialized Repo object that lives in a temporary directory.
+        dulwich/tests/data/repos.
+    :returns: The path to the imported repository.
     """
     temp_dir = tempfile.mkdtemp()
     export_path = os.path.join(os.path.dirname(__file__), os.pardir, 'data',
@@ -127,7 +131,44 @@ def import_repo(name):
     run_git_or_fail(['fast-import'], input=export_file.read(),
                     cwd=temp_repo_dir)
     export_file.close()
-    return Repo(temp_repo_dir)
+    return temp_repo_dir
+
+def import_repo(name):
+    """Import a repo from a fast-export file in a temporary directory.
+
+    :param name: The name of the repository export file, relative to
+        dulwich/tests/data/repos.
+    :returns: An initialized Repo object that lives in a temporary directory.
+    """
+    return Repo(import_repo_to_dir(name))
+
+
+def check_for_daemon(limit=10, delay=0.1, timeout=0.1, port=TCP_GIT_PORT):
+    """Check for a running TCP daemon.
+
+    Defaults to checking 10 times with a delay of 0.1 sec between tries.
+
+    :param limit: Number of attempts before deciding no daemon is running.
+    :param delay: Delay between connection attempts.
+    :param timeout: Socket timeout for connection attempts.
+    :param port: Port on which we expect the daemon to appear.
+    :returns: A boolean, true if a daemon is running on the specified port,
+        false if not.
+    """
+    for _ in xrange(limit):
+        time.sleep(delay)
+        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        s.settimeout(delay)
+        try:
+            s.connect(('localhost', port))
+            s.close()
+            return True
+        except socket.error, e:
+            if getattr(e, 'errno', False) and e.errno != errno.ECONNREFUSED:
+                raise
+            elif e.args[0] != errno.ECONNREFUSED:
+                raise
+    return False
 
 
 class CompatTestCase(unittest.TestCase):
@@ -141,3 +182,15 @@ class CompatTestCase(unittest.TestCase):
 
     def setUp(self):
         require_git_version(self.min_git_version)
+
+    def assertReposEqual(self, repo1, repo2):
+        self.assertEqual(repo1.get_refs(), repo2.get_refs())
+        self.assertEqual(sorted(set(repo1.object_store)),
+                         sorted(set(repo2.object_store)))
+
+    def assertReposNotEqual(self, repo1, repo2):
+        refs1 = repo1.get_refs()
+        objs1 = set(repo1.object_store)
+        refs2 = repo2.get_refs()
+        objs2 = set(repo2.object_store)
+        self.assertFalse(refs1 == refs2 and objs1 == objs2)

+ 11 - 5
dulwich/tests/test_client.py

@@ -1,16 +1,16 @@
 # test_client.py -- Tests for the git protocol, client side
 # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # or (at your option) any later version of the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -23,16 +23,22 @@ from dulwich.client import (
     GitClient,
     )
 
+
+# TODO(durin42): add unit-level tests of GitClient
 class GitClientTests(TestCase):
 
     def setUp(self):
         self.rout = StringIO()
         self.rin = StringIO()
-        self.client = GitClient(lambda x: True, self.rin.read, 
+        self.client = GitClient(lambda x: True, self.rin.read,
             self.rout.write)
 
     def test_caps(self):
-        self.assertEquals(['multi_ack', 'side-band-64k', 'ofs-delta', 'thin-pack'], self.client._capabilities)
+        self.assertEquals(set(['multi_ack', 'side-band-64k', 'ofs-delta',
+                               'thin-pack']),
+                          set(self.client._fetch_capabilities))
+        self.assertEquals(set(['ofs-delta', 'report-status']),
+                          set(self.client._send_capabilities))
 
     def test_fetch_pack_none(self):
         self.rin.write(

+ 2 - 2
dulwich/tests/test_pack.py

@@ -136,8 +136,8 @@ class TestPackDeltas(unittest.TestCase):
     test_string_big = 'Z' * 8192
 
     def _test_roundtrip(self, base, target):
-        self.assertEquals([target],
-                          apply_delta(base, create_delta(base, target)))
+        self.assertEquals(target,
+                          ''.join(apply_delta(base, create_delta(base, target))))
 
     def test_nochange(self):
         self._test_roundtrip(self.test_string1, self.test_string1)