2
0
Эх сурвалжийг харах

Merge cleanups and use of proper per-service capabilities from Augie.

Jelmer Vernooij 15 жил өмнө
parent
commit
298f25e2cf

+ 83 - 40
dulwich/client.py

@@ -28,7 +28,8 @@ import subprocess
 
 
 from dulwich.errors import (
 from dulwich.errors import (
     ChecksumMismatch,
     ChecksumMismatch,
-    HangupException,
+    SendPackError,
+    UpdateRefsError,
     )
     )
 from dulwich.protocol import (
 from dulwich.protocol import (
     Protocol,
     Protocol,
@@ -45,16 +46,19 @@ def _fileno_can_read(fileno):
     """Check if a file descriptor is readable."""
     """Check if a file descriptor is readable."""
     return len(select.select([fileno], [], [], 0)[0]) > 0
     return len(select.select([fileno], [], [], 0)[0]) > 0
 
 
+COMMON_CAPABILITIES = ["ofs-delta"]
+FETCH_CAPABILITIES = ["multi_ack", "side-band-64k"] + COMMON_CAPABILITIES
+SEND_CAPABILITIES = ['report-status'] + COMMON_CAPABILITIES
 
 
-CAPABILITIES = ["multi_ack", "side-band-64k", "ofs-delta"]
-
-
+# TODO(durin42): this doesn't correctly degrade if the server doesn't
+# support some capabilities. This should work properly with servers
+# that don't support side-band-64k and multi_ack.
 class GitClient(object):
 class GitClient(object):
     """Git smart server client.
     """Git smart server client.
 
 
     """
     """
 
 
-    def __init__(self, can_read, read, write, thin_packs=True, 
+    def __init__(self, can_read, read, write, thin_packs=True,
         report_activity=None):
         report_activity=None):
         """Create a new GitClient instance.
         """Create a new GitClient instance.
 
 
@@ -68,12 +72,10 @@ class GitClient(object):
         """
         """
         self.proto = Protocol(read, write, report_activity)
         self.proto = Protocol(read, write, report_activity)
         self._can_read = can_read
         self._can_read = can_read
-        self._capabilities = list(CAPABILITIES)
+        self._fetch_capabilities = list(FETCH_CAPABILITIES)
+        self._send_capabilities = list(SEND_CAPABILITIES)
         if thin_packs:
         if thin_packs:
-            self._capabilities.append("thin-pack")
-
-    def capabilities(self):
-        return " ".join(self._capabilities)
+            self._fetch_capabilities.append("thin-pack")
 
 
     def read_refs(self):
     def read_refs(self):
         server_capabilities = None
         server_capabilities = None
@@ -86,14 +88,21 @@ class GitClient(object):
             refs[ref] = sha
             refs[ref] = sha
         return refs, server_capabilities
         return refs, server_capabilities
 
 
+    # TODO(durin42): add side-band-64k capability support here and advertise it
     def send_pack(self, path, determine_wants, generate_pack_contents):
     def send_pack(self, path, determine_wants, generate_pack_contents):
         """Upload a pack to a remote repository.
         """Upload a pack to a remote repository.
 
 
         :param path: Repository path
         :param path: Repository path
-        :param generate_pack_contents: Function that can return the shas of the 
+        :param generate_pack_contents: Function that can return the shas of the
             objects to upload.
             objects to upload.
+
+        :raises SendPackError: if server rejects the pack data
+        :raises UpdateRefsError: if the server supports report-status
+                                 and rejects ref updates
         """
         """
         old_refs, server_capabilities = self.read_refs()
         old_refs, server_capabilities = self.read_refs()
+        if 'report-status' not in server_capabilities:
+            self._send_capabilities.remove('report-status')
         new_refs = determine_wants(old_refs)
         new_refs = determine_wants(old_refs)
         if not new_refs:
         if not new_refs:
             self.proto.write_pkt_line(None)
             self.proto.write_pkt_line(None)
@@ -106,9 +115,12 @@ class GitClient(object):
             new_sha1 = new_refs.get(refname, ZERO_SHA)
             new_sha1 = new_refs.get(refname, ZERO_SHA)
             if old_sha1 != new_sha1:
             if old_sha1 != new_sha1:
                 if sent_capabilities:
                 if sent_capabilities:
-                    self.proto.write_pkt_line("%s %s %s" % (old_sha1, new_sha1, refname))
+                    self.proto.write_pkt_line("%s %s %s" % (old_sha1, new_sha1,
+                                                            refname))
                 else:
                 else:
-                    self.proto.write_pkt_line("%s %s %s\0%s" % (old_sha1, new_sha1, refname, self.capabilities()))
+                    self.proto.write_pkt_line(
+                      "%s %s %s\0%s" % (old_sha1, new_sha1, refname,
+                                        ' '.join(self._send_capabilities)))
                     sent_capabilities = True
                     sent_capabilities = True
             if not new_sha1 in (have, ZERO_SHA):
             if not new_sha1 in (have, ZERO_SHA):
                 want.append(new_sha1)
                 want.append(new_sha1)
@@ -116,20 +128,50 @@ class GitClient(object):
         if not want:
         if not want:
             return new_refs
             return new_refs
         objects = generate_pack_contents(have, want)
         objects = generate_pack_contents(have, want)
-        (entries, sha) = write_pack_data(self.proto.write_file(), objects, 
+        (entries, sha) = write_pack_data(self.proto.write_file(), objects,
                                          len(objects))
                                          len(objects))
-        
-        # read the final confirmation sha
-        try:
-            client_sha = self.proto.read_pkt_line()
-        except HangupException:
-            # for git-daemon versions before v1.6.6.1-26-g38a81b4, there is
-            # nothing to read; catch this and hide from the user.
-            pass
-        else:
-            if not client_sha in (None, "", sha):
-                raise ChecksumMismatch(sha, client_sha)
 
 
+        if 'report-status' in self._send_capabilities:
+            unpack = self.proto.read_pkt_line().strip()
+            if unpack != 'unpack ok':
+                st = True
+                # flush remaining error data
+                while st is not None:
+                    st = self.proto.read_pkt_line()
+                raise SendPackError(unpack)
+            statuses = []
+            errs = False
+            ref_status = self.proto.read_pkt_line()
+            while ref_status:
+                ref_status = ref_status.strip()
+                statuses.append(ref_status)
+                if not ref_status.startswith('ok '):
+                    errs = True
+                ref_status = self.proto.read_pkt_line()
+
+            if errs:
+                ref_status = {}
+                ok = set()
+                for status in statuses:
+                    if ' ' not in status:
+                        # malformed response, move on to the next one
+                        continue
+                    status, ref = status.split(' ', 1)
+
+                    if status == 'ng':
+                        if ' ' in ref:
+                            ref, status = ref.split(' ', 1)
+                    else:
+                        ok.add(ref)
+                    ref_status[ref] = status
+                raise UpdateRefsError('%s failed to update' %
+                                      ', '.join([ref for ref in ref_status
+                                                 if ref not in ok]),
+                                      ref_status=ref_status)
+        # wait for EOF before returning
+        data = self.proto.read()
+        if data:
+            raise SendPackError('Unexpected response %r' % data)
         return new_refs
         return new_refs
 
 
     def fetch(self, path, target, determine_wants=None, progress=None):
     def fetch(self, path, target, determine_wants=None, progress=None):
@@ -137,7 +179,7 @@ class GitClient(object):
 
 
         :param path: Path to fetch from
         :param path: Path to fetch from
         :param target: Target repository to fetch into
         :param target: Target repository to fetch into
-        :param determine_wants: Optional function to determine what refs 
+        :param determine_wants: Optional function to determine what refs
             to fetch
             to fetch
         :param progress: Optional progress function
         :param progress: Optional progress function
         :return: remote refs
         :return: remote refs
@@ -146,7 +188,7 @@ class GitClient(object):
             determine_wants = target.object_store.determine_wants_all
             determine_wants = target.object_store.determine_wants_all
         f, commit = target.object_store.add_pack()
         f, commit = target.object_store.add_pack()
         try:
         try:
-            return self.fetch_pack(path, determine_wants, 
+            return self.fetch_pack(path, determine_wants,
                 target.get_graph_walker(), f.write, progress)
                 target.get_graph_walker(), f.write, progress)
         finally:
         finally:
             commit()
             commit()
@@ -166,7 +208,8 @@ class GitClient(object):
             self.proto.write_pkt_line(None)
             self.proto.write_pkt_line(None)
             return refs
             return refs
         assert isinstance(wants, list) and type(wants[0]) == str
         assert isinstance(wants, list) and type(wants[0]) == str
-        self.proto.write_pkt_line("want %s %s\n" % (wants[0], self.capabilities()))
+        self.proto.write_pkt_line("want %s %s\n" % (
+            wants[0], ' '.join(self._fetch_capabilities)))
         for want in wants[1:]:
         for want in wants[1:]:
             self.proto.write_pkt_line("want %s\n" % want)
             self.proto.write_pkt_line("want %s\n" % want)
         self.proto.write_pkt_line(None)
         self.proto.write_pkt_line(None)
@@ -189,6 +232,8 @@ class GitClient(object):
             if len(parts) < 3 or parts[2] != "continue":
             if len(parts) < 3 or parts[2] != "continue":
                 break
                 break
             pkt = self.proto.read_pkt_line()
             pkt = self.proto.read_pkt_line()
+        # TODO(durin42): this is broken if the server didn't support the
+        # side-band-64k capability.
         for pkt in self.proto.read_pkt_seq():
         for pkt in self.proto.read_pkt_seq():
             channel = ord(pkt[0])
             channel = ord(pkt[0])
             pkt = pkt[1:]
             pkt = pkt[1:]
@@ -224,9 +269,9 @@ class TCPGitClient(GitClient):
 
 
     def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress):
     def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress):
         """Fetch a pack from the remote host.
         """Fetch a pack from the remote host.
