Browse Source

New upstream release.

Jelmer Vernooij 13 năm trước cách đây
mục cha
commit
ae6468372b

+ 0 - 5
HACKING

@@ -23,8 +23,3 @@ will run the tests using unittest on Python 2.7 and higher, and using
 unittest2 (which you will need to have installed) on older versions of Python.
 unittest2 (which you will need to have installed) on older versions of Python.
 
 
  $ make check
  $ make check
-
-Alternatively, if you have testtools installed you can run the testsuite by
-overriding the test runner:
-
- $ make check TESTRUNNER=testtools.run

+ 5 - 0
Makefile

@@ -28,9 +28,14 @@ check:: build
 check-nocompat:: build
 check-nocompat:: build
 	$(RUNTEST) dulwich.tests.nocompat_test_suite
 	$(RUNTEST) dulwich.tests.nocompat_test_suite
 
 
+check-pypy:: clean
+	$(MAKE) check-noextensions PYTHON=pypy
+
 check-noextensions:: clean
 check-noextensions:: clean
 	$(RUNTEST) dulwich.tests.test_suite
 	$(RUNTEST) dulwich.tests.test_suite
 
 
+check-all: check check-pypy check-noextensions
+
 clean::
 clean::
 	$(SETUP) clean --all
 	$(SETUP) clean --all
 	rm -f dulwich/*.so
 	rm -f dulwich/*.so

+ 45 - 0
NEWS

@@ -1,3 +1,48 @@
+0.8.1	2011-10-31
+
+ FEATURES
+
+  * Repo.do_commit has a new argument 'ref'.
+
+  * Repo.do_commit has a new argument 'merge_heads'. (Jelmer Vernooij)
+
+  * New ``Repo.get_walker`` method. (Jelmer Vernooij)
+
+  * New ``Repo.clone`` method. (Jelmer Vernooij, #725369)
+
+  * ``GitClient.send_pack`` now supports the 'side-band-64k' capability.
+    (Jelmer Vernooij)
+
+  * ``HttpGitClient`` which supports the smart server protocol over
+    HTTP. "dumb" access is not yet supported. (Jelmer Vernooij, #373688)
+
+  * Add basic support for alternates. (Jelmer Vernooij, #810429)
+
+ CHANGES
+
+  * unittest2 or python >= 2.7 is now required for the testsuite.
+    testtools is no longer supported. (Jelmer Vernooij, #830713)
+
+ BUG FIXES
+
+  * Fix compilation with older versions of MSVC.  (Martin gz)
+
+  * Special case 'refs/stash' as a valid ref. (Jelmer Vernooij, #695577)
+
+  * Smart protocol clients can now change refs even if they are
+    not uploading new data. (Jelmer Vernooij, #855993)
+
+ * Don't compile C extensions when running in pypy.
+   (Ronny Pfannschmidt, #881546)
+
+ * Use different name for strnlen replacement function to avoid clashing
+   with system strnlen. (Jelmer Vernooij, #880362)
+
+ API CHANGES
+
+  * ``Repo.revision_history`` is now deprecated in favor of ``Repo.get_walker``.
+    (Jelmer Vernooij)
+
 0.8.0	2011-08-07
 0.8.0	2011-08-07
 
 
  FEATURES
  FEATURES

+ 6 - 0
debian/changelog

@@ -1,3 +1,9 @@
+dulwich (0.8.1-1) UNRELEASED; urgency=low
+
+  * New upstream release.
+
+ -- Jelmer Vernooij <jelmer@debian.org>  Mon, 31 Oct 2011 13:07:39 -0700
+
 dulwich (0.8.0-1) unstable; urgency=low
 dulwich (0.8.0-1) unstable; urgency=low
 
 
   * New upstream release.
   * New upstream release.

+ 1 - 1
dulwich/__init__.py

@@ -23,4 +23,4 @@
 
 
 from dulwich import (client, protocol, repo, server)
 from dulwich import (client, protocol, repo, server)
 
 
-__version__ = (0, 8, 0)
+__version__ = (0, 8, 1)

+ 4 - 6
dulwich/_objects.c

@@ -21,20 +21,18 @@
 #include <stdlib.h>
 #include <stdlib.h>
 #include <sys/stat.h>
 #include <sys/stat.h>
 
 
-#if defined(__APPLE__)
-#include <Availability.h>
-#endif
-
 #if (PY_VERSION_HEX < 0x02050000)
 #if (PY_VERSION_HEX < 0x02050000)
 typedef int Py_ssize_t;
 typedef int Py_ssize_t;
 #endif
 #endif
 
 
-#if defined(__MINGW32_VERSION) || (defined(__APPLE__) && __MAC_OS_X_VERSION_MIN_REQUIRED < 1070)
-size_t strnlen(char *text, size_t maxlen)
+#if defined(__MINGW32_VERSION) || defined(__APPLE__)
+size_t rep_strnlen(char *text, size_t maxlen);
+size_t rep_strnlen(char *text, size_t maxlen)
 {
 {
 	const char *last = memchr(text, '\0', maxlen);
 	const char *last = memchr(text, '\0', maxlen);
 	return last ? (size_t) (last - text) : maxlen;
 	return last ? (size_t) (last - text) : maxlen;
 }
 }
+#define strnlen rep_strnlen
 #endif
 #endif
 
 
 #define bytehex(x) (((x)<0xa)?('0'+(x)):('a'-0xa+(x)))
 #define bytehex(x) (((x)<0xa)?('0'+(x)):('a'-0xa+(x)))

+ 410 - 92
dulwich/client.py

@@ -21,17 +21,21 @@
 
 
 __docformat__ = 'restructuredText'
 __docformat__ = 'restructuredText'
 
 
+from cStringIO import StringIO
 import select
 import select
 import socket
 import socket
 import subprocess
 import subprocess
+import urllib2
 import urlparse
 import urlparse
 
 
 from dulwich.errors import (
 from dulwich.errors import (
     GitProtocolError,
     GitProtocolError,
+    NotGitRepository,
     SendPackError,
     SendPackError,
     UpdateRefsError,
     UpdateRefsError,
     )
     )
 from dulwich.protocol import (
 from dulwich.protocol import (
+    PktLineParser,
     Protocol,
     Protocol,
     TCP_GIT_PORT,
     TCP_GIT_PORT,
     ZERO_SHA,
     ZERO_SHA,
@@ -52,13 +56,72 @@ def _fileno_can_read(fileno):
     """Check if a file descriptor is readable."""
     """Check if a file descriptor is readable."""
     return len(select.select([fileno], [], [], 0)[0]) > 0
     return len(select.select([fileno], [], [], 0)[0]) > 0
 
 
-COMMON_CAPABILITIES = ['ofs-delta']
-FETCH_CAPABILITIES = ['multi_ack', 'side-band-64k'] + COMMON_CAPABILITIES
+COMMON_CAPABILITIES = ['ofs-delta', 'side-band-64k']
+FETCH_CAPABILITIES = ['multi_ack', 'multi_ack_detailed'] + COMMON_CAPABILITIES
 SEND_CAPABILITIES = ['report-status'] + COMMON_CAPABILITIES
 SEND_CAPABILITIES = ['report-status'] + COMMON_CAPABILITIES
 
 
+
+class ReportStatusParser(object):
+    """Handle status as reported by servers with the 'report-status' capability.
+    """
+
+    def __init__(self):
+        self._done = False
+        self._pack_status = None
+        self._ref_status_ok = True
+        self._ref_statuses = []
+
+    def check(self):
+        """Check if there were any errors and, if so, raise exceptions.
+
+        :raise SendPackError: Raised when the server could not unpack
+        :raise UpdateRefsError: Raised when refs could not be updated
+        """
+        if self._pack_status not in ('unpack ok', None):
+            raise SendPackError(self._pack_status)
+        if not self._ref_status_ok:
+            ref_status = {}
+            ok = set()
+            for status in self._ref_statuses:
+                if ' ' not in status:
+                    # malformed response, move on to the next one
+                    continue
+                status, ref = status.split(' ', 1)
+
+                if status == 'ng':
+                    if ' ' in ref:
+                        ref, status = ref.split(' ', 1)
+                else:
+                    ok.add(ref)
+                ref_status[ref] = status
+            raise UpdateRefsError('%s failed to update' %
+                                  ', '.join([ref for ref in ref_status
+                                             if ref not in ok]),
+                                  ref_status=ref_status)
+
+    def handle_packet(self, pkt):
+        """Handle a packet.
+
+        :raise GitProtocolError: Raised when packets are received after a
+            flush packet.
+        """
+        if self._done:
+            raise GitProtocolError("received more data after status report")
+        if pkt is None:
+            self._done = True
+            return
+        if self._pack_status is None:
+            self._pack_status = pkt.strip()
+        else:
+            ref_status = pkt.strip()
+            self._ref_statuses.append(ref_status)
+            if not ref_status.startswith('ok '):
+                self._ref_status_ok = False
+
+
 # TODO(durin42): this doesn't correctly degrade if the server doesn't
 # TODO(durin42): this doesn't correctly degrade if the server doesn't
 # support some capabilities. This should work properly with servers
 # support some capabilities. This should work properly with servers
-# that don't support side-band-64k and multi_ack.
+# that don't support multi_ack.
 class GitClient(object):
 class GitClient(object):
     """Git smart server client.
     """Git smart server client.
 
 
@@ -77,20 +140,6 @@ class GitClient(object):
         if thin_packs:
         if thin_packs:
             self._fetch_capabilities.append('thin-pack')
             self._fetch_capabilities.append('thin-pack')
 
 
-    def _connect(self, cmd, path):
-        """Create a connection to the server.
-
-        This method is abstract - concrete implementations should
-        implement their own variant which connects to the server and
-        returns an initialized Protocol object with the service ready
-        for use and a can_read function which may be used to see if
-        reads would block.
-
-        :param cmd: The git service name to which we should connect.
-        :param path: The path we should pass to the service.
-        """
-        raise NotImplementedError()
-
     def _read_refs(self, proto):
     def _read_refs(self, proto):
         server_capabilities = None
         server_capabilities = None
         refs = {}
         refs = {}
@@ -104,6 +153,51 @@ class GitClient(object):
             refs[ref] = sha
             refs[ref] = sha
         return refs, server_capabilities
         return refs, server_capabilities
 
 
+    def send_pack(self, path, determine_wants, generate_pack_contents,
+                  progress=None):
+        """Upload a pack to a remote repository.
+
+        :param path: Repository path
+        :param generate_pack_contents: Function that can return a sequence of the
+            shas of the objects to upload.
+        :param progress: Optional progress function
+
+        :raises SendPackError: if server rejects the pack data
+        :raises UpdateRefsError: if the server supports report-status
+                                 and rejects ref updates
+        """
+        raise NotImplementedError(self.send_pack)
+
+    def fetch(self, path, target, determine_wants=None, progress=None):
+        """Fetch into a target repository.
+
+        :param path: Path to fetch from
+        :param target: Target repository to fetch into
+        :param determine_wants: Optional function to determine what refs
+            to fetch
+        :param progress: Optional progress function
+        :return: remote refs
+        """
+        if determine_wants is None:
+            determine_wants = target.object_store.determine_wants_all
+        f, commit = target.object_store.add_pack()
+        try:
+            return self.fetch_pack(path, determine_wants,
+                target.get_graph_walker(), f.write, progress)
+        finally:
+            commit()
+
+    def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
+                   progress):
+        """Retrieve a pack from a git smart server.
+
+        :param determine_wants: Callback that returns list of commits to fetch
+        :param graph_walker: Object with next() and ack().
+        :param pack_data: Callback called for each bit of data in the pack
+        :param progress: Callback for progress reports (strings)
+        """
+        raise NotImplementedError(self.fetch_pack)
+
     def _parse_status_report(self, proto):
     def _parse_status_report(self, proto):
         unpack = proto.read_pkt_line().strip()
         unpack = proto.read_pkt_line().strip()
         if unpack != 'unpack ok':
         if unpack != 'unpack ok':
@@ -142,27 +236,35 @@ class GitClient(object):
                                              if ref not in ok]),
                                              if ref not in ok]),
                                   ref_status=ref_status)
                                   ref_status=ref_status)
 
 
+    def _read_side_band64k_data(self, proto, channel_callbacks):
+        """Read per-channel data.
 
 
-    # TODO(durin42): add side-band-64k capability support here and advertise it
-    def send_pack(self, path, determine_wants, generate_pack_contents):
-        """Upload a pack to a remote repository.
+        This requires the side-band-64k capability.
 
 
-        :param path: Repository path
-        :param generate_pack_contents: Function that can return a sequence of the
-            shas of the objects to upload.
+        :param proto: Protocol object to read from
+        :param channel_callbacks: Dictionary mapping channels to packet
+            handlers to use. None for a callback discards channel data.
+        """
+        for pkt in proto.read_pkt_seq():
+            channel = ord(pkt[0])
+            pkt = pkt[1:]
+            try:
+                cb = channel_callbacks[channel]
+            except KeyError:
+                raise AssertionError('Invalid sideband channel %d' % channel)
+            else:
+                if cb is not None:
+                    cb(pkt)
 
 
-        :raises SendPackError: if server rejects the pack data
-        :raises UpdateRefsError: if the server supports report-status
-                                 and rejects ref updates
+    def _handle_receive_pack_head(self, proto, capabilities, old_refs, new_refs):
+        """Handle the head of a 'git-receive-pack' request.
+
+        :param proto: Protocol object to read from
+        :param capabilities: List of negotiated capabilities
+        :param old_refs: Old refs, as received from the server
+        :param new_refs: New refs
+        :return: (have, want) tuple
         """
         """
-        proto, unused_can_read = self._connect('receive-pack', path)
-        old_refs, server_capabilities = self._read_refs(proto)
-        if 'report-status' not in server_capabilities:
-            self._send_capabilities.remove('report-status')
-        new_refs = determine_wants(old_refs)
-        if not new_refs:
-            proto.write_pkt_line(None)
-            return {}
         want = []
         want = []
         have = [x for x in old_refs.values() if not x == ZERO_SHA]
         have = [x for x in old_refs.values() if not x == ZERO_SHA]
         sent_capabilities = False
         sent_capabilities = False
@@ -176,61 +278,55 @@ class GitClient(object):
                 else:
                 else:
                     proto.write_pkt_line(
                     proto.write_pkt_line(
                       '%s %s %s\0%s' % (old_sha1, new_sha1, refname,
                       '%s %s %s\0%s' % (old_sha1, new_sha1, refname,
-                                        ' '.join(self._send_capabilities)))
+                                        ' '.join(capabilities)))
                     sent_capabilities = True
                     sent_capabilities = True
             if new_sha1 not in have and new_sha1 != ZERO_SHA:
             if new_sha1 not in have and new_sha1 != ZERO_SHA:
                 want.append(new_sha1)
                 want.append(new_sha1)
         proto.write_pkt_line(None)
         proto.write_pkt_line(None)
-        if not want:
-            return new_refs
-        objects = generate_pack_contents(have, want)
-        entries, sha = write_pack_objects(proto.write_file(), objects)
+        return (have, want)
+
+    def _handle_receive_pack_tail(self, proto, capabilities, progress):
+        """Handle the tail of a 'git-receive-pack' request.
 
 
-        if 'report-status' in self._send_capabilities:
-            self._parse_status_report(proto)
+        :param proto: Protocol object to read from
+        :param capabilities: List of negotiated capabilities
+        :param progress: Optional progress reporting function
+        """
+        if 'report-status' in capabilities:
+            report_status_parser = ReportStatusParser()
+        else:
+            report_status_parser = None
+        if "side-band-64k" in capabilities:
+            channel_callbacks = { 2: progress }
+            if 'report-status' in capabilities:
+                channel_callbacks[1] = PktLineParser(
+                    report_status_parser.handle_packet).parse
+            self._read_side_band64k_data(proto, channel_callbacks)
+        else:
+            if 'report-status':
+                for pkt in proto.read_pkt_seq():
+                    report_status_parser.handle_packet(pkt)
+        if report_status_parser is not None:
+            report_status_parser.check()
         # wait for EOF before returning
         # wait for EOF before returning
         data = proto.read()
         data = proto.read()
         if data:
         if data:
             raise SendPackError('Unexpected response %r' % data)
             raise SendPackError('Unexpected response %r' % data)
-        return new_refs
-
-    def fetch(self, path, target, determine_wants=None, progress=None):
-        """Fetch into a target repository.
-
-        :param path: Path to fetch from
-        :param target: Target repository to fetch into
-        :param determine_wants: Optional function to determine what refs
-            to fetch
-        :param progress: Optional progress function
-        :return: remote refs
-        """
-        if determine_wants is None:
-            determine_wants = target.object_store.determine_wants_all
-        f, commit = target.object_store.add_pack()
-        try:
-            return self.fetch_pack(path, determine_wants,
-                target.get_graph_walker(), f.write, progress)
-        finally:
-            commit()
 
 
-    def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
-                   progress):
-        """Retrieve a pack from a git smart server.
+    def _handle_upload_pack_head(self, proto, capabilities, graph_walker,
+                                 wants, can_read):
+        """Handle the head of a 'git-upload-pack' request.
 
 
-        :param determine_wants: Callback that returns list of commits to fetch
-        :param graph_walker: Object with next() and ack().
-        :param pack_data: Callback called for each bit of data in the pack
-        :param progress: Callback for progress reports (strings)
+        :param proto: Protocol object to read from
+        :param capabilities: List of negotiated capabilities
+        :param graph_walker: GraphWalker instance to call .ack() on
+        :param wants: List of commits to fetch
+        :param can_read: function that returns a boolean that indicates
+            whether there is extra graph data to read on proto
         """
         """
-        proto, can_read = self._connect('upload-pack', path)
-        (refs, server_capabilities) = self._read_refs(proto)
-        wants = determine_wants(refs)
-        if not wants:
-            proto.write_pkt_line(None)
-            return refs
         assert isinstance(wants, list) and type(wants[0]) == str
         assert isinstance(wants, list) and type(wants[0]) == str
         proto.write_pkt_line('want %s %s\n' % (
         proto.write_pkt_line('want %s %s\n' % (
-            wants[0], ' '.join(self._fetch_capabilities)))
+            wants[0], ' '.join(capabilities)))
         for want in wants[1:]:
         for want in wants[1:]:
             proto.write_pkt_line('want %s\n' % want)
             proto.write_pkt_line('want %s\n' % want)
         proto.write_pkt_line(None)
         proto.write_pkt_line(None)
@@ -242,33 +338,122 @@ class GitClient(object):
                 parts = pkt.rstrip('\n').split(' ')
                 parts = pkt.rstrip('\n').split(' ')
                 if parts[0] == 'ACK':
                 if parts[0] == 'ACK':
                     graph_walker.ack(parts[1])
                     graph_walker.ack(parts[1])
-                    assert parts[2] == 'continue'
+                    if parts[2] in ('continue', 'common'):
+                        pass
+                    elif parts[2] == 'ready':
+                        break
+                    else:
+                        raise AssertionError(
+                            "%s not in ('continue', 'ready', 'common)" %
+                            parts[2])
             have = graph_walker.next()
             have = graph_walker.next()
         proto.write_pkt_line('done\n')
         proto.write_pkt_line('done\n')
+
+    def _handle_upload_pack_tail(self, proto, capabilities, graph_walker,
+                                 pack_data, progress):
+        """Handle the tail of a 'git-upload-pack' request.
+
+        :param proto: Protocol object to read from
+        :param capabilities: List of negotiated capabilities
+        :param graph_walker: GraphWalker instance to call .ack() on
+        :param pack_data: Function to call with pack data
+        :param progress: Optional progress reporting function
+        """
         pkt = proto.read_pkt_line()
         pkt = proto.read_pkt_line()
         while pkt:
         while pkt:
             parts = pkt.rstrip('\n').split(' ')
             parts = pkt.rstrip('\n').split(' ')
             if parts[0] == 'ACK':
             if parts[0] == 'ACK':
                 graph_walker.ack(pkt.split(' ')[1])
                 graph_walker.ack(pkt.split(' ')[1])
-            if len(parts) < 3 or parts[2] != 'continue':
+            if len(parts) < 3 or parts[2] not in (
+                    'ready', 'continue', 'common'):
                 break
                 break
             pkt = proto.read_pkt_line()
             pkt = proto.read_pkt_line()
-        # TODO(durin42): this is broken if the server didn't support the
-        # side-band-64k capability.
-        for pkt in proto.read_pkt_seq():
-            channel = ord(pkt[0])
-            pkt = pkt[1:]
-            if channel == 1:
-                pack_data(pkt)
-            elif channel == 2:
-                if progress is not None:
-                    progress(pkt)
-            else:
-                raise AssertionError('Invalid sideband channel %d' % channel)
+        if "side-band-64k" in capabilities:
+            self._read_side_band64k_data(proto, {1: pack_data, 2: progress})
+            # wait for EOF before returning
+            data = proto.read()
+            if data:
+                raise Exception('Unexpected response %r' % data)
+        else:
+            # FIXME: Buffering?
+            pack_data(self.read())
+
+
+
+class TraditionalGitClient(GitClient):
+    """Traditional Git client."""
+
+    def _connect(self, cmd, path):
+        """Create a connection to the server.
+
+        This method is abstract - concrete implementations should
+        implement their own variant which connects to the server and
+        returns an initialized Protocol object with the service ready
+        for use and a can_read function which may be used to see if
+        reads would block.
+
+        :param cmd: The git service name to which we should connect.
+        :param path: The path we should pass to the service.
+        """
+        raise NotImplementedError()
+
+    def send_pack(self, path, determine_wants, generate_pack_contents,
+                  progress=None):
+        """Upload a pack to a remote repository.
+
+        :param path: Repository path
+        :param generate_pack_contents: Function that can return a sequence of the
+            shas of the objects to upload.
+        :param progress: Optional callback called with progress updates
+
+        :raises SendPackError: if server rejects the pack data
+        :raises UpdateRefsError: if the server supports report-status
+                                 and rejects ref updates
+        """
+        proto, unused_can_read = self._connect('receive-pack', path)
+        old_refs, server_capabilities = self._read_refs(proto)
+        negotiated_capabilities = list(self._send_capabilities)
+        if 'report-status' not in server_capabilities:
+            negotiated_capabilities.remove('report-status')
+        new_refs = determine_wants(old_refs)
+        if new_refs is None:
+            proto.write_pkt_line(None)
+            return old_refs
+        (have, want) = self._handle_receive_pack_head(proto,
+            negotiated_capabilities, old_refs, new_refs)
+        if not want and old_refs == new_refs:
+            return new_refs
+        objects = generate_pack_contents(have, want)
+        if len(objects) > 0:
+            entries, sha = write_pack_objects(proto.write_file(), objects)
+        self._handle_receive_pack_tail(proto, negotiated_capabilities,
+            progress)
+        return new_refs
+
+    def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
+                   progress=None):
+        """Retrieve a pack from a git smart server.
+
+        :param determine_wants: Callback that returns list of commits to fetch
+        :param graph_walker: Object with next() and ack().
+        :param pack_data: Callback called for each bit of data in the pack
+        :param progress: Callback for progress reports (strings)
+        """
+        proto, can_read = self._connect('upload-pack', path)
+        (refs, server_capabilities) = self._read_refs(proto)
+        negotiated_capabilities = list(self._fetch_capabilities)
+        wants = determine_wants(refs)
+        if not wants:
+            proto.write_pkt_line(None)
+            return refs
+        self._handle_upload_pack_head(proto, negotiated_capabilities,
+            graph_walker, wants, can_read)
+        self._handle_upload_pack_tail(proto, negotiated_capabilities,
+            graph_walker, pack_data, progress)
         return refs
         return refs
 
 
 
 
-class TCPGitClient(GitClient):
+class TCPGitClient(TraditionalGitClient):
     """A Git Client that works over TCP directly (i.e. git://)."""
     """A Git Client that works over TCP directly (i.e. git://)."""
 
 
     def __init__(self, host, port=None, *args, **kwargs):
     def __init__(self, host, port=None, *args, **kwargs):
@@ -330,7 +515,7 @@ class SubprocessWrapper(object):
         self.proc.wait()
         self.proc.wait()
 
 
 
 
-class SubprocessGitClient(GitClient):
+class SubprocessGitClient(TraditionalGitClient):
     """Git client that talks to a server using a subprocess."""
     """Git client that talks to a server using a subprocess."""
 
 
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
@@ -367,7 +552,7 @@ class SSHVendor(object):
 get_ssh_vendor = SSHVendor
 get_ssh_vendor = SSHVendor
 
 
 
 
-class SSHGitClient(GitClient):
+class SSHGitClient(TraditionalGitClient):
 
 
     def __init__(self, host, port=None, username=None, *args, **kwargs):
     def __init__(self, host, port=None, username=None, *args, **kwargs):
         self.host = host
         self.host = host
@@ -387,6 +572,136 @@ class SSHGitClient(GitClient):
                 con.can_read)
                 con.can_read)
 
 
 
 
+class HttpGitClient(GitClient):
+
+    def __init__(self, base_url, dumb=None, *args, **kwargs):
+        self.base_url = base_url.rstrip("/") + "/"
+        self.dumb = dumb
+        GitClient.__init__(self, *args, **kwargs)
+
+    def _get_url(self, path):
+        return urlparse.urljoin(self.base_url, path).rstrip("/") + "/"
+
+    def _perform(self, req):
+        """Perform a HTTP request.
+
+        This is provided so subclasses can provide their own version.
+
+        :param req: urllib2.Request instance
+        :return: matching response
+        """
+        return urllib2.urlopen(req)
+
+    def _discover_references(self, service, url):
+        assert url[-1] == "/"
+        url = urlparse.urljoin(url, "info/refs")
+        headers = {}
+        if self.dumb != False:
+            url += "?service=%s" % service
+            headers["Content-Type"] = "application/x-%s-request" % service
+        req = urllib2.Request(url, headers=headers)
+        resp = self._perform(req)
+        if resp.getcode() == 404:
+            raise NotGitRepository()
+        if resp.getcode() != 200:
+            raise GitProtocolError("unexpected http response %d" %
+                resp.getcode())
+        self.dumb = (not resp.info().gettype().startswith("application/x-git-"))
+        proto = Protocol(resp.read, None)
+        if not self.dumb:
+            # The first line should mention the service
+            pkts = list(proto.read_pkt_seq())
+            if pkts != [('# service=%s\n' % service)]:
+                raise GitProtocolError(
+                    "unexpected first line %r from smart server" % pkts)
+        return self._read_refs(proto)
+
+    def _smart_request(self, service, url, data):
+        assert url[-1] == "/"
+        url = urlparse.urljoin(url, service)
+        req = urllib2.Request(url,
+            headers={"Content-Type": "application/x-%s-request" % service},
+            data=data)
+        resp = self._perform(req)
+        if resp.getcode() == 404:
+            raise NotGitRepository()
+        if resp.getcode() != 200:
+            raise GitProtocolError("Invalid HTTP response from server: %d"
+                % resp.getcode())
+        if resp.info().gettype() != ("application/x-%s-result" % service):
+            raise GitProtocolError("Invalid content-type from server: %s"
+                % resp.info().gettype())
+        return resp
+
+    def send_pack(self, path, determine_wants, generate_pack_contents,
+                  progress=None):
+        """Upload a pack to a remote repository.
+
+        :param path: Repository path
+        :param generate_pack_contents: Function that can return a sequence of the
+            shas of the objects to upload.
+        :param progress: Optional progress function
+
+        :raises SendPackError: if server rejects the pack data
+        :raises UpdateRefsError: if the server supports report-status
+                                 and rejects ref updates
+        """
+        url = self._get_url(path)
+        old_refs, server_capabilities = self._discover_references(
+            "git-receive-pack", url)
+        negotiated_capabilities = list(self._send_capabilities)
+        new_refs = determine_wants(old_refs)
+        if new_refs is None:
+            return old_refs
+        if self.dumb:
+            raise NotImplementedError(self.fetch_pack)
+        req_data = StringIO()
+        req_proto = Protocol(None, req_data.write)
+        (have, want) = self._handle_receive_pack_head(
+            req_proto, negotiated_capabilities, old_refs, new_refs)
+        if not want and old_refs == new_refs:
+            return new_refs
+        objects = generate_pack_contents(have, want)
+        if len(objects) > 0:
+            entries, sha = write_pack_objects(req_proto.write_file(), objects)
+        resp = self._smart_request("git-receive-pack", url,
+            data=req_data.getvalue())
+        resp_proto = Protocol(resp.read, None)
+        self._handle_receive_pack_tail(resp_proto, negotiated_capabilities,
+            progress)
+        return new_refs
+
+    def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
+                   progress):
+        """Retrieve a pack from a git smart server.
+
+        :param determine_wants: Callback that returns list of commits to fetch
+        :param graph_walker: Object with next() and ack().
+        :param pack_data: Callback called for each bit of data in the pack
+        :param progress: Callback for progress reports (strings)
+        """
+        url = self._get_url(path)
+        refs, server_capabilities = self._discover_references(
+            "git-upload-pack", url)
+        negotiated_capabilities = list(server_capabilities)
+        wants = determine_wants(refs)
+        if not wants:
+            return refs
+        if self.dumb:
+            raise NotImplementedError(self.send_pack)
+        req_data = StringIO()
+        req_proto = Protocol(None, req_data.write)
+        self._handle_upload_pack_head(req_proto,
+            negotiated_capabilities, graph_walker, wants,
+            lambda: False)
+        resp = self._smart_request("git-upload-pack", url,
+            data=req_data.getvalue())
+        resp_proto = Protocol(resp.read, None)
+        self._handle_upload_pack_tail(resp_proto, negotiated_capabilities,
+            graph_walker, pack_data, progress)
+        return refs
+
+
 def get_transport_and_path(uri):
 def get_transport_and_path(uri):
     """Obtain a git client from a URI or path.
     """Obtain a git client from a URI or path.
 
 
@@ -399,6 +714,9 @@ def get_transport_and_path(uri):
     elif parsed.scheme == 'git+ssh':
     elif parsed.scheme == 'git+ssh':
         return SSHGitClient(parsed.hostname, port=parsed.port,
         return SSHGitClient(parsed.hostname, port=parsed.port,
                             username=parsed.username), parsed.path
                             username=parsed.username), parsed.path
+    elif parsed.scheme in ('http', 'https'):
+        return HttpGitClient(urlparse.urlunparse(
+            parsed.scheme, parsed.netloc, path='/'))
 
 
     if parsed.scheme and not parsed.netloc:
     if parsed.scheme and not parsed.netloc:
         # SSH with no user@, zero or one leading slash.
         # SSH with no user@, zero or one leading slash.

+ 19 - 16
dulwich/diff_tree.py

@@ -405,26 +405,29 @@ class RenameDetector(object):
         new_obj = self._store[change.new.sha]
         new_obj = self._store[change.new.sha]
         return _similarity_score(old_obj, new_obj) < self._rewrite_threshold
         return _similarity_score(old_obj, new_obj) < self._rewrite_threshold
 
 
+    def _add_change(self, change):
+        if change.type == CHANGE_ADD:
+            self._adds.append(change)
+        elif change.type == CHANGE_DELETE:
+            self._deletes.append(change)
+        elif self._should_split(change):
+            self._deletes.append(TreeChange.delete(change.old))
+            self._adds.append(TreeChange.add(change.new))
+        elif ((self._find_copies_harder and change.type == CHANGE_UNCHANGED)
+              or change.type == CHANGE_MODIFY):
+            # Treat all modifies as potential deletes for rename detection,
+            # but don't split them (to avoid spurious renames). Setting
+            # find_copies_harder means we treat unchanged the same as
+            # modified.
+            self._deletes.append(change)
+        else:
+            self._changes.append(change)
+
     def _collect_changes(self, tree1_id, tree2_id):
     def _collect_changes(self, tree1_id, tree2_id):
         want_unchanged = self._find_copies_harder or self._want_unchanged
         want_unchanged = self._find_copies_harder or self._want_unchanged
         for change in tree_changes(self._store, tree1_id, tree2_id,
         for change in tree_changes(self._store, tree1_id, tree2_id,
                                    want_unchanged=want_unchanged):
                                    want_unchanged=want_unchanged):
-            if change.type == CHANGE_ADD:
-                self._adds.append(change)
-            elif change.type == CHANGE_DELETE:
-                self._deletes.append(change)
-            elif self._should_split(change):
-                self._deletes.append(TreeChange.delete(change.old))
-                self._adds.append(TreeChange.add(change.new))
-            elif ((self._find_copies_harder and change.type == CHANGE_UNCHANGED)
-                  or change.type == CHANGE_MODIFY):
-                # Treat all modifies as potential deletes for rename detection,
-                # but don't split them (to avoid spurious renames). Setting
-                # find_copies_harder means we treat unchanged the same as
-                # modified.
-                self._deletes.append(change)
-            else:
-                self._changes.append(change)
+            self._add_change(change)
 
 
     def _prune(self, add_paths, delete_paths):
     def _prune(self, add_paths, delete_paths):
         self._adds = [a for a in self._adds if a.new.path not in add_paths]
         self._adds = [a for a in self._adds if a.new.path not in add_paths]

+ 66 - 0
dulwich/object_store.py

@@ -226,6 +226,10 @@ class PackBasedObjectStore(BaseObjectStore):
     def __init__(self):
     def __init__(self):
         self._pack_cache = None
         self._pack_cache = None
 
 
+    @property
+    def alternates(self):
+        return []
+
     def contains_packed(self, sha):
     def contains_packed(self, sha):
         """Check if a particular object is present by SHA1 and is packed."""
         """Check if a particular object is present by SHA1 and is packed."""
         for pack in self.packs:
         for pack in self.packs:
@@ -310,6 +314,11 @@ class PackBasedObjectStore(BaseObjectStore):
         ret = self._get_loose_object(hexsha)
         ret = self._get_loose_object(hexsha)
         if ret is not None:
         if ret is not None:
             return ret.type_num, ret.as_raw_string()
             return ret.type_num, ret.as_raw_string()
+        for alternate in self.alternates:
+            try:
+                return alternate.get_raw(hexsha)
+            except KeyError:
+                pass
         raise KeyError(hexsha)
         raise KeyError(hexsha)
 
 
     def add_objects(self, objects):
     def add_objects(self, objects):
@@ -338,6 +347,63 @@ class DiskObjectStore(PackBasedObjectStore):
         self.path = path
         self.path = path
         self.pack_dir = os.path.join(self.path, PACKDIR)
         self.pack_dir = os.path.join(self.path, PACKDIR)
         self._pack_cache_time = 0
         self._pack_cache_time = 0
+        self._alternates = None
+
+    @property
+    def alternates(self):
+        if self._alternates is not None:
+            return self._alternates
+        self._alternates = []
+        for path in self._read_alternate_paths():
+            self._alternates.append(DiskObjectStore(path))
+        return self._alternates
+
+    def _read_alternate_paths(self):
+        try:
+            f = GitFile(os.path.join(self.path, "info", "alternates"),
+                    'rb')
+        except (OSError, IOError), e:
+            if e.errno == errno.ENOENT:
+                return []
+            raise
+        ret = []
+        try:
+            for l in f.readlines():
+                l = l.rstrip("\n")
+                if l[0] == "#":
+                    continue
+                if not os.path.isabs(l):
+                    continue
+                ret.append(l)
+            return ret
+        finally:
+            f.close()
+
+    def add_alternate_path(self, path):
+        """Add an alternate path to this object store.
+        """
+        try:
+            os.mkdir(os.path.join(self.path, "info"))
+        except OSError, e:
+            if e.errno != errno.EEXIST:
+                raise
+        alternates_path = os.path.join(self.path, "info/alternates")
+        f = GitFile(alternates_path, 'wb')
+        try:
+            try:
+                orig_f = open(alternates_path, 'rb')
+            except (OSError, IOError), e:
+                if e.errno != errno.ENOENT:
+                    raise
+            else:
+                try:
+                    f.write(orig_f.read())
+                finally:
+                    orig_f.close()
+            f.write("%s\n" % path)
+        finally:
+            f.close()
+        self.alternates.append(DiskObjectStore(path))
 
 
     def _load_packs(self):
     def _load_packs(self):
         pack_files = []
         pack_files = []

+ 4 - 0
dulwich/objects.py

@@ -693,6 +693,8 @@ class TreeEntry(namedtuple('TreeEntry', ['path', 'mode', 'sha'])):
 
 
     def in_path(self, path):
     def in_path(self, path):
         """Return a copy of this entry with the given path prepended."""
         """Return a copy of this entry with the given path prepended."""
+        if type(self.path) != str:
+            raise TypeError
         return TreeEntry(posixpath.join(path, self.path), self.mode, self.sha)
         return TreeEntry(posixpath.join(path, self.path), self.mode, self.sha)
 
 
 
 
@@ -747,6 +749,8 @@ def sorted_tree_items(entries, name_order):
     for name, entry in sorted(entries.iteritems(), cmp=cmp_func):
     for name, entry in sorted(entries.iteritems(), cmp=cmp_func):
         mode, hexsha = entry
         mode, hexsha = entry
         # Stricter type checks than normal to mirror checks in the C version.
         # Stricter type checks than normal to mirror checks in the C version.
+        if not isinstance(mode, int) and not isinstance(mode, long):
+            raise TypeError('Expected integer/long for mode, got %r' % mode)
         mode = int(mode)
         mode = int(mode)
         if not isinstance(hexsha, str):
         if not isinstance(hexsha, str):
             raise TypeError('Expected a string for SHA, got %r' % hexsha)
             raise TypeError('Expected a string for SHA, got %r' % hexsha)

+ 33 - 0
dulwich/protocol.py

@@ -406,3 +406,36 @@ class BufferedPktLineWriter(object):
             self._write(data)
             self._write(data)
         self._len = 0
         self._len = 0
         self._wbuf = StringIO()
         self._wbuf = StringIO()
+
+
+class PktLineParser(object):
+    """Packet line parser that hands completed packets off to a callback.
+    """
+
+    def __init__(self, handle_pkt):
+        self.handle_pkt = handle_pkt
+        self._readahead = StringIO()
+
+    def parse(self, data):
+        """Parse a fragment of data and call back for any completed packets.
+        """
+        self._readahead.write(data)
+        buf = self._readahead.getvalue()
+        if len(buf) < 4:
+            return
+        while len(buf) >= 4:
+            size = int(buf[:4], 16)
+            if size == 0:
+                self.handle_pkt(None)
+                buf = buf[4:]
+            elif size <= len(buf):
+                self.handle_pkt(buf[4:size])
+                buf = buf[size:]
+            else:
+                break
+        self._readahead = StringIO()
+        self._readahead.write(buf)
+
+    def get_tail(self):
+        """Read back any unused data."""
+        return self._readahead.getvalue()

+ 79 - 16
dulwich/repo.py

@@ -26,7 +26,6 @@ import errno
 import os
 import os
 
 
 from dulwich.errors import (
 from dulwich.errors import (
-    MissingCommitError,
     NoIndexPresent,
     NoIndexPresent,
     NotBlobError,
     NotBlobError,
     NotCommitError,
     NotCommitError,
@@ -53,9 +52,6 @@ from dulwich.objects import (
     Tree,
     Tree,
     hex_to_sha,
     hex_to_sha,
     )
     )
-from dulwich.walk import (
-    Walker,
-    )
 import warnings
 import warnings
 
 
 
 
@@ -214,7 +210,7 @@ class RefsContainer(object):
         :param name: The name of the reference.
         :param name: The name of the reference.
         :raises KeyError: if a refname is not HEAD or is otherwise not valid.
         :raises KeyError: if a refname is not HEAD or is otherwise not valid.
         """
         """
-        if name == 'HEAD':
+        if name in ('HEAD', 'refs/stash'):
             return
             return
         if not name.startswith('refs/') or not check_ref_format(name[5:]):
         if not name.startswith('refs/') or not check_ref_format(name[5:]):
             raise RefFormatError(name)
             raise RefFormatError(name)
@@ -954,6 +950,35 @@ class BaseRepo(object):
             return cached
             return cached
         return self.object_store.peel_sha(self.refs[ref]).id
         return self.object_store.peel_sha(self.refs[ref]).id
 
 
+    def get_walker(self, include=None, *args, **kwargs):
+        """Obtain a walker for this repository.
+
+        :param include: Iterable of SHAs of commits to include along with their
+            ancestors. Defaults to [HEAD]
+        :param exclude: Iterable of SHAs of commits to exclude along with their
+            ancestors, overriding includes.
+        :param order: ORDER_* constant specifying the order of results. Anything
+            other than ORDER_DATE may result in O(n) memory usage.
+        :param reverse: If True, reverse the order of output, requiring O(n)
+            memory.
+        :param max_entries: The maximum number of entries to yield, or None for
+            no limit.
+        :param paths: Iterable of file or subtree paths to show entries for.
+        :param rename_detector: diff.RenameDetector object for detecting
+            renames.
+        :param follow: If True, follow path across renames/copies. Forces a
+            default rename_detector.
+        :param since: Timestamp to list commits after.
+        :param until: Timestamp to list commits before.
+        :param queue_cls: A class to use for a queue of commits, supporting the
+            iterator protocol. The constructor takes a single argument, the
+            Walker.
+        """
+        from dulwich.walk import Walker
+        if include is None:
+            include = [self.head()]
+        return Walker(self.object_store, include, *args, **kwargs)
+
     def revision_history(self, head):
     def revision_history(self, head):
         """Returns a list of the commits reachable from head.
         """Returns a list of the commits reachable from head.
 
 
@@ -963,9 +988,10 @@ class BaseRepo(object):
         :raise MissingCommitError: if any missing commits are referenced,
         :raise MissingCommitError: if any missing commits are referenced,
             including if the head parameter isn't the SHA of a commit.
             including if the head parameter isn't the SHA of a commit.
         """
         """
-        # TODO(dborowitz): Expose more of the Walker functionality here or in a
-        # separate Repo/BaseObjectStore method.
-        return [e.commit for e in Walker(self.object_store, [head])]
+        warnings.warn("Repo.revision_history() is deprecated."
+            "Use dulwich.walker.Walker(repo) instead.",
+            category=DeprecationWarning, stacklevel=2)
+        return [e.commit for e in self.get_walker(include=[head])]
 
 
     def __getitem__(self, name):
     def __getitem__(self, name):
         if len(name) in (20, 40):
         if len(name) in (20, 40):
@@ -978,6 +1004,9 @@ class BaseRepo(object):
         except RefFormatError:
         except RefFormatError:
             raise KeyError(name)
             raise KeyError(name)
 
 
+    def __iter__(self):
+        raise NotImplementedError(self.__iter__)
+
     def __contains__(self, name):
     def __contains__(self, name):
         if len(name) in (20, 40):
         if len(name) in (20, 40):
             return name in self.object_store or name in self.refs
             return name in self.object_store or name in self.refs
@@ -1001,10 +1030,11 @@ class BaseRepo(object):
         else:
         else:
             raise ValueError(name)
             raise ValueError(name)
 
 
-    def do_commit(self, message, committer=None,
+    def do_commit(self, message=None, committer=None,
                   author=None, commit_timestamp=None,
                   author=None, commit_timestamp=None,
                   commit_timezone=None, author_timestamp=None,
                   commit_timezone=None, author_timestamp=None,
-                  author_timezone=None, tree=None, encoding=None):
+                  author_timezone=None, tree=None, encoding=None,
+                  ref='HEAD', merge_heads=None):
         """Create a new commit.
         """Create a new commit.
 
 
         :param message: Commit message
         :param message: Commit message
@@ -1018,6 +1048,8 @@ class BaseRepo(object):
         :param tree: SHA1 of the tree root to use (if not specified the
         :param tree: SHA1 of the tree root to use (if not specified the
             current index will be committed).
             current index will be committed).
         :param encoding: Encoding
         :param encoding: Encoding
+        :param ref: Optional ref to commit to (defaults to current branch)
+        :param merge_heads: Merge heads (defaults to .git/MERGE_HEADS)
         :return: New commit SHA1
         :return: New commit SHA1
         """
         """
         import time
         import time
@@ -1029,6 +1061,9 @@ class BaseRepo(object):
             if len(tree) != 40:
             if len(tree) != 40:
                 raise ValueError("tree must be a 40-byte hex sha string")
                 raise ValueError("tree must be a 40-byte hex sha string")
             c.tree = tree
             c.tree = tree
+        if merge_heads is None:
+            # FIXME: Read merge heads from .git/MERGE_HEADS
+            merge_heads = []
         # TODO: Allow username to be missing, and get it from .git/config
         # TODO: Allow username to be missing, and get it from .git/config
         if committer is None:
         if committer is None:
             raise ValueError("committer not set")
             raise ValueError("committer not set")
@@ -1051,20 +1086,23 @@ class BaseRepo(object):
         c.author_timezone = author_timezone
         c.author_timezone = author_timezone
         if encoding is not None:
         if encoding is not None:
             c.encoding = encoding
             c.encoding = encoding
+        if message is None:
+            # FIXME: Try to read commit message from .git/MERGE_MSG
+            raise ValueError("No commit message specified")
         c.message = message
         c.message = message
         try:
         try:
-            old_head = self.refs["HEAD"]
-            c.parents = [old_head]
+            old_head = self.refs[ref]
+            c.parents = [old_head] + merge_heads
             self.object_store.add_object(c)
             self.object_store.add_object(c)
-            ok = self.refs.set_if_equals("HEAD", old_head, c.id)
+            ok = self.refs.set_if_equals(ref, old_head, c.id)
         except KeyError:
         except KeyError:
-            c.parents = []
+            c.parents = merge_heads
             self.object_store.add_object(c)
             self.object_store.add_object(c)
-            ok = self.refs.add_if_new("HEAD", c.id)
+            ok = self.refs.add_if_new(ref, c.id)
         if not ok:
         if not ok:
             # Fail if the atomic compare-and-swap failed, leaving the commit and
             # Fail if the atomic compare-and-swap failed, leaving the commit and
             # all its objects as garbage.
             # all its objects as garbage.
-            raise CommitError("HEAD changed during commit")
+            raise CommitError("%s changed during commit" % (ref,))
 
 
         return c.id
         return c.id
 
 
@@ -1173,6 +1211,31 @@ class Repo(BaseRepo):
                     blob.id, 0)
                     blob.id, 0)
         index.write()
         index.write()
 
 
+    def clone(self, target_path, mkdir=True, bare=False, origin="origin"):
+        """Clone this repository.
+
+        :param target_path: Target path
+        :param mkdir: Create the target directory
+        :param bare: Whether to create a bare repository
+        :return: Created repository
+        """
+        if not bare:
+            target = self.init(target_path, mkdir=mkdir)
+        else:
+            target = self.init_bare(target_path)
+        self.fetch(target)
+        target.refs.import_refs(
+            'refs/remotes/'+origin, self.refs.as_dict('refs/heads'))
+        target.refs.import_refs(
+            'refs/tags', self.refs.as_dict('refs/tags'))
+        try:
+            target.refs.add_if_new(
+                'refs/heads/master',
+                self.refs['refs/heads/master'])
+        except KeyError:
+            pass
+        return target
+
     def __repr__(self):
     def __repr__(self):
         return "<Repo at %r>" % self.path
         return "<Repo at %r>" % self.path
 
 

+ 3 - 11
dulwich/tests/__init__.py

@@ -30,18 +30,10 @@ import tempfile
 if sys.version_info >= (2, 7):
 if sys.version_info >= (2, 7):
     # If Python itself provides an exception, use that
     # If Python itself provides an exception, use that
     import unittest
     import unittest
-    from unittest import SkipTest as TestSkipped
-    from unittest import TestCase
+    from unittest import SkipTest, TestCase
 else:
 else:
-    try:
-        import unittest2 as unittest
-        from unittest2 import SkipTest as TestSkipped
-        from unittest2 import TestCase
-    except ImportError:
-        import unittest
-        from testtools.testcase import TestSkipped
-        from testtools.testcase import TestCase
-        TestCase.skipException = TestSkipped
+    import unittest2 as unittest
+    from unittest2 import SkipTest, TestCase
 
 
 
 
 class BlackboxTestCase(TestCase):
 class BlackboxTestCase(TestCase):

+ 2 - 14
dulwich/tests/compat/server_utils.py

@@ -46,23 +46,11 @@ class ServerTests(object):
     Does not inherit from TestCase so tests are not automatically run.
     Does not inherit from TestCase so tests are not automatically run.
     """
     """
 
 
-    def setUp(self):
-        self._old_repo = None
-        self._new_repo = None
-        self._server = None
-
-    def tearDown(self):
-        if self._server is not None:
-            self._server.shutdown()
-            self._server = None
-        if self._old_repo is not None:
-            tear_down_repo(self._old_repo)
-        if self._new_repo is not None:
-            tear_down_repo(self._new_repo)
-
     def import_repos(self):
     def import_repos(self):
         self._old_repo = import_repo('server_old.export')
         self._old_repo = import_repo('server_old.export')
+        self.addCleanup(tear_down_repo, self._old_repo)
         self._new_repo = import_repo('server_new.export')
         self._new_repo = import_repo('server_new.export')
+        self.addCleanup(tear_down_repo, self._new_repo)
 
 
     def url(self, port):
     def url(self, port):
         return '%s://localhost:%s/' % (self.protocol, port)
         return '%s://localhost:%s/' % (self.protocol, port)

+ 209 - 12
dulwich/tests/compat/test_client.py

@@ -19,11 +19,17 @@
 
 
 """Compatibilty tests between the Dulwich client and the cgit server."""
 """Compatibilty tests between the Dulwich client and the cgit server."""
 
 
+import BaseHTTPServer
+import SimpleHTTPServer
+import copy
 import os
 import os
+import select
 import shutil
 import shutil
 import signal
 import signal
 import subprocess
 import subprocess
 import tempfile
 import tempfile
+import threading
+import urllib
 
 
 from dulwich import (
 from dulwich import (
     client,
     client,
@@ -35,7 +41,7 @@ from dulwich import (
     repo,
     repo,
     )
     )
 from dulwich.tests import (
 from dulwich.tests import (
-    TestSkipped,
+    SkipTest,
     )
     )
 
 
 from dulwich.tests.compat.utils import (
 from dulwich.tests.compat.utils import (
@@ -44,6 +50,9 @@ from dulwich.tests.compat.utils import (
     import_repo_to_dir,
     import_repo_to_dir,
     run_git_or_fail,
     run_git_or_fail,
     )
     )
+from dulwich.tests.compat.server_utils import (
+    ShutdownServerMixIn,
+    )
 
 
 
 
 class DulwichClientTestBase(object):
 class DulwichClientTestBase(object):
@@ -51,9 +60,9 @@ class DulwichClientTestBase(object):
 
 
     def setUp(self):
     def setUp(self):
         self.gitroot = os.path.dirname(import_repo_to_dir('server_new.export'))
         self.gitroot = os.path.dirname(import_repo_to_dir('server_new.export'))
-        dest = os.path.join(self.gitroot, 'dest')
-        file.ensure_dir_exists(dest)
-        run_git_or_fail(['init', '--quiet', '--bare'], cwd=dest)
+        self.dest = os.path.join(self.gitroot, 'dest')
+        file.ensure_dir_exists(self.dest)
+        run_git_or_fail(['init', '--quiet', '--bare'], cwd=self.dest)
 
 
     def tearDown(self):
     def tearDown(self):
         shutil.rmtree(self.gitroot)
         shutil.rmtree(self.gitroot)
@@ -99,11 +108,7 @@ class DulwichClientTestBase(object):
                     src.object_store.generate_pack_contents)
                     src.object_store.generate_pack_contents)
         self.assertDestEqualsSrc()
         self.assertDestEqualsSrc()
 
 
-    def disable_ff_and_make_dummy_commit(self):
-        # disable non-fast-forward pushes to the server
-        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
-        run_git_or_fail(['config', 'receive.denyNonFastForwards', 'true'],
-                        cwd=dest.path)
+    def make_dummy_commit(self, dest):
         b = objects.Blob.from_string('hi')
         b = objects.Blob.from_string('hi')
         dest.object_store.add_object(b)
         dest.object_store.add_object(b)
         t = index.commit_tree(dest.object_store, [('hi', b.id, 0100644)])
         t = index.commit_tree(dest.object_store, [('hi', b.id, 0100644)])
@@ -114,7 +119,15 @@ class DulwichClientTestBase(object):
         c.message = 'hi'
         c.message = 'hi'
         c.tree = t
         c.tree = t
         dest.object_store.add_object(c)
         dest.object_store.add_object(c)
-        return dest, c.id
+        return c.id
+
+    def disable_ff_and_make_dummy_commit(self):
+        # disable non-fast-forward pushes to the server
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        run_git_or_fail(['config', 'receive.denyNonFastForwards', 'true'],
+                        cwd=dest.path)
+        commit_id = self.make_dummy_commit(dest)
+        return dest, commit_id
 
 
     def compute_send(self):
     def compute_send(self):
         srcpath = os.path.join(self.gitroot, 'server_new.export')
         srcpath = os.path.join(self.gitroot, 'server_new.export')
@@ -168,6 +181,20 @@ class DulwichClientTestBase(object):
         map(lambda r: dest.refs.set_if_equals(r[0], None, r[1]), refs.items())
         map(lambda r: dest.refs.set_if_equals(r[0], None, r[1]), refs.items())
         self.assertDestEqualsSrc()
         self.assertDestEqualsSrc()
 
 
+    def test_send_remove_branch(self):
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        dummy_commit = self.make_dummy_commit(dest)
+        dest.refs['refs/heads/master'] = dummy_commit
+        dest.refs['refs/heads/abranch'] = dummy_commit
+        sendrefs = dict(dest.refs)
+        sendrefs['refs/heads/abranch'] = "00" * 20
+        del sendrefs['HEAD']
+        gen_pack = lambda have, want: []
+        c = self._client()
+        self.assertEquals(dest.refs["refs/heads/abranch"], dummy_commit)
+        c.send_pack(self._build_path('/dest'), lambda _: sendrefs, gen_pack)
+        self.assertFalse("refs/heads/abranch" in dest.refs)
+
 
 
 class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
 class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
 
 
@@ -175,7 +202,7 @@ class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
         CompatTestCase.setUp(self)
         CompatTestCase.setUp(self)
         DulwichClientTestBase.setUp(self)
         DulwichClientTestBase.setUp(self)
         if check_for_daemon(limit=1):
         if check_for_daemon(limit=1):
-            raise TestSkipped('git-daemon was already running on port %s' %
+            raise SkipTest('git-daemon was already running on port %s' %
                               protocol.TCP_GIT_PORT)
                               protocol.TCP_GIT_PORT)
         fd, self.pidfile = tempfile.mkstemp(prefix='dulwich-test-git-client',
         fd, self.pidfile = tempfile.mkstemp(prefix='dulwich-test-git-client',
                                             suffix=".pid")
                                             suffix=".pid")
@@ -186,7 +213,7 @@ class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
              '--detach', '--reuseaddr', '--enable=receive-pack',
              '--detach', '--reuseaddr', '--enable=receive-pack',
              '--listen=localhost', self.gitroot], cwd=self.gitroot)
              '--listen=localhost', self.gitroot], cwd=self.gitroot)
         if not check_for_daemon():
         if not check_for_daemon():
-            raise TestSkipped('git-daemon failed to start')
+            raise SkipTest('git-daemon failed to start')
 
 
     def tearDown(self):
     def tearDown(self):
         try:
         try:
@@ -249,3 +276,173 @@ class DulwichSubprocessClientTest(CompatTestCase, DulwichClientTestBase):
 
 
     def _build_path(self, path):
     def _build_path(self, path):
         return self.gitroot + path
         return self.gitroot + path
+
+
+class GitHTTPRequestHandler(SimpleHTTPServer.SimpleHTTPRequestHandler):
+    """HTTP Request handler that calls out to 'git http-backend'."""
+
+    # Make rfile unbuffered -- we need to read one line and then pass
+    # the rest to a subprocess, so we can't use buffered input.
+    rbufsize = 0
+
+    def do_POST(self):
+        self.run_backend()
+
+    def do_GET(self):
+        self.run_backend()
+
+    def send_head(self):
+        return self.run_backend()
+
+    def log_request(self, code='-', size='-'):
+        # Let's be quiet, the test suite is noisy enough already
+        pass
+
+    def run_backend(self):
+        """Call out to git http-backend."""
+        # Based on CGIHTTPServer.CGIHTTPRequestHandler.run_cgi:
+        # Copyright (c) 2001-2010 Python Software Foundation; All Rights Reserved
+        # Licensed under the Python Software Foundation License.
+        rest = self.path
+        # find an explicit query string, if present.
+        i = rest.rfind('?')
+        if i >= 0:
+            rest, query = rest[:i], rest[i+1:]
+        else:
+            query = ''
+
+        env = copy.deepcopy(os.environ)
+        env['SERVER_SOFTWARE'] = self.version_string()
+        env['SERVER_NAME'] = self.server.server_name
+        env['GATEWAY_INTERFACE'] = 'CGI/1.1'
+        env['SERVER_PROTOCOL'] = self.protocol_version
+        env['SERVER_PORT'] = str(self.server.server_port)
+        env['GIT_PROJECT_ROOT'] = self.server.root_path
+        env["GIT_HTTP_EXPORT_ALL"] = "1"
+        env['REQUEST_METHOD'] = self.command
+        uqrest = urllib.unquote(rest)
+        env['PATH_INFO'] = uqrest
+        env['SCRIPT_NAME'] = "/"
+        if query:
+            env['QUERY_STRING'] = query
+        host = self.address_string()
+        if host != self.client_address[0]:
+            env['REMOTE_HOST'] = host
+        env['REMOTE_ADDR'] = self.client_address[0]
+        authorization = self.headers.getheader("authorization")
+        if authorization:
+            authorization = authorization.split()
+            if len(authorization) == 2:
+                import base64, binascii
+                env['AUTH_TYPE'] = authorization[0]
+                if authorization[0].lower() == "basic":
+                    try:
+                        authorization = base64.decodestring(authorization[1])
+                    except binascii.Error:
+                        pass
+                    else:
+                        authorization = authorization.split(':')
+                        if len(authorization) == 2:
+                            env['REMOTE_USER'] = authorization[0]
+        # XXX REMOTE_IDENT
+        if self.headers.typeheader is None:
+            env['CONTENT_TYPE'] = self.headers.type
+        else:
+            env['CONTENT_TYPE'] = self.headers.typeheader
+        length = self.headers.getheader('content-length')
+        if length:
+            env['CONTENT_LENGTH'] = length
+        referer = self.headers.getheader('referer')
+        if referer:
+            env['HTTP_REFERER'] = referer
+        accept = []
+        for line in self.headers.getallmatchingheaders('accept'):
+            if line[:1] in "\t\n\r ":
+                accept.append(line.strip())
+            else:
+                accept = accept + line[7:].split(',')
+        env['HTTP_ACCEPT'] = ','.join(accept)
+        ua = self.headers.getheader('user-agent')
+        if ua:
+            env['HTTP_USER_AGENT'] = ua
+        co = filter(None, self.headers.getheaders('cookie'))
+        if co:
+            env['HTTP_COOKIE'] = ', '.join(co)
+        # XXX Other HTTP_* headers
+        # Since we're setting the env in the parent, provide empty
+        # values to override previously set values
+        for k in ('QUERY_STRING', 'REMOTE_HOST', 'CONTENT_LENGTH',
+                  'HTTP_USER_AGENT', 'HTTP_COOKIE', 'HTTP_REFERER'):
+            env.setdefault(k, "")
+
+        self.send_response(200, "Script output follows")
+
+        decoded_query = query.replace('+', ' ')
+
+        try:
+            nbytes = int(length)
+        except (TypeError, ValueError):
+            nbytes = 0
+        if self.command.lower() == "post" and nbytes > 0:
+            data = self.rfile.read(nbytes)
+        else:
+            data = None
+        # throw away additional data [see bug #427345]
+        while select.select([self.rfile._sock], [], [], 0)[0]:
+            if not self.rfile._sock.recv(1):
+                break
+        args = ['http-backend']
+        if '=' not in decoded_query:
+            args.append(decoded_query)
+        stdout = run_git_or_fail(args, input=data, env=env, stderr=subprocess.PIPE)
+        self.wfile.write(stdout)
+
+
+class HTTPGitServer(BaseHTTPServer.HTTPServer):
+
+    allow_reuse_address = True
+
+    def __init__(self, server_address, root_path):
+        BaseHTTPServer.HTTPServer.__init__(self, server_address, GitHTTPRequestHandler)
+        self.root_path = root_path
+
+    def get_url(self):
+        return 'http://%s:%s/' % (self.server_name, self.server_port)
+
+
+if not getattr(HTTPGitServer, 'shutdown', None):
+    _HTTPGitServer = HTTPGitServer
+
+    class TCPGitServer(ShutdownServerMixIn, HTTPGitServer):
+        """Subclass of HTTPGitServer that can be shut down."""
+
+        def __init__(self, *args, **kwargs):
+            # BaseServer is old-style so we have to call both __init__s
+            ShutdownServerMixIn.__init__(self)
+            _HTTPGitServer.__init__(self, *args, **kwargs)
+
+
+class DulwichHttpClientTest(CompatTestCase, DulwichClientTestBase):
+
+    min_git_version = (1, 7, 0, 2)
+
+    def setUp(self):
+        CompatTestCase.setUp(self)
+        DulwichClientTestBase.setUp(self)
+        self._httpd = HTTPGitServer(("localhost", 0), self.gitroot)
+        self.addCleanup(self._httpd.shutdown)
+        threading.Thread(target=self._httpd.serve_forever).start()
+        run_git_or_fail(['config', 'http.uploadpack', 'true'],
+                        cwd=self.dest)
+        run_git_or_fail(['config', 'http.receivepack', 'true'],
+                        cwd=self.dest)
+
+    def tearDown(self):
+        DulwichClientTestBase.tearDown(self)
+        CompatTestCase.tearDown(self)
+
+    def _client(self):
+        return client.HttpGitClient(self._httpd.get_url())
+
+    def _build_path(self, path):
+        return path

+ 1 - 8
dulwich/tests/compat/test_server.py

@@ -62,14 +62,6 @@ class GitServerTestCase(ServerTests, CompatTestCase):
 
 
     protocol = 'git'
     protocol = 'git'
 
 
-    def setUp(self):
-        ServerTests.setUp(self)
-        CompatTestCase.setUp(self)
-
-    def tearDown(self):
-        ServerTests.tearDown(self)
-        CompatTestCase.tearDown(self)
-
     def _handlers(self):
     def _handlers(self):
         return {'git-receive-pack': NoSideBand64kReceivePackHandler}
         return {'git-receive-pack': NoSideBand64kReceivePackHandler}
 
 
@@ -83,6 +75,7 @@ class GitServerTestCase(ServerTests, CompatTestCase):
         dul_server = TCPGitServer(backend, 'localhost', 0,
         dul_server = TCPGitServer(backend, 'localhost', 0,
                                   handlers=self._handlers())
                                   handlers=self._handlers())
         self._check_server(dul_server)
         self._check_server(dul_server)
+        self.addCleanup(dul_server.shutdown)
         threading.Thread(target=dul_server.serve).start()
         threading.Thread(target=dul_server.serve).start()
         self._server = dul_server
         self._server = dul_server
         _, port = self._server.socket.getsockname()
         _, port = self._server.socket.getsockname()

+ 5 - 5
dulwich/tests/compat/test_utils.py

@@ -20,8 +20,8 @@
 """Tests for git compatibility utilities."""
 """Tests for git compatibility utilities."""
 
 
 from dulwich.tests import (
 from dulwich.tests import (
+    SkipTest,
     TestCase,
     TestCase,
-    TestSkipped,
     )
     )
 from dulwich.tests.compat import utils
 from dulwich.tests.compat import utils
 
 
@@ -61,11 +61,11 @@ class GitVersionTests(TestCase):
     def assertRequireSucceeds(self, required_version):
     def assertRequireSucceeds(self, required_version):
         try:
         try:
             utils.require_git_version(required_version)
             utils.require_git_version(required_version)
-        except TestSkipped:
+        except SkipTest:
             self.fail()
             self.fail()
 
 
     def assertRequireFails(self, required_version):
     def assertRequireFails(self, required_version):
-        self.assertRaises(TestSkipped, utils.require_git_version,
+        self.assertRaises(SkipTest, utils.require_git_version,
                           required_version)
                           required_version)
 
 
     def test_require_git_version(self):
     def test_require_git_version(self):
@@ -87,6 +87,6 @@ class GitVersionTests(TestCase):
             self.assertRequireSucceeds((1, 7, 0, 2))
             self.assertRequireSucceeds((1, 7, 0, 2))
             self.assertRequireFails((1, 7, 0, 3))
             self.assertRequireFails((1, 7, 0, 3))
             self.assertRequireFails((1, 7, 1))
             self.assertRequireFails((1, 7, 1))
-        except TestSkipped, e:
-            # This test is designed to catch all TestSkipped exceptions.
+        except SkipTest, e:
+            # This test is designed to catch all SkipTest exceptions.
             self.fail('Test unexpectedly skipped: %s' % e)
             self.fail('Test unexpectedly skipped: %s' % e)

+ 3 - 18
dulwich/tests/compat/test_web.py

@@ -31,7 +31,7 @@ from dulwich.server import (
     DictBackend,
     DictBackend,
     )
     )
 from dulwich.tests import (
 from dulwich.tests import (
-    TestSkipped,
+    SkipTest,
     )
     )
 from dulwich.web import (
 from dulwich.web import (
     HTTPGitApplication,
     HTTPGitApplication,
@@ -77,6 +77,7 @@ class WebTests(ServerTests):
         dul_server = simple_server.make_server(
         dul_server = simple_server.make_server(
           'localhost', 0, app, server_class=WSGIServer,
           'localhost', 0, app, server_class=WSGIServer,
           handler_class=HTTPGitRequestHandler)
           handler_class=HTTPGitRequestHandler)
+        self.addCleanup(dul_server.shutdown)
         threading.Thread(target=dul_server.serve_forever).start()
         threading.Thread(target=dul_server.serve_forever).start()
         self._server = dul_server
         self._server = dul_server
         _, port = dul_server.socket.getsockname()
         _, port = dul_server.socket.getsockname()
@@ -91,14 +92,6 @@ class SmartWebTestCase(WebTests, CompatTestCase):
 
 
     min_git_version = (1, 6, 6)
     min_git_version = (1, 6, 6)
 
 
-    def setUp(self):
-        WebTests.setUp(self)
-        CompatTestCase.setUp(self)
-
-    def tearDown(self):
-        WebTests.tearDown(self)
-        CompatTestCase.tearDown(self)
-
     def _handlers(self):
     def _handlers(self):
         return {'git-receive-pack': NoSideBand64kReceivePackHandler}
         return {'git-receive-pack': NoSideBand64kReceivePackHandler}
 
 
@@ -131,17 +124,9 @@ class SmartWebSideBand64kTestCase(SmartWebTestCase):
 class DumbWebTestCase(WebTests, CompatTestCase):
 class DumbWebTestCase(WebTests, CompatTestCase):
     """Test cases for dumb HTTP server."""
     """Test cases for dumb HTTP server."""
 
 
-    def setUp(self):
-        WebTests.setUp(self)
-        CompatTestCase.setUp(self)
-
-    def tearDown(self):
-        WebTests.tearDown(self)
-        CompatTestCase.tearDown(self)
-
     def _make_app(self, backend):
     def _make_app(self, backend):
         return HTTPGitApplication(backend, dumb=True)
         return HTTPGitApplication(backend, dumb=True)
 
 
     def test_push_to_dulwich(self):
     def test_push_to_dulwich(self):
         # Note: remove this if dumb pushing is supported
         # Note: remove this if dumb pushing is supported
-        raise TestSkipped('Dumb web pushing not supported.')
+        raise SkipTest('Dumb web pushing not supported.')

+ 6 - 6
dulwich/tests/compat/utils.py

@@ -30,8 +30,8 @@ from dulwich.repo import Repo
 from dulwich.protocol import TCP_GIT_PORT
 from dulwich.protocol import TCP_GIT_PORT
 
 
 from dulwich.tests import (
 from dulwich.tests import (
+    SkipTest,
     TestCase,
     TestCase,
-    TestSkipped,
     )
     )
 
 
 _DEFAULT_GIT = 'git'
 _DEFAULT_GIT = 'git'
@@ -77,12 +77,12 @@ def require_git_version(required_version, git_path=_DEFAULT_GIT):
     :param git_path: Path to the git executable; defaults to the version in
     :param git_path: Path to the git executable; defaults to the version in
         the system path.
         the system path.
     :raise ValueError: if the required version tuple has too many parts.
     :raise ValueError: if the required version tuple has too many parts.
-    :raise TestSkipped: if no suitable git version was found at the given path.
+    :raise SkipTest: if no suitable git version was found at the given path.
     """
     """
     found_version = git_version(git_path=git_path)
     found_version = git_version(git_path=git_path)
     if found_version is None:
     if found_version is None:
-        raise TestSkipped('Test requires git >= %s, but c git not found' %
-                         (required_version, ))
+        raise SkipTest('Test requires git >= %s, but c git not found' %
+                       (required_version, ))
 
 
     if len(required_version) > _VERSION_LEN:
     if len(required_version) > _VERSION_LEN:
         raise ValueError('Invalid version tuple %s, expected %i parts' %
         raise ValueError('Invalid version tuple %s, expected %i parts' %
@@ -96,8 +96,8 @@ def require_git_version(required_version, git_path=_DEFAULT_GIT):
     if found_version < required_version:
     if found_version < required_version:
         required_version = '.'.join(map(str, required_version))
         required_version = '.'.join(map(str, required_version))
         found_version = '.'.join(map(str, found_version))
         found_version = '.'.join(map(str, found_version))
-        raise TestSkipped('Test requires git >= %s, found %s' %
-                         (required_version, found_version))
+        raise SkipTest('Test requires git >= %s, found %s' %
+                       (required_version, found_version))
 
 
 
 
 def run_git(args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False,
 def run_git(args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False,

+ 31 - 5
dulwich/tests/test_client.py

@@ -19,10 +19,13 @@
 from cStringIO import StringIO
 from cStringIO import StringIO
 
 
 from dulwich.client import (
 from dulwich.client import (
-    GitClient,
+    TraditionalGitClient,
     TCPGitClient,
     TCPGitClient,
     SubprocessGitClient,
     SubprocessGitClient,
     SSHGitClient,
     SSHGitClient,
+    ReportStatusParser,
+    SendPackError,
+    UpdateRefsError,
     get_transport_and_path,
     get_transport_and_path,
     )
     )
 from dulwich.tests import (
 from dulwich.tests import (
@@ -34,13 +37,13 @@ from dulwich.protocol import (
     )
     )
 
 
 
 
-class DummyClient(GitClient):
+class DummyClient(TraditionalGitClient):
 
 
     def __init__(self, can_read, read, write):
     def __init__(self, can_read, read, write):
         self.can_read = can_read
         self.can_read = can_read
         self.read = read
         self.read = read
         self.write = write
         self.write = write
-        GitClient.__init__(self)
+        TraditionalGitClient.__init__(self)
 
 
     def _connect(self, service, path):
     def _connect(self, service, path):
         return Protocol(self.read, self.write), self.can_read
         return Protocol(self.read, self.write), self.can_read
@@ -58,9 +61,9 @@ class GitClientTests(TestCase):
 
 
     def test_caps(self):
     def test_caps(self):
         self.assertEquals(set(['multi_ack', 'side-band-64k', 'ofs-delta',
         self.assertEquals(set(['multi_ack', 'side-band-64k', 'ofs-delta',
-                               'thin-pack']),
+                               'thin-pack', 'multi_ack_detailed']),
                           set(self.client._fetch_capabilities))
                           set(self.client._fetch_capabilities))
-        self.assertEquals(set(['ofs-delta', 'report-status']),
+        self.assertEquals(set(['ofs-delta', 'report-status', 'side-band-64k']),
                           set(self.client._send_capabilities))
                           set(self.client._send_capabilities))
 
 
     def test_fetch_pack_none(self):
     def test_fetch_pack_none(self):
@@ -151,3 +154,26 @@ class SSHGitClientTests(TestCase):
         self.assertEquals('/usr/lib/git/git-upload-pack',
         self.assertEquals('/usr/lib/git/git-upload-pack',
             self.client._get_cmd_path('upload-pack'))
             self.client._get_cmd_path('upload-pack'))
 
 
+
+class ReportStatusParserTests(TestCase):
+
+    def test_invalid_pack(self):
+        parser = ReportStatusParser()
+        parser.handle_packet("unpack error - foo bar")
+        parser.handle_packet("ok refs/foo/bar")
+        parser.handle_packet(None)
+        self.assertRaises(SendPackError, parser.check)
+
+    def test_update_refs_error(self):
+        parser = ReportStatusParser()
+        parser.handle_packet("unpack ok")
+        parser.handle_packet("ng refs/foo/bar need to pull")
+        parser.handle_packet(None)
+        self.assertRaises(UpdateRefsError, parser.check)
+
+    def test_ok(self):
+        parser = ReportStatusParser()
+        parser.handle_packet("unpack ok")
+        parser.handle_packet("ok refs/foo/bar")
+        parser.handle_packet(None)
+        parser.check()

+ 3 - 3
dulwich/tests/test_fastexport.py

@@ -32,8 +32,8 @@ from dulwich.repo import (
     MemoryRepo,
     MemoryRepo,
     )
     )
 from dulwich.tests import (
 from dulwich.tests import (
+    SkipTest,
     TestCase,
     TestCase,
-    TestSkipped,
     )
     )
 
 
 
 
@@ -47,7 +47,7 @@ class GitFastExporterTests(TestCase):
         try:
         try:
             from dulwich.fastexport import GitFastExporter
             from dulwich.fastexport import GitFastExporter
         except ImportError:
         except ImportError:
-            raise TestSkipped("python-fastimport not available")
+            raise SkipTest("python-fastimport not available")
         self.fastexporter = GitFastExporter(self.stream, self.store)
         self.fastexporter = GitFastExporter(self.stream, self.store)
 
 
     def test_emit_blob(self):
     def test_emit_blob(self):
@@ -93,7 +93,7 @@ class GitImportProcessorTests(TestCase):
         try:
         try:
             from dulwich.fastexport import GitImportProcessor
             from dulwich.fastexport import GitImportProcessor
         except ImportError:
         except ImportError:
-            raise TestSkipped("python-fastimport not available")
+            raise SkipTest("python-fastimport not available")
         self.processor = GitImportProcessor(self.repo)
         self.processor = GitImportProcessor(self.repo)
 
 
     def test_commit_handler(self):
     def test_commit_handler(self):

+ 2 - 2
dulwich/tests/test_file.py

@@ -24,8 +24,8 @@ import tempfile
 
 
 from dulwich.file import GitFile, fancy_rename
 from dulwich.file import GitFile, fancy_rename
 from dulwich.tests import (
 from dulwich.tests import (
+    SkipTest,
     TestCase,
     TestCase,
-    TestSkipped,
     )
     )
 
 
 
 
@@ -70,7 +70,7 @@ class FancyRenameTests(TestCase):
 
 
     def test_dest_opened(self):
     def test_dest_opened(self):
         if sys.platform != "win32":
         if sys.platform != "win32":
-            raise TestSkipped("platform allows overwriting open files")
+            raise SkipTest("platform allows overwriting open files")
         self.create(self.bar, 'bar contents')
         self.create(self.bar, 'bar contents')
         dest_f = open(self.bar, 'rb')
         dest_f = open(self.bar, 'rb')
         self.assertRaises(OSError, fancy_rename, self.foo, self.bar)
         self.assertRaises(OSError, fancy_rename, self.foo, self.bar)

+ 22 - 1
dulwich/tests/test_object_store.py

@@ -221,12 +221,33 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
     def setUp(self):
     def setUp(self):
         TestCase.setUp(self)
         TestCase.setUp(self)
         self.store_dir = tempfile.mkdtemp()
         self.store_dir = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, self.store_dir)
         self.store = DiskObjectStore.init(self.store_dir)
         self.store = DiskObjectStore.init(self.store_dir)
 
 
     def tearDown(self):
     def tearDown(self):
         TestCase.tearDown(self)
         TestCase.tearDown(self)
         PackBasedObjectStoreTests.tearDown(self)
         PackBasedObjectStoreTests.tearDown(self)
-        shutil.rmtree(self.store_dir)
+
+    def test_alternates(self):
+        alternate_dir = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, alternate_dir)
+        alternate_store = DiskObjectStore(alternate_dir)
+        b2 = make_object(Blob, data="yummy data")
+        alternate_store.add_object(b2)
+        store = DiskObjectStore(self.store_dir)
+        self.assertRaises(KeyError, store.__getitem__, b2.id)
+        store.add_alternate_path(alternate_dir)
+        self.assertEquals(b2, store[b2.id])
+
+    def test_add_alternate_path(self):
+        store = DiskObjectStore(self.store_dir)
+        self.assertEquals([], store._read_alternate_paths())
+        store.add_alternate_path("/foo/path")
+        self.assertEquals(["/foo/path"], store._read_alternate_paths())
+        store.add_alternate_path("/bar/path")
+        self.assertEquals(
+            ["/foo/path", "/bar/path"],
+            store._read_alternate_paths())
 
 
     def test_pack_dir(self):
     def test_pack_dir(self):
         o = DiskObjectStore(self.store_dir)
         o = DiskObjectStore(self.store_dir)

+ 2 - 4
dulwich/tests/test_pack.py

@@ -91,10 +91,7 @@ class PackTests(TestCase):
     def setUp(self):
     def setUp(self):
         super(PackTests, self).setUp()
         super(PackTests, self).setUp()
         self.tempdir = tempfile.mkdtemp()
         self.tempdir = tempfile.mkdtemp()
-
-    def tearDown(self):
-        shutil.rmtree(self.tempdir)
-        super(PackTests, self).tearDown()
+        self.addCleanup(shutil.rmtree, self.tempdir)
 
 
     datadir = os.path.abspath(os.path.join(os.path.dirname(__file__),
     datadir = os.path.abspath(os.path.join(os.path.dirname(__file__),
         'data/packs'))
         'data/packs'))
@@ -736,6 +733,7 @@ class TestPackIterator(DeltaChainIterator):
 class DeltaChainIteratorTests(TestCase):
 class DeltaChainIteratorTests(TestCase):
 
 
     def setUp(self):
     def setUp(self):
+        super(DeltaChainIteratorTests, self).setUp()
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
         self.fetched = set()
         self.fetched = set()
 
 

+ 2 - 2
dulwich/tests/test_patch.py

@@ -37,8 +37,8 @@ from dulwich.patch import (
     write_tree_diff,
     write_tree_diff,
     )
     )
 from dulwich.tests import (
 from dulwich.tests import (
+    SkipTest,
     TestCase,
     TestCase,
-    TestSkipped,
     )
     )
 
 
 
 
@@ -164,7 +164,7 @@ From: Jelmer Vernooy <jelmer@debian.org>
         self.assertEquals(None, version)
         self.assertEquals(None, version)
 
 
     def test_extract_mercurial(self):
     def test_extract_mercurial(self):
-        raise TestSkipped("git_am_patch_split doesn't handle Mercurial patches properly yet")
+        raise SkipTest("git_am_patch_split doesn't handle Mercurial patches properly yet")
         expected_diff = """diff --git a/dulwich/tests/test_patch.py b/dulwich/tests/test_patch.py
         expected_diff = """diff --git a/dulwich/tests/test_patch.py b/dulwich/tests/test_patch.py
 --- a/dulwich/tests/test_patch.py
 --- a/dulwich/tests/test_patch.py
 +++ b/dulwich/tests/test_patch.py
 +++ b/dulwich/tests/test_patch.py

+ 27 - 0
dulwich/tests/test_protocol.py

@@ -25,6 +25,7 @@ from dulwich.errors import (
     HangupException,
     HangupException,
     )
     )
 from dulwich.protocol import (
 from dulwich.protocol import (
+    PktLineParser,
     Protocol,
     Protocol,
     ReceivableProtocol,
     ReceivableProtocol,
     extract_capabilities,
     extract_capabilities,
@@ -280,3 +281,29 @@ class BufferedPktLineWriterTests(TestCase):
         self._writer.write('z')
         self._writer.write('z')
         self._writer.flush()
         self._writer.flush()
         self.assertOutputEquals('0005z')
         self.assertOutputEquals('0005z')
+
+
+class PktLineParserTests(TestCase):
+
+    def test_none(self):
+        pktlines = []
+        parser = PktLineParser(pktlines.append)
+        parser.parse("0000")
+        self.assertEquals(pktlines, [None])
+        self.assertEquals("", parser.get_tail())
+
+    def test_small_fragments(self):
+        pktlines = []
+        parser = PktLineParser(pktlines.append)
+        parser.parse("00")
+        parser.parse("05")
+        parser.parse("z0000")
+        self.assertEquals(pktlines, ["z", None])
+        self.assertEquals("", parser.get_tail())
+
+    def test_multiple_packets(self):
+        pktlines = []
+        parser = PktLineParser(pktlines.append)
+        parser.parse("0005z0006aba")
+        self.assertEquals(pktlines, ["z", "ab"])
+        self.assertEquals("a", parser.get_tail())

+ 82 - 4
dulwich/tests/test_repository.py

@@ -113,6 +113,10 @@ class RepositoryTests(TestCase):
         self.assertEqual(r.ref('refs/heads/master'),
         self.assertEqual(r.ref('refs/heads/master'),
                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
 
 
+    def test_iter(self):
+        r = self._repo = open_repo('a.git')
+        self.assertRaises(NotImplementedError, r.__iter__)
+
     def test_setitem(self):
     def test_setitem(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         r["refs/tags/foo"] = 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
         r["refs/tags/foo"] = 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
@@ -247,17 +251,45 @@ class RepositoryTests(TestCase):
         self.addCleanup(warnings.resetwarnings)
         self.addCleanup(warnings.resetwarnings)
         self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
         self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
 
 
+    def test_get_walker(self):
+        r = self._repo = open_repo('a.git')
+        # include defaults to [r.head()]
+        self.assertEqual([e.commit.id for e in r.get_walker()],
+                         [r.head(), '2a72d929692c41d8554c07f6301757ba18a65d91'])
+        self.assertEqual(
+            [e.commit.id for e in r.get_walker(['2a72d929692c41d8554c07f6301757ba18a65d91'])],
+            ['2a72d929692c41d8554c07f6301757ba18a65d91'])
+
     def test_linear_history(self):
     def test_linear_history(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
+        warnings.simplefilter("ignore", DeprecationWarning)
+        self.addCleanup(warnings.resetwarnings)
         history = r.revision_history(r.head())
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, [r.head(),
         self.assertEqual(shas, [r.head(),
                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
 
 
+    def test_clone(self):
+        r = self._repo = open_repo('a.git')
+        tmp_dir = tempfile.mkdtemp()
+        self.addCleanup(shutil.rmtree, tmp_dir)
+        t = r.clone(tmp_dir, mkdir=False)
+        self.assertEqual({
+            'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+            'refs/remotes/origin/master':
+                'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+            'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+            'refs/tags/mytag': '28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
+            'refs/tags/mytag-packed':
+                'b0931cadc54336e78a1d980420e3268903b57a50',
+            }, t.refs.as_dict())
+        shas = [e.commit.id for e in r.get_walker()]
+        self.assertEqual(shas, [t.head(),
+                         '2a72d929692c41d8554c07f6301757ba18a65d91'])
+
     def test_merge_history(self):
     def test_merge_history(self):
         r = self._repo = open_repo('simple_merge.git')
         r = self._repo = open_repo('simple_merge.git')
-        history = r.revision_history(r.head())
-        shas = [c.sha().hexdigest() for c in history]
+        shas = [e.commit.id for e in r.get_walker()]
         self.assertEqual(shas, ['5dac377bdded4c9aeb8dff595f0faeebcc8498cc',
         self.assertEqual(shas, ['5dac377bdded4c9aeb8dff595f0faeebcc8498cc',
                                 'ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
                                 'ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
                                 '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6',
                                 '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6',
@@ -266,14 +298,15 @@ class RepositoryTests(TestCase):
 
 
     def test_revision_history_missing_commit(self):
     def test_revision_history_missing_commit(self):
         r = self._repo = open_repo('simple_merge.git')
         r = self._repo = open_repo('simple_merge.git')
+        warnings.simplefilter("ignore", DeprecationWarning)
+        self.addCleanup(warnings.resetwarnings)
         self.assertRaises(errors.MissingCommitError, r.revision_history,
         self.assertRaises(errors.MissingCommitError, r.revision_history,
                           missing_sha)
                           missing_sha)
 
 
     def test_out_of_order_merge(self):
     def test_out_of_order_merge(self):
         """Test that revision history is ordered by date, not parent order."""
         """Test that revision history is ordered by date, not parent order."""
         r = self._repo = open_repo('ooo_merge.git')
         r = self._repo = open_repo('ooo_merge.git')
-        history = r.revision_history(r.head())
-        shas = [c.sha().hexdigest() for c in history]
+        shas = [e.commit.id for e in r.get_walker()]
         self.assertEqual(shas, ['7601d7f6231db6a57f7bbb79ee52e4d462fd44d1',
         self.assertEqual(shas, ['7601d7f6231db6a57f7bbb79ee52e4d462fd44d1',
                                 'f507291b64138b875c28e03469025b1ea20bc614',
                                 'f507291b64138b875c28e03469025b1ea20bc614',
                                 'fb5b0425c7ce46959bec94d54b9a157645e114f5',
                                 'fb5b0425c7ce46959bec94d54b9a157645e114f5',
@@ -445,6 +478,50 @@ class BuildRepoTests(TestCase):
         self.assertEqual(r[self._root_commit].tree, new_commit.tree)
         self.assertEqual(r[self._root_commit].tree, new_commit.tree)
         self.assertEqual('failed commit', new_commit.message)
         self.assertEqual('failed commit', new_commit.message)
 
 
+    def test_commit_branch(self):
+        r = self._repo
+
+        commit_sha = r.do_commit('commit to branch',
+             committer='Test Committer <test@nodomain.com>',
+             author='Test Author <test@nodomain.com>',
+             commit_timestamp=12395, commit_timezone=0,
+             author_timestamp=12395, author_timezone=0,
+             ref="refs/heads/new_branch")
+        self.assertEqual(self._root_commit, r["HEAD"].id)
+        self.assertEqual(commit_sha, r["refs/heads/new_branch"].id)
+        self.assertEqual([], r[commit_sha].parents)
+        self.assertTrue("refs/heads/new_branch" in r)
+
+        new_branch_head = commit_sha
+
+        commit_sha = r.do_commit('commit to branch 2',
+             committer='Test Committer <test@nodomain.com>',
+             author='Test Author <test@nodomain.com>',
+             commit_timestamp=12395, commit_timezone=0,
+             author_timestamp=12395, author_timezone=0,
+             ref="refs/heads/new_branch")
+        self.assertEqual(self._root_commit, r["HEAD"].id)
+        self.assertEqual(commit_sha, r["refs/heads/new_branch"].id)
+        self.assertEqual([new_branch_head], r[commit_sha].parents)
+
+    def test_commit_merge_heads(self):
+        r = self._repo
+        merge_1 = r.do_commit('commit to branch 2',
+             committer='Test Committer <test@nodomain.com>',
+             author='Test Author <test@nodomain.com>',
+             commit_timestamp=12395, commit_timezone=0,
+             author_timestamp=12395, author_timezone=0,
+             ref="refs/heads/new_branch")
+        commit_sha = r.do_commit('commit with merge',
+             committer='Test Committer <test@nodomain.com>',
+             author='Test Author <test@nodomain.com>',
+             commit_timestamp=12395, commit_timezone=0,
+             author_timestamp=12395, author_timezone=0,
+             merge_heads=[merge_1])
+        self.assertEquals(
+            [self._root_commit, merge_1],
+            r[commit_sha].parents)
+
     def test_stage_deleted(self):
     def test_stage_deleted(self):
         r = self._repo
         r = self._repo
         os.remove(os.path.join(r.path, 'a'))
         os.remove(os.path.join(r.path, 'a'))
@@ -611,6 +688,7 @@ class RefsContainerTests(object):
 
 
     def test_check_refname(self):
     def test_check_refname(self):
         self._refs._check_refname('HEAD')
         self._refs._check_refname('HEAD')
+        self._refs._check_refname('refs/stash')
         self._refs._check_refname('refs/heads/foo')
         self._refs._check_refname('refs/heads/foo')
 
 
         self.assertRaises(errors.RefFormatError, self._refs._check_refname,
         self.assertRaises(errors.RefFormatError, self._refs._check_refname,

+ 1 - 0
dulwich/tests/test_utils.py

@@ -37,6 +37,7 @@ from utils import (
 class BuildCommitGraphTest(TestCase):
 class BuildCommitGraphTest(TestCase):
 
 
     def setUp(self):
     def setUp(self):
+        super(BuildCommitGraphTest, self).setUp()
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
 
 
     def test_linear(self):
     def test_linear(self):

+ 1 - 0
dulwich/tests/test_walk.py

@@ -74,6 +74,7 @@ class TestWalkEntry(object):
 class WalkerTest(TestCase):
 class WalkerTest(TestCase):
 
 
     def setUp(self):
     def setUp(self):
+        super(WalkerTest, self).setUp()
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
 
 
     def make_commits(self, commit_spec, **kwargs):
     def make_commits(self, commit_spec, **kwargs):

+ 3 - 4
dulwich/tests/utils.py

@@ -20,7 +20,6 @@
 """Utility functions common to Dulwich tests."""
 """Utility functions common to Dulwich tests."""
 
 
 
 
-from cStringIO import StringIO
 import datetime
 import datetime
 import os
 import os
 import shutil
 import shutil
@@ -47,7 +46,7 @@ from dulwich.pack import (
     )
     )
 from dulwich.repo import Repo
 from dulwich.repo import Repo
 from dulwich.tests import (
 from dulwich.tests import (
-    TestSkipped,
+    SkipTest,
     )
     )
 
 
 # Plain files are very frequently used in tests, so let the mode be very short.
 # Plain files are very frequently used in tests, so let the mode be very short.
@@ -142,7 +141,7 @@ def ext_functest_builder(method, func):
 
 
     This is intended to generate test methods that test both a pure-Python
     This is intended to generate test methods that test both a pure-Python
     version and an extension version using common test code. The extension test
     version and an extension version using common test code. The extension test
-    will raise TestSkipped if the extension is not found.
+    will raise SkipTest if the extension is not found.
 
 
     Sample usage:
     Sample usage:
 
 
@@ -160,7 +159,7 @@ def ext_functest_builder(method, func):
 
 
     def do_test(self):
     def do_test(self):
         if not isinstance(func, types.BuiltinFunctionType):
         if not isinstance(func, types.BuiltinFunctionType):
-            raise TestSkipped("%s extension not found", func.func_name)
+            raise SkipTest("%s extension not found", func.func_name)
         method(self, func)
         method(self, func)
 
 
     return do_test
     return do_test

+ 2 - 0
dulwich/walk.py

@@ -221,6 +221,8 @@ class Walker(object):
             iterator protocol. The constructor takes a single argument, the
             iterator protocol. The constructor takes a single argument, the
             Walker.
             Walker.
         """
         """
+        # Note: when adding arguments to this method, please also update
+        # dulwich.repo.BaseRepo.get_walker
         if order not in ALL_ORDERS:
         if order not in ALL_ORDERS:
             raise ValueError('Unknown walk order %s' % order)
             raise ValueError('Unknown walk order %s' % order)
         self.store = store
         self.store = store

+ 4 - 4
setup.py

@@ -1,6 +1,6 @@
 #!/usr/bin/python
 #!/usr/bin/python
 # Setup file for dulwich
 # Setup file for dulwich
-# Copyright (C) 2008-2010 Jelmer Vernooij <jelmer@samba.org>
+# Copyright (C) 2008-2011 Jelmer Vernooij <jelmer@samba.org>
 
 
 try:
 try:
     from setuptools import setup, Extension
     from setuptools import setup, Extension
@@ -10,7 +10,7 @@ except ImportError:
     has_setuptools = False
     has_setuptools = False
 from distutils.core import Distribution
 from distutils.core import Distribution
 
 
-dulwich_version_string = '0.8.0'
+dulwich_version_string = '0.8.1'
 
 
 include_dirs = []
 include_dirs = []
 # Windows MSVC support
 # Windows MSVC support
@@ -27,11 +27,11 @@ class DulwichDistribution(Distribution):
             return True
             return True
 
 
     def has_ext_modules(self):
     def has_ext_modules(self):
-        return not self.pure
+        return not self.pure and not '__pypy__' in sys.modules
 
 
     global_options = Distribution.global_options + [
     global_options = Distribution.global_options + [
         ('pure', None, 
         ('pure', None, 
-            "use pure (slower) Python code instead of C extensions")]
+            "use pure Python code instead of C extensions (slower on CPython)")]
 
 
     pure = False
     pure = False