-        
+
         :param path: Path of the reposiutory on the remote host
         :param path: Path of the reposiutory on the remote host
-        :param determine_wants: Callback that receives available refs dict and 
+        :param determine_wants: Callback that receives available refs dict and
             should return list of sha's to fetch.
             should return list of sha's to fetch.
         :param graph_walker: GraphWalker instance used to find missing shas
         :param graph_walker: GraphWalker instance used to find missing shas
         :param pack_data: Callback for writing pack data
         :param pack_data: Callback for writing pack data
@@ -262,18 +307,18 @@ class SubprocessGitClient(GitClient):
 
 
         :param path: Path to the git repository on the server
         :param path: Path to the git repository on the server
         :param changed_refs: Dictionary with new values for the refs
         :param changed_refs: Dictionary with new values for the refs
-        :param generate_pack_contents: Function that returns an iterator over 
+        :param generate_pack_contents: Function that returns an iterator over
             objects to send
             objects to send
         """
         """
         client = self._connect("git-receive-pack", path)
         client = self._connect("git-receive-pack", path)
         return client.send_pack(path, changed_refs, generate_pack_contents)
         return client.send_pack(path, changed_refs, generate_pack_contents)
 
 
-    def fetch_pack(self, path, determine_wants, graph_walker, pack_data, 
+    def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
         progress):
         progress):
         """Retrieve a pack from the server
         """Retrieve a pack from the server
 
 
         :param path: Path to the git repository on the server
         :param path: Path to the git repository on the server
-        :param determine_wants: Function that receives existing refs 
+        :param determine_wants: Function that receives existing refs
             on the server and returns a list of desired shas
             on the server and returns a list of desired shas
         :param graph_walker: GraphWalker instance
         :param graph_walker: GraphWalker instance
         :param pack_data: Function that can write pack data
         :param pack_data: Function that can write pack data
@@ -289,12 +334,8 @@ class SSHSubprocess(object):
 
 
     def __init__(self, proc):
     def __init__(self, proc):
         self.proc = proc
         self.proc = proc
-
-    def send(self, data):
-        return os.write(self.proc.stdin.fileno(), data)
-
-    def recv(self, count):
-        return self.proc.stdout.read(count)
+        self.read = self.recv = proc.stdout.read
+        self.write = self.send = proc.stdin.write
 
 
     def close(self):
     def close(self):
         self.proc.stdin.close()
         self.proc.stdin.close()
@@ -331,7 +372,9 @@ class SSHGitClient(GitClient):
         self._kwargs = kwargs
         self._kwargs = kwargs
 
 
     def send_pack(self, path, determine_wants, generate_pack_contents):
     def send_pack(self, path, determine_wants, generate_pack_contents):
-        remote = get_ssh_vendor().connect_ssh(self.host, ["git-receive-pack '%s'" % path], port=self.port, username=self.username)
+        remote = get_ssh_vendor().connect_ssh(
+            self.host, ["git-receive-pack '%s'" % path],
+            port=self.port, username=self.username)
         client = GitClient(lambda: _fileno_can_read(remote.proc.stdout.fileno()), remote.recv, remote.send, *self._args, **self._kwargs)
         client = GitClient(lambda: _fileno_can_read(remote.proc.stdout.fileno()), remote.recv, remote.send, *self._args, **self._kwargs)
         return client.send_pack(path, determine_wants, generate_pack_contents)
         return client.send_pack(path, determine_wants, generate_pack_contents)
 
 

+ 23 - 8
dulwich/errors.py

@@ -1,17 +1,17 @@
 # errors.py -- errors for dulwich
 # errors.py -- errors for dulwich
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
 # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
 # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # as published by the Free Software Foundation; version 2
 # or (at your option) any later version of the License.
 # or (at your option) any later version of the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -38,7 +38,7 @@ class ChecksumMismatch(Exception):
                 "Checksum mismatch: Expected %s, got %s" % (expected, got))
                 "Checksum mismatch: Expected %s, got %s" % (expected, got))
         else:
         else:
             Exception.__init__(self,
             Exception.__init__(self,
-                "Checksum mismatch: Expected %s, got %s; %s" % 
+                "Checksum mismatch: Expected %s, got %s; %s" %
                 (expected, got, extra))
                 (expected, got, extra))
 
 
 
 
@@ -81,21 +81,21 @@ class NotBlobError(WrongObjectException):
 
 
 class MissingCommitError(Exception):
 class MissingCommitError(Exception):
     """Indicates that a commit was not found in the repository"""
     """Indicates that a commit was not found in the repository"""
-  
+
     def __init__(self, sha, *args, **kwargs):
     def __init__(self, sha, *args, **kwargs):
         Exception.__init__(self, "%s is not in the revision store" % sha)
         Exception.__init__(self, "%s is not in the revision store" % sha)
 
 
 
 
 class ObjectMissing(Exception):
 class ObjectMissing(Exception):
     """Indicates that a requested object is missing."""
     """Indicates that a requested object is missing."""
-  
+
     def __init__(self, sha, *args, **kwargs):
     def __init__(self, sha, *args, **kwargs):
         Exception.__init__(self, "%s is not in the pack" % sha)
         Exception.__init__(self, "%s is not in the pack" % sha)
 
 
 
 
 class ApplyDeltaError(Exception):
 class ApplyDeltaError(Exception):
     """Indicates that applying a delta failed."""
     """Indicates that applying a delta failed."""
-    
+
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         Exception.__init__(self, *args, **kwargs)
         Exception.__init__(self, *args, **kwargs)
 
 
@@ -109,8 +109,23 @@ class NotGitRepository(Exception):
 
 
 class GitProtocolError(Exception):
 class GitProtocolError(Exception):
     """Git protocol exception."""
     """Git protocol exception."""
-    
+
+    def __init__(self, *args, **kwargs):
+        Exception.__init__(self, *args, **kwargs)
+
+
+class SendPackError(GitProtocolError):
+    """An error occurred during send_pack."""
+
+    def __init__(self, *args, **kwargs):
+        Exception.__init__(self, *args, **kwargs)
+
+
+class UpdateRefsError(GitProtocolError):
+    """The server reported errors updating refs."""
+
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
+        self.ref_status = kwargs.pop('ref_status')
         Exception.__init__(self, *args, **kwargs)
         Exception.__init__(self, *args, **kwargs)
 
 
 
 

+ 20 - 20
dulwich/object_store.py

@@ -60,7 +60,8 @@ class BaseObjectStore(object):
     """Object store interface."""
     """Object store interface."""
 
 
     def determine_wants_all(self, refs):
     def determine_wants_all(self, refs):
-	    return [sha for (ref, sha) in refs.iteritems() if not sha in self and not ref.endswith("^{}")]
+        return [sha for (ref, sha) in refs.iteritems()
+                if not sha in self and not ref.endswith("^{}")]
 
 
     def iter_shas(self, shas):
     def iter_shas(self, shas):
         """Iterate over the objects for the specified shas.
         """Iterate over the objects for the specified shas.
@@ -148,7 +149,7 @@ class BaseObjectStore(object):
                     newmode = None
                     newmode = None
                     newhexsha = None
                     newhexsha = None
                     newchildpath = None
                     newchildpath = None
-                if (want_unchanged or oldmode != newmode or 
+                if (want_unchanged or oldmode != newmode or
                     oldhexsha != newhexsha):
                     oldhexsha != newhexsha):
                     if stat.S_ISDIR(oldmode):
                     if stat.S_ISDIR(oldmode):
                         if newmode is None or stat.S_ISDIR(newmode):
                         if newmode is None or stat.S_ISDIR(newmode):
@@ -182,7 +183,7 @@ class BaseObjectStore(object):
         while todo:
         while todo:
             (tid, tpath) = todo.pop()
             (tid, tpath) = todo.pop()
             tree = self[tid]
             tree = self[tid]
-            for name, mode, hexsha in tree.iteritems(): 
+            for name, mode, hexsha in tree.iteritems():
                 path = posixpath.join(tpath, name)
                 path = posixpath.join(tpath, name)
                 if stat.S_ISDIR(mode):
                 if stat.S_ISDIR(mode):
                     todo.add((hexsha, path))
                     todo.add((hexsha, path))
@@ -195,7 +196,7 @@ class BaseObjectStore(object):
 
 
         :param haves: Iterable over SHAs already in common.
         :param haves: Iterable over SHAs already in common.
         :param wants: Iterable over SHAs of objects to fetch.
         :param wants: Iterable over SHAs of objects to fetch.
-        :param progress: Simple progress function that will be called with 
+        :param progress: Simple progress function that will be called with
             updated progress strings.
             updated progress strings.
         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
         :param get_tagged: Function that returns a dict of pointed-to sha -> tag
             sha for including tags.
             sha for including tags.
@@ -221,7 +222,7 @@ class BaseObjectStore(object):
 
 
     def get_graph_walker(self, heads):
     def get_graph_walker(self, heads):
         """Obtain a graph walker for this object store.
         """Obtain a graph walker for this object store.
-        
+
         :param heads: Local heads to start search with
         :param heads: Local heads to start search with
         :return: GraphWalker object
         :return: GraphWalker object
         """
         """
@@ -304,7 +305,7 @@ class PackBasedObjectStore(BaseObjectStore):
                 return pack.get_raw(sha)
                 return pack.get_raw(sha)
             except KeyError:
             except KeyError:
                 pass
                 pass
-        if hexsha is None: 
+        if hexsha is None:
             hexsha = sha_to_hex(name)
             hexsha = sha_to_hex(name)
         ret = self._get_loose_object(hexsha)
         ret = self._get_loose_object(hexsha)
         if ret is not None:
         if ret is not None:
@@ -387,7 +388,7 @@ class DiskObjectStore(PackBasedObjectStore):
     def move_in_thin_pack(self, path):
     def move_in_thin_pack(self, path):
         """Move a specific file containing a pack into the pack directory.
         """Move a specific file containing a pack into the pack directory.
 
 
-        :note: The file should be on the same file system as the 
+        :note: The file should be on the same file system as the
             packs directory.
             packs directory.
 
 
         :param path: Path to the pack file.
         :param path: Path to the pack file.
@@ -395,13 +396,13 @@ class DiskObjectStore(PackBasedObjectStore):
         data = ThinPackData(self.get_raw, path)
         data = ThinPackData(self.get_raw, path)
 
 
         # Write index for the thin pack (do we really need this?)
         # Write index for the thin pack (do we really need this?)
-        temppath = os.path.join(self.pack_dir, 
+        temppath = os.path.join(self.pack_dir,
             sha_to_hex(urllib2.randombytes(20))+".tempidx")
             sha_to_hex(urllib2.randombytes(20))+".tempidx")
         data.create_index_v2(temppath)
         data.create_index_v2(temppath)
         p = Pack.from_objects(data, load_pack_index(temppath))
         p = Pack.from_objects(data, load_pack_index(temppath))
 
 
         # Write a full pack version
         # Write a full pack version
-        temppath = os.path.join(self.pack_dir, 
+        temppath = os.path.join(self.pack_dir,
             sha_to_hex(urllib2.randombytes(20))+".temppack")
             sha_to_hex(urllib2.randombytes(20))+".temppack")
         write_pack(temppath, ((o, None) for o in p.iterobjects()), len(p))
         write_pack(temppath, ((o, None) for o in p.iterobjects()), len(p))
         pack_sha = load_pack_index(temppath+".idx").objects_sha1()
         pack_sha = load_pack_index(temppath+".idx").objects_sha1()
@@ -415,14 +416,14 @@ class DiskObjectStore(PackBasedObjectStore):
     def move_in_pack(self, path):
     def move_in_pack(self, path):
         """Move a specific file containing a pack into the pack directory.
         """Move a specific file containing a pack into the pack directory.
 
 
-        :note: The file should be on the same file system as the 
+        :note: The file should be on the same file system as the
             packs directory.
             packs directory.
 
 
         :param path: Path to the pack file.
         :param path: Path to the pack file.
         """
         """
         p = PackData(path)
         p = PackData(path)
         entries = p.sorted_entries()
         entries = p.sorted_entries()
-        basename = os.path.join(self.pack_dir, 
+        basename = os.path.join(self.pack_dir,
             "pack-%s" % iter_sha1(entry[0] for entry in entries))
             "pack-%s" % iter_sha1(entry[0] for entry in entries))
         write_pack_index_v2(basename+".idx", entries, p.get_stored_checksum())
         write_pack_index_v2(basename+".idx", entries, p.get_stored_checksum())
         p.close()
         p.close()
@@ -434,7 +435,7 @@ class DiskObjectStore(PackBasedObjectStore):
     def add_thin_pack(self):
     def add_thin_pack(self):
         """Add a new thin pack to this object store.
         """Add a new thin pack to this object store.
 
 
-        Thin packs are packs that contain deltas with parents that exist 
+        Thin packs are packs that contain deltas with parents that exist
         in a different pack.
         in a different pack.
         """
         """
         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
@@ -449,9 +450,9 @@ class DiskObjectStore(PackBasedObjectStore):
         return f, commit
         return f, commit
 
 
     def add_pack(self):
     def add_pack(self):
-        """Add a new pack to this object store. 
+        """Add a new pack to this object store.
 
 
-        :return: Fileobject to write to and a commit function to 
+        :return: Fileobject to write to and a commit function to
             call when the pack is finished.
             call when the pack is finished.
         """
         """
         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
@@ -607,7 +608,7 @@ class ObjectStoreIterator(ObjectIterator):
     def __contains__(self, needle):
     def __contains__(self, needle):
         """Check if an object is present.
         """Check if an object is present.
 
 
-        :note: This checks if the object is present in 
+        :note: This checks if the object is present in
             the underlying object store, not if it would
             the underlying object store, not if it would
             be yielded by the iterator.
             be yielded by the iterator.
 
 
@@ -617,7 +618,7 @@ class ObjectStoreIterator(ObjectIterator):
 
 
     def __getitem__(self, key):
     def __getitem__(self, key):
         """Find an object by SHA1.
         """Find an object by SHA1.
-        
+
         :note: This retrieves the object from the underlying
         :note: This retrieves the object from the underlying
             object store. It will also succeed if the object would
             object store. It will also succeed if the object would
             not be returned by the iterator.
             not be returned by the iterator.
@@ -652,7 +653,7 @@ def tree_lookup_path(lookup_obj, root_sha, path):
 class MissingObjectFinder(object):
 class MissingObjectFinder(object):
     """Find the objects missing from another object store.
     """Find the objects missing from another object store.
 
 
-    :param object_store: Object store containing at least all objects to be 
+    :param object_store: Object store containing at least all objects to be
         sent
         sent
     :param haves: SHA1s of commits not to send (already present in target)
     :param haves: SHA1s of commits not to send (already present in target)
     :param wants: SHA1s of commits to send
     :param wants: SHA1s of commits to send
@@ -706,9 +707,8 @@ class MissingObjectFinder(object):
 
 
 
 
 class ObjectStoreGraphWalker(object):
 class ObjectStoreGraphWalker(object):
-    """Graph walker that finds out what commits are missing from an object 
-    store.
-    
+    """Graph walker that finds what commits are missing from an object store.
+
     :ivar heads: Revisions without descendants in the local repo
     :ivar heads: Revisions without descendants in the local repo
     :ivar get_parents: Function to retrieve parents in the local repo
     :ivar get_parents: Function to retrieve parents in the local repo
     """
     """

+ 0 - 13
dulwich/tests/compat/server_utils.py

@@ -51,19 +51,6 @@ class ServerTests(object):
         tear_down_repo(self._old_repo)
         tear_down_repo(self._old_repo)
         tear_down_repo(self._new_repo)
         tear_down_repo(self._new_repo)
 
 
-    def assertReposEqual(self, repo1, repo2):
-        self.assertEqual(repo1.get_refs(), repo2.get_refs())
-        self.assertEqual(sorted(set(repo1.object_store)),
-                         sorted(set(repo2.object_store)))
-
-    def assertReposNotEqual(self, repo1, repo2):
-        refs1 = repo1.get_refs()
-        objs1 = set(repo1.object_store)
-        refs2 = repo2.get_refs()
-        objs2 = set(repo2.object_store)
-
-        self.assertFalse(refs1 == refs2 and objs1 == objs2)
-
     def test_push_to_dulwich(self):
     def test_push_to_dulwich(self):
         self.assertReposNotEqual(self._old_repo, self._new_repo)
         self.assertReposNotEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._old_repo)
         port = self._start_server(self._old_repo)

+ 150 - 0
dulwich/tests/compat/test_client.py

@@ -0,0 +1,150 @@
+# test_client.py -- Compatibilty tests for git client.
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Compatibilty tests between the Dulwich client and the cgit server."""
+
+import os
+import shutil
+import signal
+import tempfile
+
+from dulwich import client
+from dulwich import errors
+from dulwich import file
+from dulwich import index
+from dulwich import protocol
+from dulwich import object_store
+from dulwich import objects
+from dulwich import repo
+from dulwich.tests import (
+    TestSkipped,
+    )
+
+from utils import (
+    CompatTestCase,
+    check_for_daemon,
+    import_repo_to_dir,
+    run_git,
+    )
+
+class DulwichClientTest(CompatTestCase):
+    """Tests for client/server compatibility."""
+
+    def setUp(self):
+        if check_for_daemon(limit=1):
+            raise TestSkipped('git-daemon was already running on port %s' %
+                              protocol.TCP_GIT_PORT)
+        CompatTestCase.setUp(self)
+        fd, self.pidfile = tempfile.mkstemp(prefix='dulwich-test-git-client',
+                                            suffix=".pid")
+        os.fdopen(fd).close()
+        self.gitroot = os.path.dirname(import_repo_to_dir('server_new.export'))
+        dest = os.path.join(self.gitroot, 'dest')
+        file.ensure_dir_exists(dest)
+        run_git(['init', '--bare'], cwd=dest)
+        run_git(
+            ['daemon', '--verbose', '--export-all',
+             '--pid-file=%s' % self.pidfile, '--base-path=%s' % self.gitroot,
+             '--detach', '--reuseaddr', '--enable=receive-pack',
+             '--listen=localhost', self.gitroot], cwd=self.gitroot)
+        if not check_for_daemon():
+            raise TestSkipped('git-daemon failed to start')
+
+    def tearDown(self):
+        CompatTestCase.tearDown(self)
+        try:
+            os.kill(int(open(self.pidfile).read().strip()), signal.SIGKILL)
+            os.unlink(self.pidfile)
+        except (OSError, IOError):
+            pass
+        shutil.rmtree(self.gitroot)
+
+    def test_send_pack(self):
+        c = client.TCPGitClient('localhost')
+        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        src = repo.Repo(srcpath)
+        sendrefs = dict(src.get_refs())
+        del sendrefs['HEAD']
+        c.send_pack('/dest', lambda _: sendrefs,
+                    src.object_store.generate_pack_contents)
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        self.assertReposEqual(src, dest)
+
+    def test_send_without_report_status(self):
+        c = client.TCPGitClient('localhost')
+        c._send_capabilities.remove('report-status')
+        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        src = repo.Repo(srcpath)
+        sendrefs = dict(src.get_refs())
+        del sendrefs['HEAD']
+        c.send_pack('/dest', lambda _: sendrefs,
+                    src.object_store.generate_pack_contents)
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        self.assertReposEqual(src, dest)
+
+    def disable_ff_and_make_dummy_commit(self):
+        # disable non-fast-forward pushes to the server
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        run_git(['config', 'receive.denyNonFastForwards', 'true'], cwd=dest.path)
+        b = objects.Blob.from_string('hi')
+        dest.object_store.add_object(b)
+        t = index.commit_tree(dest.object_store, [('hi', b.id, 0100644)])
+        c = objects.Commit()
+        c.author = c.committer = 'Foo Bar <foo@example.com>'
+        c.author_time = c.commit_time = 0
+        c.author_timezone = c.commit_timezone = 0
+        c.message = 'hi'
+        c.tree = t
+        dest.object_store.add_object(c)
+        return dest, c.id
+
+    def compute_send(self):
+        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        src = repo.Repo(srcpath)
+        sendrefs = dict(src.get_refs())
+        del sendrefs['HEAD']
+        return sendrefs, src.object_store.generate_pack_contents
+
+    def test_send_pack_one_error(self):
+        dest, dummy_commit = self.disable_ff_and_make_dummy_commit()
+        dest.refs['refs/heads/master'] = dummy_commit
+        sendrefs, gen_pack = self.compute_send()
+        c = client.TCPGitClient('localhost')
+        try:
+            c.send_pack('/dest', lambda _: sendrefs, gen_pack)
+        except errors.UpdateRefsError, e:
+            self.assertEqual('refs/heads/master failed to update', str(e))
+            self.assertEqual({'refs/heads/branch': 'ok',
+                              'refs/heads/master': 'non-fast-forward'},
+                             e.ref_status)
+
+    def test_send_pack_multiple_errors(self):
+        dest, dummy = self.disable_ff_and_make_dummy_commit()
+        # set up for two non-ff errors
+        dest.refs['refs/heads/branch'] = dest.refs['refs/heads/master'] = dummy
+        sendrefs, gen_pack = self.compute_send()
+        c = client.TCPGitClient('localhost')
+        try:
+            c.send_pack('/dest', lambda _: sendrefs, gen_pack)
+        except errors.UpdateRefsError, e:
+            self.assertEqual('refs/heads/branch, refs/heads/master failed to '
+                             'update', str(e))
+            self.assertEqual({'refs/heads/branch': 'non-fast-forward',
+                              'refs/heads/master': 'non-fast-forward'},
+                             e.ref_status)

+ 57 - 4
dulwich/tests/compat/utils.py

@@ -19,12 +19,16 @@
 
 
 """Utilities for interacting with cgit."""
 """Utilities for interacting with cgit."""
 
 
+import errno
 import os
 import os
+import socket
 import subprocess
 import subprocess
 import tempfile
 import tempfile
+import time
 import unittest
 import unittest
 
 
 from dulwich.repo import Repo
 from dulwich.repo import Repo
+from dulwich.protocol import TCP_GIT_PORT
 
 
 from dulwich.tests import (
 from dulwich.tests import (
     TestSkipped,
     TestSkipped,
@@ -108,15 +112,15 @@ def run_git_or_fail(args, git_path=_DEFAULT_GIT, input=None, **popen_kwargs):
     return stdout
     return stdout
 
 
 
 
-def import_repo(name):
+def import_repo_to_dir(name):
     """Import a repo from a fast-export file in a temporary directory.
     """Import a repo from a fast-export file in a temporary directory.
 
 
     These are used rather than binary repos for compat tests because they are
     These are used rather than binary repos for compat tests because they are
     more compact an human-editable, and we already depend on git.
     more compact an human-editable, and we already depend on git.
 
 
     :param name: The name of the repository export file, relative to
     :param name: The name of the repository export file, relative to
-        dulwich/tests/data/repos
-    :returns: An initialized Repo object that lives in a temporary directory.
+        dulwich/tests/data/repos.
+    :returns: The path to the imported repository.
     """
     """
     temp_dir = tempfile.mkdtemp()
     temp_dir = tempfile.mkdtemp()
     export_path = os.path.join(os.path.dirname(__file__), os.pardir, 'data',
     export_path = os.path.join(os.path.dirname(__file__), os.pardir, 'data',
@@ -127,7 +131,44 @@ def import_repo(name):
     run_git_or_fail(['fast-import'], input=export_file.read(),
     run_git_or_fail(['fast-import'], input=export_file.read(),
                     cwd=temp_repo_dir)
                     cwd=temp_repo_dir)
     export_file.close()
     export_file.close()
-    return Repo(temp_repo_dir)
+    return temp_repo_dir
+
+def import_repo(name):
+    """Import a repo from a fast-export file in a temporary directory.
+
+    :param name: The name of the repository export file, relative to
+        dulwich/tests/data/repos.
+    :returns: An initialized Repo object that lives in a temporary directory.
+    """
+    return Repo(import_repo_to_dir(name))
+
+
+def check_for_daemon(limit=10, delay=0.1, timeout=0.1, port=TCP_GIT_PORT):
+    """Check for a running TCP daemon.
+
+    Defaults to checking 10 times with a delay of 0.1 sec between tries.
+
+    :param limit: Number of attempts before deciding no daemon is running.
+    :param delay: Delay between connection attempts.
+    :param timeout: Socket timeout for connection attempts.
+    :param port: Port on which we expect the daemon to appear.
+    :returns: A boolean, true if a daemon is running on the specified port,
+        false if not.
+    """
+    for _ in xrange(limit):
+        time.sleep(delay)
+        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        s.settimeout(delay)
+        try:
+            s.connect(('localhost', port))
+            s.close()
+            return True
+        except socket.error, e:
+            if getattr(e, 'errno', False) and e.errno != errno.ECONNREFUSED:
+                raise
+            elif e.args[0] != errno.ECONNREFUSED:
+                raise
+    return False
 
 
 
 
 class CompatTestCase(unittest.TestCase):
 class CompatTestCase(unittest.TestCase):
@@ -141,3 +182,15 @@ class CompatTestCase(unittest.TestCase):
 
 
     def setUp(self):
     def setUp(self):
         require_git_version(self.min_git_version)
         require_git_version(self.min_git_version)
+
+    def assertReposEqual(self, repo1, repo2):
+        self.assertEqual(repo1.get_refs(), repo2.get_refs())
+        self.assertEqual(sorted(set(repo1.object_store)),
+                         sorted(set(repo2.object_store)))
+
+    def assertReposNotEqual(self, repo1, repo2):
+        refs1 = repo1.get_refs()
+        objs1 = set(repo1.object_store)
+        refs2 = repo2.get_refs()
+        objs2 = set(repo2.object_store)
+        self.assertFalse(refs1 == refs2 and objs1 == objs2)

+ 11 - 5
dulwich/tests/test_client.py

@@ -1,16 +1,16 @@
 # test_client.py -- Tests for the git protocol, client side
 # test_client.py -- Tests for the git protocol, client side
 # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
 # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # as published by the Free Software Foundation; version 2
 # or (at your option) any later version of the License.
 # or (at your option) any later version of the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -23,16 +23,22 @@ from dulwich.client import (
     GitClient,
     GitClient,
     )
     )
 
 
+
+# TODO(durin42): add unit-level tests of GitClient
 class GitClientTests(TestCase):
 class GitClientTests(TestCase):
 
 
     def setUp(self):
     def setUp(self):
         self.rout = StringIO()
         self.rout = StringIO()
         self.rin = StringIO()
         self.rin = StringIO()
-        self.client = GitClient(lambda x: True, self.rin.read, 
+        self.client = GitClient(lambda x: True, self.rin.read,
             self.rout.write)
             self.rout.write)
 
 
     def test_caps(self):
     def test_caps(self):
-        self.assertEquals(['multi_ack', 'side-band-64k', 'ofs-delta', 'thin-pack'], self.client._capabilities)
+        self.assertEquals(set(['multi_ack', 'side-band-64k', 'ofs-delta',
+                               'thin-pack']),
+                          set(self.client._fetch_capabilities))
+        self.assertEquals(set(['ofs-delta', 'report-status']),
+                          set(self.client._send_capabilities))
 
 
     def test_fetch_pack_none(self):
     def test_fetch_pack_none(self):
         self.rin.write(
         self.rin.write(

+ 2 - 2
dulwich/tests/test_pack.py

@@ -136,8 +136,8 @@ class TestPackDeltas(unittest.TestCase):
     test_string_big = 'Z' * 8192
     test_string_big = 'Z' * 8192
 
 
     def _test_roundtrip(self, base, target):
     def _test_roundtrip(self, base, target):
-        self.assertEquals([target],
-                          apply_delta(base, create_delta(base, target)))
+        self.assertEquals(target,
+                          ''.join(apply_delta(base, create_delta(base, target))))
 
 
     def test_nochange(self):
     def test_nochange(self):
         self._test_roundtrip(self.test_string1, self.test_string1)
         self._test_roundtrip(self.test_string1, self.test_string1)