浏览代码

Import upstream version 0.6.1+bzr685

Jelmer Vernooij 14 年之前
父节点
当前提交
6643b18378

+ 0 - 7
.bzrignore

@@ -1,7 +0,0 @@
-_trial_temp
-build
-MANIFEST
-dist
-apidocs
-*,cover
-.testrepository

+ 94 - 0
NEWS

@@ -1,3 +1,85 @@
+0.6.2	UNRELEASED
+
+ BUG FIXES
+
+  * HTTP server correctly handles empty CONTENT_LENGTH. (Dave Borowitz)
+
+  * Don't error when creating GitFiles with the default mode. (Dave Borowitz)
+
+  * ThinPackData.from_file now works with resolve_ext_ref callback.
+    (Dave Borowitz)
+
+  * Provide strnlen() on mingw32 which doesn't have it. (Hans Kolek)
+
+ FEATURES
+
+  * Use slots for core objects to save up on memory. (Jelmer Vernooij)
+
+  * Web server supports streaming progress/pack output. (Dave Borowitz)
+
+  * New public function dulwich.pack.write_pack_header. (Dave Borowitz)
+
+  * Distinguish between missing files and read errors in HTTP server.
+    (Dave Borowitz)
+
+  * Initial work on support for fastimport using python-fastimport.
+    (Jelmer Vernooij)
+
+  * New dulwich.pack.MemoryPackIndex class. (Jelmer Vernooij)
+
+  * Delegate SHA peeling to the object store.  (Dave Borowitz)
+
+ TESTS
+
+  * Use GitFile when modifying packed-refs in tests. (Dave Borowitz)
+
+  * New tests in test_web with better coverage and fewer ad-hoc mocks.
+    (Dave Borowitz)
+
+  * Standardize quote delimiters in test_protocol. (Dave Borowitz)
+
+  * Fix use when testtools is installed. (Jelmer Vernooij)
+
+  * Add trivial test for write_pack_header. (Jelmer Vernooij)
+
+  * Refactor some of dulwich.tests.compat.server_utils. (Dave Borowitz)
+
+  * Allow overwriting id property of objects in test utils. (Dave Borowitz)
+
+  * Use real in-memory objects rather than stubs for server tests.
+    (Dave Borowitz)
+
+  * Clean up MissingObjectFinder. (Dave Borowitz)
+
+ API CHANGES
+
+  * ObjectStore.iter_tree_contents now walks contents in depth-first, sorted
+    order. (Dave Borowitz)
+
+  * ObjectStore.iter_tree_contents can optionally yield tree objects as well.
+    (Dave Borowitz).
+
+  * Add side-band-64k support to ReceivePackHandler. (Dave Borowitz)
+
+  * Change server capabilities methods to classmethods. (Dave Borowitz)
+
+  * Tweak server handler injection. (Dave Borowitz)
+
+  * PackIndex1 and PackIndex2 now subclass FilePackIndex, which is 
+    itself a subclass of PackIndex. (Jelmer Vernooij)
+
+ DOCUMENTATION
+
+  * Add docstrings for various functions in dulwich.objects. (Jelmer Vernooij)
+
+  * Clean up docstrings in dulwich.protocol. (Dave Borowitz)
+
+  * Explicitly specify allowed protocol commands to
+    ProtocolGraphWalker.read_proto_line.  (Dave Borowitz)
+
+  * Add utility functions to DictRefsContainer. (Dave Borowitz)
+
+
 0.6.1	2010-07-22
 
  BUG FIXES
@@ -31,10 +113,18 @@
 
   * Quiet logging output from web tests. (Dave Borowitz)
 
+  * More flexible version checking for compat tests. (Dave Borowitz)
+
+  * Compat tests for servers with and without side-band-64k. (Dave Borowitz)
+
  CLEANUP
 
   * Clean up file headers. (Dave Borowitz)
 
+ TESTS
+
+  * Use GitFile when modifying packed-refs in tests. (Dave Borowitz)
+
  API CHANGES
 
   * dulwich.pack.write_pack_index_v{1,2} now take a file-like object
@@ -45,6 +135,10 @@
 
   * Move reference WSGI handler to web.py. (Dave Borowitz)
 
+  * Factor out _report_status in ReceivePackHandler. (Dave Borowitz)
+
+  * Factor out a function to convert a line to a pkt-line. (Dave Borowitz)
+
 
 0.6.0	2010-05-22
 

+ 1 - 1
dulwich/__init__.py

@@ -27,4 +27,4 @@ import protocol
 import repo
 import server
 
-__version__ = (0, 6, 1)
+__version__ = (0, 6, 2)

+ 8 - 0
dulwich/_objects.c

@@ -25,6 +25,14 @@
 typedef int Py_ssize_t;
 #endif
 
+#if defined(__MINGW32_VERSION) || defined(__APPLE__)
+size_t strnlen(char *text, size_t maxlen)
+{
+	const char *last = memchr(text, '\0', maxlen);
+	return last ? (size_t) (last - text) : maxlen;
+}
+#endif
+
 #define bytehex(x) (((x)<0xa)?('0'+(x)):('a'-0xa+(x)))
 
 static PyObject *sha_to_pyhex(const unsigned char *sha)

+ 11 - 0
dulwich/errors.py

@@ -137,6 +137,17 @@ class HangupException(GitProtocolError):
             "The remote server unexpectedly closed the connection.")
 
 
+class UnexpectedCommandError(GitProtocolError):
+    """Unexpected command received in a proto line."""
+
+    def __init__(self, command):
+        if command is None:
+            command = 'flush-pkt'
+        else:
+            command = 'command %s' % command
+        GitProtocolError.__init__(self, 'Protocol got unexpected %s' % command)
+
+
 class FileFormatException(Exception):
     """Base class for exceptions relating to reading git file formats."""
 

+ 226 - 39
dulwich/fastexport.py

@@ -20,13 +20,30 @@
 
 """Fast export/import functionality."""
 
+from dulwich.index import (
+    commit_tree,
+    )
 from dulwich.objects import (
-    format_timezone,
+    Blob,
+    Commit,
+    Tag,
+    parse_timezone,
+    )
+from fastimport import (
+    commands,
+    errors as fastimport_errors,
+    processor,
     )
 
 import stat
 
-class FastExporter(object):
+
+def split_email(text):
+    (name, email) = text.rsplit(" <", 1)
+    return (name, email.rstrip(">"))
+
+
+class GitFastExporter(object):
     """Generate a fast-export output stream for Git objects."""
 
     def __init__(self, outf, store):
@@ -35,47 +52,217 @@ class FastExporter(object):
         self.markers = {}
         self._marker_idx = 0
 
+    def print_cmd(self, cmd):
+        self.outf.write("%r\n" % cmd)
+
     def _allocate_marker(self):
         self._marker_idx+=1
-        return self._marker_idx
-
-    def _dump_blob(self, blob, marker):
-        self.outf.write("blob\nmark :%s\n" % marker)
-        self.outf.write("data %s\n" % blob.raw_length())
-        for chunk in blob.as_raw_chunks():
-            self.outf.write(chunk)
-        self.outf.write("\n")
-
-    def export_blob(self, blob):
-        i = self._allocate_marker()
-        self.markers[i] = blob.id
-        self._dump_blob(blob, i)
-        return i
-
-    def _dump_commit(self, commit, marker, ref, file_changes):
-        self.outf.write("commit %s\n" % ref)
-        self.outf.write("mark :%s\n" % marker)
-        self.outf.write("author %s %s %s\n" % (commit.author,
-            commit.author_time, format_timezone(commit.author_timezone)))
-        self.outf.write("committer %s %s %s\n" % (commit.committer,
-            commit.commit_time, format_timezone(commit.commit_timezone)))
-        self.outf.write("data %s\n" % len(commit.message))
-        self.outf.write(commit.message)
-        self.outf.write("\n")
-        self.outf.write('\n'.join(file_changes))
-        self.outf.write("\n\n")
-
-    def export_commit(self, commit, ref, base_tree=None):
-        file_changes = []
+        return str(self._marker_idx)
+
+    def _export_blob(self, blob):
+        marker = self._allocate_marker()
+        self.markers[marker] = blob.id
+        return (commands.BlobCommand(marker, blob.data), marker)
+
+    def emit_blob(self, blob):
+        (cmd, marker) = self._export_blob(blob)
+        self.print_cmd(cmd)
+        return marker
+
+    def _iter_files(self, base_tree, new_tree):
         for (old_path, new_path), (old_mode, new_mode), (old_hexsha, new_hexsha) in \
-                self.store.tree_changes(base_tree, commit.tree):
+                self.store.tree_changes(base_tree, new_tree):
             if new_path is None:
-                file_changes.append("D %s" % old_path)
+                yield commands.FileDeleteCommand(old_path)
                 continue
             if not stat.S_ISDIR(new_mode):
-                marker = self.export_blob(self.store[new_hexsha])
-            file_changes.append("M %o :%s %s" % (new_mode, marker, new_path))
+                blob = self.store[new_hexsha]
+                marker = self.emit_blob(blob)
+            if old_path != new_path and old_path is not None:
+                yield commands.FileRenameCommand(old_path, new_path)
+            if old_mode != new_mode or old_hexsha != new_hexsha:
+                yield commands.FileModifyCommand(new_path, new_mode, marker, None)
+
+    def _export_commit(self, commit, ref, base_tree=None):
+        file_cmds = list(self._iter_files(base_tree, commit.tree))
+        marker = self._allocate_marker()
+        if commit.parents:
+            from_ = commit.parents[0]
+            merges = commit.parents[1:]
+        else:
+            from_ = None
+            merges = []
+        author, author_email = split_email(commit.author)
+        committer, committer_email = split_email(commit.committer)
+        cmd = commands.CommitCommand(ref, marker,
+            (author, author_email, commit.author_time, commit.author_timezone),
+            (committer, committer_email, commit.commit_time, commit.commit_timezone),
+            commit.message, from_, merges, file_cmds)
+        return (cmd, marker)
+
+    def emit_commit(self, commit, ref, base_tree=None):
+        cmd, marker = self._export_commit(commit, ref, base_tree)
+        self.print_cmd(cmd)
+        return marker
+
+
+class FastImporter(object):
+    """Class for importing fastimport streams.
+
+    Please note that this is mostly a stub implementation at the moment,
+    doing the bare mimimum.
+    """
+
+    def __init__(self, repo):
+        self.repo = repo
+
+    def _parse_person(self, line):
+        (name, timestr, timezonestr) = line.rsplit(" ", 2)
+        return name, int(timestr), parse_timezone(timezonestr)[0]
+
+    def _read_blob(self, stream):
+        line = stream.readline()
+        if line.startswith("mark :"):
+            mark = line[len("mark :"):-1]
+            line = stream.readline()
+        else:
+            mark = None
+        if not line.startswith("data "):
+            raise ValueError("Blob without valid data line: %s" % line)
+        size = int(line[len("data "):])
+        o = Blob()
+        o.data = stream.read(size)
+        stream.readline()
+        self.repo.object_store.add_object(o)
+        return mark, o.id
+
+    def _read_commit(self, stream, contents, marks):
+        line = stream.readline()
+        if line.startswith("mark :"):
+            mark = line[len("mark :"):-1]
+            line = stream.readline()
+        else:
+            mark = None
+        o = Commit()
+        o.author = None
+        o.author_time = None
+        while line.startswith("author "):
+            (o.author, o.author_time, o.author_timezone) = \
+                    self._parse_person(line[len("author "):-1])
+            line = stream.readline()
+        while line.startswith("committer "):
+            (o.committer, o.commit_time, o.commit_timezone) = \
+                    self._parse_person(line[len("committer "):-1])
+            line = stream.readline()
+        if o.author is None:
+            o.author = o.committer
+        if o.author_time is None:
+            o.author_time = o.commit_time
+            o.author_timezone = o.commit_timezone
+        if not line.startswith("data "):
+            raise ValueError("Blob without valid data line: %s" % line)
+        size = int(line[len("data "):])
+        o.message = stream.read(size)
+        stream.readline()
+        line = stream.readline()[:-1]
+        while line:
+            if line.startswith("M "):
+                (kind, modestr, val, path) = line.split(" ")
+                if val[0] == ":":
+                    val = marks[val[1:]]
+                contents[path] = (int(modestr, 8), val)
+            else:
+                raise ValueError(line)
+            line = stream.readline()[:-1]
+        try:
+            o.parents = (self.repo.head(),)
+        except KeyError:
+            o.parents = ()
+        o.tree = commit_tree(self.repo.object_store,
+            ((path, hexsha, mode) for (path, (mode, hexsha)) in
+                contents.iteritems()))
+        self.repo.object_store.add_object(o)
+        return mark, o.id
+
+    def import_stream(self, stream):
+        """Import from a file-like object.
+
+        :param stream: File-like object to read a fastimport stream from.
+        :return: Dictionary with marks
+        """
+        contents = {}
+        marks = {}
+        while True:
+            line = stream.readline()
+            if not line:
+                break
+            line = line[:-1]
+            if line == "" or line[0] == "#":
+                continue
+            if line.startswith("blob"):
+                mark, hexsha = self._read_blob(stream)
+                if mark is not None:
+                    marks[mark] = hexsha
+            elif line.startswith("commit "):
+                ref = line[len("commit "):-1]
+                mark, hexsha = self._read_commit(stream, contents, marks)
+                if mark is not None:
+                    marks[mark] = hexsha
+                self.repo.refs["HEAD"] = self.repo.refs[ref] = hexsha
+            else:
+                raise ValueError("invalid command '%s'" % line)
+        return marks
+
+
+class GitImportProcessor(processor.ImportProcessor):
+    """An import processor that imports into a Git repository using Dulwich.
+
+    """
+
+    def __init__(self, repo, params=None, verbose=False, outf=None):
+        processor.ImportProcessor.__init__(self, params, verbose)
+        self.repo = repo
+        self.last_commit = None
+
+    def blob_handler(self, cmd):
+        """Process a BlobCommand."""
+        self.repo.object_store.add_object(Blob.from_string(cmd.data))
+
+    def checkpoint_handler(self, cmd):
+        """Process a CheckpointCommand."""
+        pass
+
+    def commit_handler(self, cmd):
+        """Process a CommitCommand."""
+        commit = Commit()
+        commit.author = cmd.author
+        commit.committer = cmd.committer
+        commit.message = cmd.message
+        commit.parents = []
+        if self.last_commit is not None:
+            commit.parents.append(self.last_commit)
+        commit.parents += cmd.merges
+        self.repo[cmd.ref] = commit.id
+        self.last_commit = commit.id
+
+    def progress_handler(self, cmd):
+        """Process a ProgressCommand."""
+        pass
+
+    def reset_handler(self, cmd):
+        """Process a ResetCommand."""
+        self.last_commit = cmd.from_
+        self.rep.refs[cmd.from_] = cmd.id
+
+    def tag_handler(self, cmd):
+        """Process a TagCommand."""
+        tag = Tag()
+        tag.tagger = cmd.tagger
+        tag.message = cmd.message
+        tag.name = cmd.tag
+        self.repo.add_object(tag)
+        self.repo.refs["refs/tags/" + tag.name] = tag.id
 
-        i = self._allocate_marker()
-        self._dump_commit(commit, i, ref, file_changes)
-        return i
+    def feature_handler(self, cmd):
+        """Process a FeatureCommand."""
+        raise fastimport_errors.UnknownFeature(cmd.feature_name)

+ 1 - 1
dulwich/file.py

@@ -60,7 +60,7 @@ def fancy_rename(oldname, newname):
     os.remove(tmpfile)
 
 
-def GitFile(filename, mode='r', bufsize=-1):
+def GitFile(filename, mode='rb', bufsize=-1):
     """Create a file object that obeys the git file locking protocol.
 
     :return: a builtin file object or a _GitFile object

+ 43 - 17
dulwich/object_store.py

@@ -41,6 +41,7 @@ from dulwich.objects import (
     sha_to_hex,
     hex_to_filename,
     S_ISGITLINK,
+    object_class,
     )
 from dulwich.pack import (
     Pack,
@@ -175,21 +176,26 @@ class BaseObjectStore(object):
                     else:
                         todo.add((None, newhexsha, childpath))
 
-    def iter_tree_contents(self, tree):
-        """Yield (path, mode, hexsha) tuples for all non-Tree objects in a tree.
+    def iter_tree_contents(self, tree_id, include_trees=False):
+        """Iterate the contents of a tree and all subtrees.
 
-        :param tree: SHA1 of the root of the tree
+        Iteration is depth-first pre-order, as in e.g. os.walk.
+
+        :param tree_id: SHA1 of the tree.
+        :param include_trees: If True, include tree objects in the iteration.
+        :yield: Tuples of (path, mode, hexhsa) for objects in a tree.
         """
-        todo = set([(tree, "")])
+        todo = [('', stat.S_IFDIR, tree_id)]
         while todo:
-            (tid, tpath) = todo.pop()
-            tree = self[tid]
-            for name, mode, hexsha in tree.iteritems():
-                path = posixpath.join(tpath, name)
-                if stat.S_ISDIR(mode):
-                    todo.add((hexsha, path))
-                else:
-                    yield path, mode, hexsha
+            path, mode, hexsha = todo.pop()
+            is_subtree = stat.S_ISDIR(mode)
+            if not is_subtree or include_trees:
+                yield path, mode, hexsha
+            if is_subtree:
+                entries = reversed(list(self[hexsha].iteritems()))
+                for name, entry_mode, entry_hexsha in entries:
+                    entry_path = posixpath.join(path, name)
+                    todo.append((entry_path, entry_mode, entry_hexsha))
 
     def find_missing_objects(self, haves, wants, progress=None,
                              get_tagged=None):
@@ -238,6 +244,21 @@ class BaseObjectStore(object):
         """
         return self.iter_shas(self.find_missing_objects(have, want, progress))
 
+    def peel_sha(self, sha):
+        """Peel all tags from a SHA.
+
+        :param sha: The object SHA to peel.
+        :return: The fully-peeled SHA1 of a tag object, after peeling all
+            intermediate tags; if the original ref does not point to a tag, this
+            will equal the original SHA1.
+        """
+        obj = self[sha]
+        obj_class = object_class(obj.type_name)
+        while obj_class is Tag:
+            obj_class, sha = obj.object
+            obj = self[sha]
+        return obj
+
 
 class PackBasedObjectStore(BaseObjectStore):
 
@@ -588,7 +609,7 @@ class ObjectImporter(object):
         raise NotImplementedError(self.add_object)
 
     def finish(self, object):
-        """Finish the imoprt and write objects to disk."""
+        """Finish the import and write objects to disk."""
         raise NotImplementedError(self.finish)
 
 
@@ -690,8 +711,10 @@ class MissingObjectFinder(object):
 
     def __init__(self, object_store, haves, wants, progress=None,
                  get_tagged=None):
-        self.sha_done = set(haves)
-        self.objects_to_send = set([(w, None, False) for w in wants if w not in haves])
+        haves = set(haves)
+        self.sha_done = haves
+        self.objects_to_send = set([(w, None, False) for w in wants
+                                    if w not in haves])
         self.object_store = object_store
         if progress is None:
             self.progress = lambda x: None
@@ -700,10 +723,13 @@ class MissingObjectFinder(object):
         self._tagged = get_tagged and get_tagged() or {}
 
     def add_todo(self, entries):
-        self.objects_to_send.update([e for e in entries if not e[0] in self.sha_done])
+        self.objects_to_send.update([e for e in entries
+                                     if not e[0] in self.sha_done])
 
     def parse_tree(self, tree):
-        self.add_todo([(sha, name, not stat.S_ISDIR(mode)) for (mode, name, sha) in tree.entries() if not S_ISGITLINK(mode)])
+        self.add_todo([(sha, name, not stat.S_ISDIR(mode))
+                       for mode, name, sha in tree.entries()
+                       if not S_ISGITLINK(mode)])
 
     def parse_commit(self, commit):
         self.add_todo([(commit.tree, "", False)])

+ 32 - 0
dulwich/objects.py

@@ -157,6 +157,8 @@ def check_identity(identity, error_msg):
 class FixedSha(object):
     """SHA object that behaves like hashlib's but is given a fixed value."""
 
+    __slots__ = ('_hexsha', '_sha')
+
     def __init__(self, hexsha):
         self._hexsha = hexsha
         self._sha = hex_to_sha(hexsha)
@@ -171,6 +173,9 @@ class FixedSha(object):
 class ShaFile(object):
     """A git SHA file."""
 
+    __slots__ = ('_needs_parsing', '_chunked_text', '_file', '_path', 
+                 '_sha', '_needs_serialization', '_magic')
+
     @staticmethod
     def _parse_legacy_object_header(magic, f):
         """Parse a legacy object, creating it but not reading the file."""
@@ -474,6 +479,8 @@ class ShaFile(object):
 class Blob(ShaFile):
     """A Git Blob object."""
 
+    __slots__ = ()
+
     type_name = 'blob'
     type_num = 3
 
@@ -555,6 +562,10 @@ class Tag(ShaFile):
     type_name = 'tag'
     type_num = 4
 
+    __slots__ = ('_tag_timezone_neg_utc', '_name', '_object_sha', 
+                 '_object_class', '_tag_time', '_tag_timezone',
+                 '_tagger', '_message')
+
     def __init__(self):
         super(Tag, self).__init__()
         self._tag_timezone_neg_utc = False
@@ -740,6 +751,8 @@ class Tree(ShaFile):
     type_name = 'tree'
     type_num = 2
 
+    __slots__ = ('_entries')
+
     def __init__(self):
         super(Tree, self).__init__()
         self._entries = {}
@@ -865,6 +878,13 @@ class Tree(ShaFile):
 
 
 def parse_timezone(text):
+    """Parse a timezone text fragment (e.g. '+0100').
+
+    :param text: Text to parse.
+    :return: Tuple with timezone as seconds difference to UTC 
+        and a boolean indicating whether this was a UTC timezone
+        prefixed with a negative sign (-0000).
+    """
     offset = int(text)
     negative_utc = (offset == 0 and text[0] == '-')
     signum = (offset < 0) and -1 or 1
@@ -875,6 +895,12 @@ def parse_timezone(text):
 
 
 def format_timezone(offset, negative_utc=False):
+    """Format a timezone for Git serialization.
+
+    :param offset: Timezone offset as seconds difference to UTC
+    :param negative_utc: Whether to use a minus sign for UTC
+        (-0000 rather than +0000).
+    """
     if offset % 60 != 0:
         raise ValueError("Unable to handle non-minute offset.")
     if offset < 0 or (offset == 0 and negative_utc):
@@ -895,6 +921,12 @@ class Commit(ShaFile):
     type_name = 'commit'
     type_num = 1
 
+    __slots__ = ('_parents', '_encoding', '_extra', '_author_timezone_neg_utc',
+                 '_commit_timezone_neg_utc', '_commit_time',
+                 '_author_time', '_author_timezone', '_commit_timezone',
+                 '_author', '_committer', '_parents', '_extra',
+                 '_encoding', '_tree', '_message')
+
     def __init__(self):
         super(Commit, self).__init__()
         self._parents = []

+ 133 - 46
dulwich/pack.py

@@ -222,6 +222,111 @@ class PackIndex(object):
 
     Given a sha id of an object a pack index can tell you the location in the
     packfile of that object if it has it.
+    """
+
+    def __eq__(self, other):
+        if not isinstance(other, PackIndex):
+            return False
+
+        for (name1, _, _), (name2, _, _) in izip(self.iterentries(),
+                                                 other.iterentries()):
+            if name1 != name2:
+                return False
+        return True
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    def __len__(self):
+        """Return the number of entries in this pack index."""
+        raise NotImplementedError(self.__len__)
+
+    def __iter__(self):
+        """Iterate over the SHAs in this pack."""
+        raise NotImplementedError(self.__iter__)
+
+    def iterentries(self):
+        """Iterate over the entries in this pack index.
+
+        :return: iterator over tuples with object name, offset in packfile and
+            crc32 checksum.
+        """
+        raise NotImplementedError(self.iterentries)
+
+    def get_pack_checksum(self):
+        """Return the SHA1 checksum stored for the corresponding packfile.
+
+        :return: 20-byte binary digest
+        """
+        raise NotImplementedError(self.get_pack_checksum)
+
+    def object_index(self, sha):
+        """Return the index in to the corresponding packfile for the object.
+
+        Given the name of an object it will return the offset that object
+        lives at within the corresponding pack file. If the pack file doesn't
+        have the object then None will be returned.
+        """
+        if len(sha) == 40:
+            sha = hex_to_sha(sha)
+        return self._object_index(sha)
+
+    def _object_index(self, sha):
+        """See object_index.
+
+        :param sha: A *binary* SHA string. (20 characters long)_
+        """
+        raise NotImplementedError(self._object_index)
+
+    def __iter__(self):
+        """Iterate over the SHAs in this pack."""
+        return imap(sha_to_hex, self._itersha())
+
+    def objects_sha1(self):
+        """Return the hex SHA1 over all the shas of all objects in this pack.
+
+        :note: This is used for the filename of the pack.
+        """
+        return iter_sha1(self._itersha())
+
+    def _itersha(self):
+        """Yield all the SHA1's of the objects in the index, sorted."""
+        raise NotImplementedError(self._itersha)
+
+
+class MemoryPackIndex(PackIndex):
+    """Pack index that is stored entirely in memory."""
+
+    def __init__(self, entries, pack_checksum=None):
+        """Create a new MemoryPackIndex.
+
+        :param entries: Sequence of name, idx, crc32 (sorted)
+        :param pack_checksum: Optional pack checksum
+        """
+        self._by_sha = {}
+        for name, idx, crc32 in entries:
+            self._by_sha[name] = idx
+        self._entries = entries
+        self._pack_checksum = pack_checksum
+
+    def get_pack_checksum(self):
+        return self._pack_checksum
+
+    def __len__(self):
+        return len(self._entries)
+
+    def _object_index(self, sha):
+        return self._by_sha[sha][0]
+
+    def _itersha(self):
+        return iter(self._by_sha)
+
+    def iterentries(self):
+        return iter(self._entries)
+
+
+class FilePackIndex(PackIndex):
+    """Pack index that is based on a file.
 
     To do the loop it opens the file, and indexes first 256 4 byte groups
     with the first byte of the sha id. The value in the four byte group indexed
@@ -250,20 +355,12 @@ class PackIndex(object):
             self._contents, self._size = (contents, size)
 
     def __eq__(self, other):
-        if not isinstance(other, PackIndex):
-            return False
-
-        if self._fan_out_table != other._fan_out_table:
+        # Quick optimization:
+        if (isinstance(other, FilePackIndex) and
+            self._fan_out_table != other._fan_out_table):
             return False
 
-        for (name1, _, _), (name2, _, _) in izip(self.iterentries(),
-                                                 other.iterentries()):
-            if name1 != name2:
-                return False
-        return True
-
-    def __ne__(self, other):
-        return not self.__eq__(other)
+        return super(FilePackIndex, self).__eq__(other)
 
     def close(self):
         self._file.close()
@@ -292,21 +389,10 @@ class PackIndex(object):
         """Unpack the crc32 checksum for the i-th object from the index file."""
         raise NotImplementedError(self._unpack_crc32_checksum)
 
-    def __iter__(self):
-        """Iterate over the SHAs in this pack."""
-        return imap(sha_to_hex, self._itersha())
-
     def _itersha(self):
         for i in range(len(self)):
             yield self._unpack_name(i)
 
-    def objects_sha1(self):
-        """Return the hex SHA1 over all the shas of all objects in this pack.
-
-        :note: This is used for the filename of the pack.
-        """
-        return iter_sha1(self._itersha())
-
     def iterentries(self):
         """Iterate over the entries in this pack index.
 
@@ -351,17 +437,6 @@ class PackIndex(object):
         """
         return str(self._contents[-20:])
 
-    def object_index(self, sha):
-        """Return the index in to the corresponding packfile for the object.
-
-        Given the name of an object it will return the offset that object
-        lives at within the corresponding pack file. If the pack file doesn't
-        have the object then None will be returned.
-        """
-        if len(sha) == 40:
-            sha = hex_to_sha(sha)
-        return self._object_index(sha)
-
     def _object_index(self, sha):
         """See object_index.
 
@@ -380,11 +455,11 @@ class PackIndex(object):
         return self._unpack_offset(i)
 
 
-class PackIndex1(PackIndex):
-    """Version 1 Pack Index."""
+class PackIndex1(FilePackIndex):
+    """Version 1 Pack Index file."""
 
     def __init__(self, filename, file=None, contents=None, size=None):
-        PackIndex.__init__(self, filename, file, contents, size)
+        super(PackIndex1, self).__init__(filename, file, contents, size)
         self.version = 1
         self._fan_out_table = self._read_fan_out_table(0)
 
@@ -406,11 +481,11 @@ class PackIndex1(PackIndex):
         return None
 
 
-class PackIndex2(PackIndex):
-    """Version 2 Pack Index."""
+class PackIndex2(FilePackIndex):
+    """Version 2 Pack Index file."""
 
     def __init__(self, filename, file=None, contents=None, size=None):
-        PackIndex.__init__(self, filename, file, contents, size)
+        super(PackIndex2, self).__init__(filename, file, contents, size)
         assert self._contents[:4] == '\377tOc', "Not a v2 pack index file"
         (self.version, ) = unpack_from(">L", self._contents, 4)
         assert self.version == 2, "Version was %d" % self.version
@@ -888,6 +963,10 @@ class ThinPackData(PackData):
         super(ThinPackData, self).__init__(*args, **kwargs)
         self.resolve_ext_ref = resolve_ext_ref
 
+    @classmethod
+    def from_file(cls, resolve_ext_ref, file, size):
+        return cls(resolve_ext_ref, str(file), file=file, size=size)
+
     def get_ref(self, sha):
         """Resolve a reference looking in both this pack and the store."""
         try:
@@ -1061,11 +1140,21 @@ def write_pack(filename, objects, num_objects):
         f.close()
 
 
+def write_pack_header(f, num_objects):
+    """Write a pack header for the given number of objects."""
+    f.write('PACK')                          # Pack header
+    f.write(struct.pack('>L', 2))            # Pack version
+    f.write(struct.pack('>L', num_objects))  # Number of objects in pack
+
+
 def write_pack_data(f, objects, num_objects, window=10):
-    """Write a new pack file.
+    """Write a new pack data file.
 
-    :param filename: The filename of the new pack file.
-    :param objects: List of objects to write (tuples with object and path)
+    :param f: File to write to
+    :param objects: Iterable over (object, path) tuples to write
+    :param num_objects: Number of objects to write
+    :param window: Sliding window size for searching for deltas; currently
+                   unimplemented
     :return: List with (name, offset, crc32 checksum) entries, pack checksum
     """
     recency = list(objects)
@@ -1085,9 +1174,7 @@ def write_pack_data(f, objects, num_objects, window=10):
     # Write the pack
     entries = []
     f = SHA1Writer(f)
-    f.write("PACK")               # Pack header
-    f.write(struct.pack(">L", 2)) # Pack version
-    f.write(struct.pack(">L", num_objects)) # Number of objects in pack
+    write_pack_header(f, num_objects)
     for o, path in recency:
         sha1 = o.sha().digest()
         orig_t = o.type_num

+ 11 - 4
dulwich/patch.py

@@ -154,14 +154,21 @@ def git_am_patch_split(f):
     c = Commit()
     c.author = msg["from"]
     c.committer = msg["from"]
-    if msg["subject"].startswith("[PATCH"):
-        subject = msg["subject"].split("]", 1)[1][1:]
-    else:
+    try:
+        patch_tag_start = msg["subject"].index("[PATCH")
+    except ValueError:
         subject = msg["subject"]
-    c.message = subject
+    else:
+        close = msg["subject"].index("] ", patch_tag_start)
+        subject = msg["subject"][close+2:]
+    c.message = subject.replace("\n", "") + "\n"
+    first = True
     for l in f:
         if l == "---\n":
             break
+        if first:
+            c.message += "\n"
+            first = False
         c.message += l
     diff = ""
     for l in f:

+ 92 - 31
dulwich/protocol.py

@@ -41,10 +41,7 @@ MULTI_ACK_DETAILED = 2
 
 
 class ProtocolFile(object):
-    """
-    Some network ops are like file ops. The file ops expect to operate on
-    file objects, so provide them with a dummy file.
-    """
+    """A dummy file for network ops that expect file-like objects."""
 
     def __init__(self, read, write):
         self.read = read
@@ -57,7 +54,29 @@ class ProtocolFile(object):
         pass
 
 
+def pkt_line(data):
+    """Wrap data in a pkt-line.
+
+    :param data: The data to wrap, as a str or None.
+    :return: The data prefixed with its length in pkt-line format; if data was
+        None, returns the flush-pkt ('0000').
+    """
+    if data is None:
+        return '0000'
+    return '%04x%s' % (len(data) + 4, data)
+
+
 class Protocol(object):
+    """Class for interacting with a remote git process over the wire.
+
+    Parts of the git wire protocol use 'pkt-lines' to communicate. A pkt-line
+    consists of the length of the line as a 4-byte hex string, followed by the
+    payload data. The length includes the 4-byte header. The special line '0000'
+    indicates the end of a section of input and is called a 'flush-pkt'.
+
+    For details on the pkt-line format, see the cgit distribution:
+        Documentation/technical/protocol-common.txt
+    """
 
     def __init__(self, read, write, report_activity=None):
         self.read = read
@@ -65,10 +84,10 @@ class Protocol(object):
         self.report_activity = report_activity
 
     def read_pkt_line(self):
-        """
-        Reads a 'pkt line' from the remote git process
+        """Reads a pkt-line from the remote git process.
 
-        :return: The next string from the stream
+        :return: The next string from the stream, without the length prefix, or
+            None for a flush-pkt ('0000').
         """
         try:
             sizestr = self.read(4)
@@ -86,30 +105,32 @@ class Protocol(object):
             raise GitProtocolError(e)
 
     def read_pkt_seq(self):
+        """Read a sequence of pkt-lines from the remote git process.
+
+        :yield: Each line of data up to but not including the next flush-pkt.
+        """
         pkt = self.read_pkt_line()
         while pkt:
             yield pkt
             pkt = self.read_pkt_line()
 
     def write_pkt_line(self, line):
-        """
-        Sends a 'pkt line' to the remote git process
+        """Sends a pkt-line to the remote git process.
 
-        :param line: A string containing the data to send
+        :param line: A string containing the data to send, without the length
+            prefix.
         """
         try:
-            if line is None:
-                self.write("0000")
-                if self.report_activity:
-                    self.report_activity(4, 'write')
-            else:
-                self.write("%04x%s" % (len(line)+4, line))
-                if self.report_activity:
-                    self.report_activity(4+len(line), 'write')
+            line = pkt_line(line)
+            self.write(line)
+            if self.report_activity:
+                self.report_activity(len(line), 'write')
         except socket.error, e:
             raise GitProtocolError(e)
 
     def write_file(self):
+        """Return a writable file-like object for this protocol."""
+
         class ProtocolFile(object):
 
             def __init__(self, proto):
@@ -129,11 +150,10 @@ class Protocol(object):
         return ProtocolFile(self)
 
     def write_sideband(self, channel, blob):
-        """
-        Write data to the sideband (a git multiplexing method)
+        """Write multiplexed data to the sideband.
 
-        :param channel: int specifying which channel to write to
-        :param blob: a blob of data (as a string) to send on this channel
+        :param channel: An int specifying the channel to write to.
+        :param blob: A blob of data (as a string) to send on this channel.
         """
         # a pktline can be a max of 65520. a sideband line can therefore be
         # 65520-5 = 65515
@@ -143,23 +163,21 @@ class Protocol(object):
             blob = blob[65515:]
 
     def send_cmd(self, cmd, *args):
-        """
-        Send a command and some arguments to a git server
+        """Send a command and some arguments to a git server.
 
-        Only used for git://
+        Only used for the TCP git protocol (git://).
 
-        :param cmd: The remote service to access
-        :param args: List of arguments to send to remove service
+        :param cmd: The remote service to access.
+        :param args: List of arguments to send to remove service.
         """
         self.write_pkt_line("%s %s" % (cmd, "".join(["%s\0" % a for a in args])))
 
     def read_cmd(self):
-        """
-        Read a command and some arguments from the git client
+        """Read a command and some arguments from the git client
 
-        Only used for git://
+        Only used for the TCP git protocol (git://).
 
-        :return: A tuple of (command, [list of arguments])
+        :return: A tuple of (command, [list of arguments]).
         """
         line = self.read_pkt_line()
         splice_at = line.find(" ")
@@ -310,3 +328,46 @@ def ack_type(capabilities):
     elif 'multi_ack' in capabilities:
         return MULTI_ACK
     return SINGLE_ACK
+
+
+class BufferedPktLineWriter(object):
+    """Writer that wraps its data in pkt-lines and has an independent buffer.
+
+    Consecutive calls to write() wrap the data in a pkt-line and then buffers it
+    until enough lines have been written such that their total length (including
+    length prefix) reach the buffer size.
+    """
+
+    def __init__(self, write, bufsize=65515):
+        """Initialize the BufferedPktLineWriter.
+
+        :param write: A write callback for the underlying writer.
+        :param bufsize: The internal buffer size, including length prefixes.
+        """
+        self._write = write
+        self._bufsize = bufsize
+        self._wbuf = StringIO()
+        self._buflen = 0
+
+    def write(self, data):
+        """Write data, wrapping it in a pkt-line."""
+        line = pkt_line(data)
+        line_len = len(line)
+        over = self._buflen + line_len - self._bufsize
+        if over >= 0:
+            start = line_len - over
+            self._wbuf.write(line[:start])
+            self.flush()
+        else:
+            start = 0
+        saved = line[start:]
+        self._wbuf.write(saved)
+        self._buflen += len(saved)
+
+    def flush(self):
+        """Flush all data from the buffer."""
+        data = self._wbuf.getvalue()
+        if data:
+            self._write(data)
+        self._len = 0
+        self._wbuf = StringIO()

+ 17 - 8
dulwich/repo.py

@@ -342,6 +342,7 @@ class DictRefsContainer(RefsContainer):
 
     def __init__(self, refs):
         self._refs = refs
+        self._peeled = {}
 
     def allkeys(self):
         return self._refs.keys()
@@ -374,6 +375,19 @@ class DictRefsContainer(RefsContainer):
         del self._refs[name]
         return True
 
+    def get_peeled(self, name):
+        return self._peeled.get(name)
+
+    def _update(self, refs):
+        """Update multiple refs; intended only for testing."""
+        # TODO(dborowitz): replace this with a public function that uses
+        # set_if_equal.
+        self._refs.update(refs)
+
+    def _update_peeled(self, peeled):
+        """Update cached peeled refs; intended only for testing."""
+        self._peeled.update(peeled)
+
 
 class DiskRefsContainer(RefsContainer):
     """Refs container that reads refs from disk."""
@@ -924,20 +938,15 @@ class BaseRepo(object):
     def get_peeled(self, ref):
         """Get the peeled value of a ref.
 
-        :param ref: the refname to peel
-        :return: the fully-peeled SHA1 of a tag object, after peeling all
+        :param ref: The refname to peel.
+        :return: The fully-peeled SHA1 of a tag object, after peeling all
             intermediate tags; if the original ref does not point to a tag, this
             will equal the original SHA1.
         """
         cached = self.refs.get_peeled(ref)
         if cached is not None:
             return cached
-        obj = self[ref]
-        obj_class = object_class(obj.type_name)
-        while obj_class is Tag:
-            obj_class, sha = obj.object
-            obj = self.get_object(sha)
-        return obj.id
+        return self.object_store.peel_sha(self.refs[ref]).id
 
     def revision_history(self, head):
         """Returns a list of the commits reachable from head.

+ 96 - 53
dulwich/server.py

@@ -36,6 +36,7 @@ from dulwich.errors import (
     ApplyDeltaError,
     ChecksumMismatch,
     GitProtocolError,
+    UnexpectedCommandError,
     ObjectFormatException,
     )
 from dulwich import log_utils
@@ -57,6 +58,7 @@ from dulwich.protocol import (
     ack_type,
     extract_capabilities,
     extract_want_line_capabilities,
+    BufferedPktLineWriter,
     )
 from dulwich.repo import (
     Repo,
@@ -161,16 +163,20 @@ class Handler(object):
         self.proto = proto
         self._client_capabilities = None
 
-    def capability_line(self):
-        return " ".join(self.capabilities())
+    @classmethod
+    def capability_line(cls):
+        return " ".join(cls.capabilities())
 
-    def capabilities(self):
-        raise NotImplementedError(self.capabilities)
+    @classmethod
+    def capabilities(cls):
+        raise NotImplementedError(cls.capabilities)
 
-    def innocuous_capabilities(self):
+    @classmethod
+    def innocuous_capabilities(cls):
         return ("include-tag", "thin-pack", "no-progress", "ofs-delta")
 
-    def required_capabilities(self):
+    @classmethod
+    def required_capabilities(cls):
         """Return a list of capabilities that we require the client to have."""
         return []
 
@@ -206,11 +212,13 @@ class UploadPackHandler(Handler):
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
 
-    def capabilities(self):
+    @classmethod
+    def capabilities(cls):
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
                 "ofs-delta", "no-progress", "include-tag")
 
-    def required_capabilities(self):
+    @classmethod
+    def required_capabilities(cls):
         return ("side-band-64k", "thin-pack", "ofs-delta")
 
     def progress(self, message):
@@ -269,6 +277,41 @@ class UploadPackHandler(Handler):
         self.proto.write("0000")
 
 
+def _split_proto_line(line, allowed):
+    """Split a line read from the wire.
+
+    :param line: The line read from the wire.
+    :param allowed: An iterable of command names that should be allowed.
+        Command names not listed below as possible return values will be
+        ignored.  If None, any commands from the possible return values are
+        allowed.
+    :return: a tuple having one of the following forms:
+        ('want', obj_id)
+        ('have', obj_id)
+        ('done', None)
+        (None, None)  (for a flush-pkt)
+
+    :raise UnexpectedCommandError: if the line cannot be parsed into one of the
+        allowed return values.
+    """
+    if not line:
+        fields = [None]
+    else:
+        fields = line.rstrip('\n').split(' ', 1)
+    command = fields[0]
+    if allowed is not None and command not in allowed:
+        raise UnexpectedCommandError(command)
+    try:
+        if len(fields) == 1 and command in ('done', None):
+            return (command, None)
+        elif len(fields) == 2 and command in ('want', 'have'):
+            hex_to_sha(fields[1])
+            return tuple(fields)
+    except (TypeError, AssertionError), e:
+        raise GitProtocolError(e)
+    raise GitProtocolError('Received invalid line from client: %s' % line)
+
+
 class ProtocolGraphWalker(object):
     """A graph walker that knows the git protocol.
 
@@ -333,18 +376,16 @@ class ProtocolGraphWalker(object):
         line, caps = extract_want_line_capabilities(want)
         self.handler.set_client_capabilities(caps)
         self.set_ack_type(ack_type(caps))
-        command, sha = self._split_proto_line(line)
+        allowed = ('want', None)
+        command, sha = _split_proto_line(line, allowed)
 
         want_revs = []
         while command != None:
-            if command != 'want':
-                raise GitProtocolError(
-                  'Protocol got unexpected command %s' % command)
             if sha not in values:
                 raise GitProtocolError(
                   'Client wants invalid object %s' % sha)
             want_revs.append(sha)
-            command, sha = self.read_proto_line()
+            command, sha = self.read_proto_line(allowed)
 
         self.set_wants(want_revs)
         return want_revs
@@ -366,34 +407,14 @@ class ProtocolGraphWalker(object):
             return None
         return self._cache[self._cache_index]
 
-    def _split_proto_line(self, line):
-        fields = line.rstrip('\n').split(' ', 1)
-        if len(fields) == 1 and fields[0] == 'done':
-            return ('done', None)
-        elif len(fields) == 2 and fields[0] in ('want', 'have'):
-            try:
-                hex_to_sha(fields[1])
-                return tuple(fields)
-            except (TypeError, AssertionError), e:
-                raise GitProtocolError(e)
-        raise GitProtocolError('Received invalid line from client:\n%s' % line)
-
-    def read_proto_line(self):
+    def read_proto_line(self, allowed):
         """Read a line from the wire.
 
-        :return: a tuple having one of the following forms:
-            ('want', obj_id)
-            ('have', obj_id)
-            ('done', None)
-            (None, None)  (for a flush-pkt)
-
-        :raise GitProtocolError: if the line cannot be parsed into one of the
-            possible return values.
+        :param allowed: An iterable of command names that should be allowed.
+        :return: A tuple of (command, value); see _split_proto_line.
+        :raise GitProtocolError: If an error occurred reading the line.
         """
-        line = self.proto.read_pkt_line()
-        if not line:
-            return (None, None)
-        return self._split_proto_line(line)
+        return _split_proto_line(self.proto.read_pkt_line(), allowed)
 
     def send_ack(self, sha, ack_type=''):
         if ack_type:
@@ -457,6 +478,9 @@ class ProtocolGraphWalker(object):
         self._impl = impl_classes[ack_type](self)
 
 
+_GRAPH_WALKER_COMMANDS = ('have', 'done', None)
+
+
 class SingleAckGraphWalkerImpl(object):
     """Graph walker implementation that speaks the single-ack protocol."""
 
@@ -470,7 +494,7 @@ class SingleAckGraphWalkerImpl(object):
             self._sent_ack = True
 
     def next(self):
-        command, sha = self.walker.read_proto_line()
+        command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
         if command in (None, 'done'):
             if not self._sent_ack:
                 self.walker.send_nak()
@@ -497,7 +521,7 @@ class MultiAckGraphWalkerImpl(object):
 
     def next(self):
         while True:
-            command, sha = self.walker.read_proto_line()
+            command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
             if command is None:
                 self.walker.send_nak()
                 # in multi-ack mode, a flush-pkt indicates the client wants to
@@ -537,7 +561,7 @@ class MultiAckDetailedGraphWalkerImpl(object):
 
     def next(self):
         while True:
-            command, sha = self.walker.read_proto_line()
+            command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
             if command is None:
                 self.walker.send_nak()
                 if self.walker.stateless_rpc:
@@ -569,8 +593,9 @@ class ReceivePackHandler(Handler):
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
 
-    def capabilities(self):
-        return ("report-status", "delete-refs")
+    @classmethod
+    def capabilities(cls):
+        return ("report-status", "delete-refs", "side-band-64k")
 
     def _apply_pack(self, refs):
         f, commit = self.repo.object_store.add_thin_pack()
@@ -614,6 +639,29 @@ class ReceivePackHandler(Handler):
 
         return status
 
+    def _report_status(self, status):
+        if self.has_capability('side-band-64k'):
+            writer = BufferedPktLineWriter(
+              lambda d: self.proto.write_sideband(1, d))
+            write = writer.write
+
+            def flush():
+                writer.flush()
+                self.proto.write_pkt_line(None)
+        else:
+            write = self.proto.write_pkt_line
+            flush = lambda: None
+
+        for name, msg in status:
+            if name == 'unpack':
+                write('unpack %s\n' % msg)
+            elif msg == 'ok':
+                write('ok %s\n' % name)
+            else:
+                write('ng %s %s\n' % (name, msg))
+        write(None)
+        flush()
+
     def handle(self):
         refs = self.repo.get_refs().items()
 
@@ -654,14 +702,7 @@ class ReceivePackHandler(Handler):
         # when we have read all the pack from the client, send a status report
         # if the client asked for it
         if self.has_capability('report-status'):
-            for name, msg in status:
-                if name == 'unpack':
-                    self.proto.write_pkt_line('unpack %s\n' % msg)
-                elif msg == 'ok':
-                    self.proto.write_pkt_line('ok %s\n' % name)
-                else:
-                    self.proto.write_pkt_line('ng %s %s\n' % (name, msg))
-            self.proto.write_pkt_line(None)
+            self._report_status(status)
 
 
 # Default handler classes for git services.
@@ -674,7 +715,7 @@ DEFAULT_HANDLERS = {
 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
 
     def __init__(self, handlers, *args, **kwargs):
-        self.handlers = handlers and handlers or DEFAULT_HANDLERS
+        self.handlers = handlers
         SocketServer.StreamRequestHandler.__init__(self, *args, **kwargs)
 
     def handle(self):
@@ -698,8 +739,10 @@ class TCPGitServer(SocketServer.TCPServer):
         return TCPGitRequestHandler(self.handlers, *args, **kwargs)
 
     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT, handlers=None):
+        self.handlers = dict(DEFAULT_HANDLERS)
+        if handlers is not None:
+            self.handlers.update(handlers)
         self.backend = backend
-        self.handlers = handlers
         logger.info('Listening for TCP connections on %s:%d', listen_addr, port)
         SocketServer.TCPServer.__init__(self, (listen_addr, port),
                                         self._make_handler)

+ 40 - 17
dulwich/tests/compat/server_utils.py

@@ -24,12 +24,15 @@ import select
 import socket
 import threading
 
+from dulwich.server import (
+    ReceivePackHandler,
+    )
 from dulwich.tests.utils import (
     tear_down_repo,
     )
 from utils import (
     import_repo,
-    run_git,
+    run_git_or_fail,
     )
 
 
@@ -40,41 +43,49 @@ class ServerTests(object):
     """
 
     def setUp(self):
-        self._old_repo = import_repo('server_old.export')
-        self._new_repo = import_repo('server_new.export')
+        self._old_repo = None
+        self._new_repo = None
         self._server = None
 
     def tearDown(self):
         if self._server is not None:
             self._server.shutdown()
             self._server = None
-        tear_down_repo(self._old_repo)
-        tear_down_repo(self._new_repo)
+        if self._old_repo is not None:
+            tear_down_repo(self._old_repo)
+        if self._new_repo is not None:
+            tear_down_repo(self._new_repo)
+
+    def import_repos(self):
+        self._old_repo = import_repo('server_old.export')
+        self._new_repo = import_repo('server_new.export')
+
+    def url(self, port):
+        return '%s://localhost:%s/' % (self.protocol, port)
+
+    def branch_args(self, branches=None):
+        if branches is None:
+            branches = ['master', 'branch']
+        return ['%s:%s' % (b, b) for b in branches]
 
     def test_push_to_dulwich(self):
+        self.import_repos()
         self.assertReposNotEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._old_repo)
 
-        all_branches = ['master', 'branch']
-        branch_args = ['%s:%s' % (b, b) for b in all_branches]
-        url = '%s://localhost:%s/' % (self.protocol, port)
-        returncode, _ = run_git(['push', url] + branch_args,
-                                cwd=self._new_repo.path)
-        self.assertEqual(0, returncode)
+        run_git_or_fail(['push', self.url(port)] + self.branch_args(),
+                        cwd=self._new_repo.path)
         self.assertReposEqual(self._old_repo, self._new_repo)
 
     def test_fetch_from_dulwich(self):
+        self.import_repos()
         self.assertReposNotEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._new_repo)
 
-        all_branches = ['master', 'branch']
-        branch_args = ['%s:%s' % (b, b) for b in all_branches]
-        url = '%s://localhost:%s/' % (self.protocol, port)
-        returncode, _ = run_git(['fetch', url] + branch_args,
-                                cwd=self._old_repo.path)
+        run_git_or_fail(['fetch', self.url(port)] + self.branch_args(),
+                        cwd=self._old_repo.path)
         # flush the pack cache so any new packs are picked up
         self._old_repo.object_store._pack_cache = None
-        self.assertEqual(0, returncode)
         self.assertReposEqual(self._old_repo, self._new_repo)
 
 
@@ -155,3 +166,15 @@ class ShutdownServerMixIn:
             except:
                 self.handle_error(request, client_address)
                 self.close_request(request)
+
+
+# TODO(dborowitz): Come up with a better way of testing various permutations of
+# capabilities. The only reason it is the way it is now is that side-band-64k
+# was only recently introduced into git-receive-pack.
+class NoSideBand64kReceivePackHandler(ReceivePackHandler):
+    """ReceivePackHandler that does not support side-band-64k."""
+
+    @classmethod
+    def capabilities(cls):
+        return tuple(c for c in ReceivePackHandler.capabilities()
+                     if c != 'side-band-64k')

+ 5 - 5
dulwich/tests/compat/test_client.py

@@ -40,7 +40,7 @@ from utils import (
     CompatTestCase,
     check_for_daemon,
     import_repo_to_dir,
-    run_git,
+    run_git_or_fail,
     )
 
 class DulwichClientTestBase(object):
@@ -50,7 +50,7 @@ class DulwichClientTestBase(object):
         self.gitroot = os.path.dirname(import_repo_to_dir('server_new.export'))
         dest = os.path.join(self.gitroot, 'dest')
         file.ensure_dir_exists(dest)
-        run_git(['init', '--quiet', '--bare'], cwd=dest)
+        run_git_or_fail(['init', '--quiet', '--bare'], cwd=dest)
 
     def tearDown(self):
         shutil.rmtree(self.gitroot)
@@ -99,8 +99,8 @@ class DulwichClientTestBase(object):
     def disable_ff_and_make_dummy_commit(self):
         # disable non-fast-forward pushes to the server
         dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
-        run_git(['config', 'receive.denyNonFastForwards', 'true'],
-                cwd=dest.path)
+        run_git_or_fail(['config', 'receive.denyNonFastForwards', 'true'],
+                        cwd=dest.path)
         b = objects.Blob.from_string('hi')
         dest.object_store.add_object(b)
         t = index.commit_tree(dest.object_store, [('hi', b.id, 0100644)])
@@ -176,7 +176,7 @@ class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
         fd, self.pidfile = tempfile.mkstemp(prefix='dulwich-test-git-client',
                                             suffix=".pid")
         os.fdopen(fd).close()
-        run_git(
+        run_git_or_fail(
             ['daemon', '--verbose', '--export-all',
              '--pid-file=%s' % self.pidfile, '--base-path=%s' % self.gitroot,
              '--detach', '--reuseaddr', '--enable=receive-pack',

+ 2 - 5
dulwich/tests/compat/test_pack.py

@@ -34,7 +34,7 @@ from dulwich.tests.test_pack import (
     )
 from utils import (
     require_git_version,
-    run_git,
+    run_git_or_fail,
     )
 
 
@@ -56,10 +56,7 @@ class TestPack(PackTests):
         pack_path = os.path.join(self._tempdir, "Elch")
         write_pack(pack_path, [(x, "") for x in origpack.iterobjects()],
                    len(origpack))
-
-        returncode, output = run_git(['verify-pack', '-v', pack_path],
-                                     capture_stdout=True)
-        self.assertEquals(0, returncode)
+        output = run_git_or_fail(['verify-pack', '-v', pack_path])
 
         pack_shas = set()
         for line in output.splitlines():

+ 2 - 5
dulwich/tests/compat/test_repository.py

@@ -35,7 +35,7 @@ from dulwich.tests.utils import (
     )
 
 from utils import (
-    run_git,
+    run_git_or_fail,
     import_repo,
     CompatTestCase,
     )
@@ -53,10 +53,7 @@ class ObjectStoreTestCase(CompatTestCase):
         tear_down_repo(self._repo)
 
     def _run_git(self, args):
-        returncode, output = run_git(args, capture_stdout=True,
-                                     cwd=self._repo.path)
-        self.assertEqual(0, returncode)
-        return output
+        return run_git_or_fail(args, cwd=self._repo.path)
 
     def _parse_refs(self, output):
         refs = {}

+ 32 - 2
dulwich/tests/compat/test_server.py

@@ -29,10 +29,12 @@ import threading
 from dulwich.server import (
     DictBackend,
     TCPGitServer,
+    ReceivePackHandler,
     )
 from server_utils import (
     ServerTests,
     ShutdownServerMixIn,
+    NoSideBand64kReceivePackHandler,
     )
 from utils import (
     CompatTestCase,
@@ -54,7 +56,10 @@ if not getattr(TCPGitServer, 'shutdown', None):
 
 
 class GitServerTestCase(ServerTests, CompatTestCase):
-    """Tests for client/server compatibility."""
+    """Tests for client/server compatibility.
+
+    This server test case does not use side-band-64k in git-receive-pack.
+    """
 
     protocol = 'git'
 
@@ -66,10 +71,35 @@ class GitServerTestCase(ServerTests, CompatTestCase):
         ServerTests.tearDown(self)
         CompatTestCase.tearDown(self)
 
+    def _handlers(self):
+        return {'git-receive-pack': NoSideBand64kReceivePackHandler}
+
+    def _check_server(self, dul_server):
+        receive_pack_handler_cls = dul_server.handlers['git-receive-pack']
+        caps = receive_pack_handler_cls.capabilities()
+        self.assertFalse('side-band-64k' in caps)
+
     def _start_server(self, repo):
         backend = DictBackend({'/': repo})
-        dul_server = TCPGitServer(backend, 'localhost', 0)
+        dul_server = TCPGitServer(backend, 'localhost', 0,
+                                  handlers=self._handlers())
+        self._check_server(dul_server)
         threading.Thread(target=dul_server.serve).start()
         self._server = dul_server
         _, port = self._server.socket.getsockname()
         return port
+
+
+class GitServerSideBand64kTestCase(GitServerTestCase):
+    """Tests for client/server compatibility with side-band-64k support."""
+
+    # side-band-64k in git-receive-pack was introduced in git 1.7.0.2
+    min_git_version = (1, 7, 0, 2)
+
+    def _handlers(self):
+        return None  # default handlers include side-band-64k
+
+    def _check_server(self, server):
+        receive_pack_handler_cls = server.handlers['git-receive-pack']
+        caps = receive_pack_handler_cls.capabilities()
+        self.assertTrue('side-band-64k' in caps)

+ 91 - 0
dulwich/tests/compat/test_utils.py

@@ -0,0 +1,91 @@
+# test_utils.py -- Tests for git compatibility utilities
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+
+# as published by the Free Software Foundation; either version 2
+# of the License, or (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor,
+# Boston, MA  02110-1301, USA.
+
+"""Tests for git compatibility utilities."""
+
+from unittest import TestCase
+
+from dulwich.tests import (
+    TestSkipped,
+    )
+import utils
+
+
+class GitVersionTests(TestCase):
+
+    def setUp(self):
+        self._orig_run_git = utils.run_git
+        self._version_str = None  # tests can override to set stub version
+
+        def run_git(args, **unused_kwargs):
+            self.assertEqual(['--version'], args)
+            return 0, self._version_str
+        utils.run_git = run_git
+
+    def tearDown(self):
+        utils.run_git = self._orig_run_git
+
+    def test_git_version_none(self):
+        self._version_str = 'not a git version'
+        self.assertEqual(None, utils.git_version())
+
+    def test_git_version_3(self):
+        self._version_str = 'git version 1.6.6'
+        self.assertEqual((1, 6, 6, 0), utils.git_version())
+
+    def test_git_version_4(self):
+        self._version_str = 'git version 1.7.0.2'
+        self.assertEqual((1, 7, 0, 2), utils.git_version())
+
+    def test_git_version_extra(self):
+        self._version_str = 'git version 1.7.0.3.295.gd8fa2'
+        self.assertEqual((1, 7, 0, 3), utils.git_version())
+
+    def assertRequireSucceeds(self, required_version):
+        try:
+            utils.require_git_version(required_version)
+        except TestSkipped:
+            self.fail()
+
+    def assertRequireFails(self, required_version):
+        self.assertRaises(TestSkipped, utils.require_git_version,
+                          required_version)
+
+    def test_require_git_version(self):
+        try:
+            self._version_str = 'git version 1.6.6'
+            self.assertRequireSucceeds((1, 6, 6))
+            self.assertRequireSucceeds((1, 6, 6, 0))
+            self.assertRequireSucceeds((1, 6, 5))
+            self.assertRequireSucceeds((1, 6, 5, 99))
+            self.assertRequireFails((1, 7, 0))
+            self.assertRequireFails((1, 7, 0, 2))
+            self.assertRaises(ValueError, utils.require_git_version,
+                              (1, 6, 6, 0, 0))
+
+            self._version_str = 'git version 1.7.0.2'
+            self.assertRequireSucceeds((1, 6, 6))
+            self.assertRequireSucceeds((1, 6, 6, 0))
+            self.assertRequireSucceeds((1, 7, 0))
+            self.assertRequireSucceeds((1, 7, 0, 2))
+            self.assertRequireFails((1, 7, 0, 3))
+            self.assertRequireFails((1, 7, 1))
+        except TestSkipped, e:
+            # This test is designed to catch all TestSkipped exceptions.
+            self.fail('Test unexpectedly skipped: %s' % e)

+ 31 - 2
dulwich/tests/compat/test_web.py

@@ -42,6 +42,7 @@ from dulwich.web import (
 from server_utils import (
     ServerTests,
     ShutdownServerMixIn,
+    NoSideBand64kReceivePackHandler,
     )
 from utils import (
     CompatTestCase,
@@ -84,7 +85,10 @@ class WebTests(ServerTests):
 
 
 class SmartWebTestCase(WebTests, CompatTestCase):
-    """Test cases for smart HTTP server."""
+    """Test cases for smart HTTP server.
+
+    This server test case does not use side-band-64k in git-receive-pack.
+    """
 
     min_git_version = (1, 6, 6)
 
@@ -96,8 +100,33 @@ class SmartWebTestCase(WebTests, CompatTestCase):
         WebTests.tearDown(self)
         CompatTestCase.tearDown(self)
 
+    def _handlers(self):
+        return {'git-receive-pack': NoSideBand64kReceivePackHandler}
+
+    def _check_app(self, app):
+        receive_pack_handler_cls = app.handlers['git-receive-pack']
+        caps = receive_pack_handler_cls.capabilities()
+        self.assertFalse('side-band-64k' in caps)
+
     def _make_app(self, backend):
-        return HTTPGitApplication(backend)
+        app = HTTPGitApplication(backend, handlers=self._handlers())
+        self._check_app(app)
+        return app
+
+
+class SmartWebSideBand64kTestCase(SmartWebTestCase):
+    """Test cases for smart HTTP server with side-band-64k support."""
+
+    # side-band-64k in git-receive-pack was introduced in git 1.7.0.2
+    min_git_version = (1, 7, 0, 2)
+
+    def _handlers(self):
+        return None  # default handlers include side-band-64k
+
+    def _check_app(self, app):
+        receive_pack_handler_cls = app.handlers['git-receive-pack']
+        caps = receive_pack_handler_cls.capabilities()
+        self.assertTrue('side-band-64k' in caps)
 
 
 class DumbWebTestCase(WebTests, CompatTestCase):

+ 34 - 15
dulwich/tests/compat/utils.py

@@ -35,6 +35,7 @@ from dulwich.tests import (
     )
 
 _DEFAULT_GIT = 'git'
+_VERSION_LEN = 4
 
 
 def git_version(git_path=_DEFAULT_GIT):
@@ -42,32 +43,50 @@ def git_version(git_path=_DEFAULT_GIT):
 
     :param git_path: Path to the git executable; defaults to the version in
         the system path.
-    :return: A tuple of ints of the form (major, minor, point), or None if no
-        git installation was found.
+    :return: A tuple of ints of the form (major, minor, point, sub-point), or
+        None if no git installation was found.
     """
     try:
-        _, output = run_git(['--version'], git_path=git_path,
-                            capture_stdout=True)
+        output = run_git_or_fail(['--version'], git_path=git_path)
     except OSError:
         return None
     version_prefix = 'git version '
     if not output.startswith(version_prefix):
         return None
-    output = output[len(version_prefix):]
-    nums = output.split('.')
-    if len(nums) == 2:
-        nums.add('0')
-    else:
-        nums = nums[:3]
-    try:
-        return tuple(int(x) for x in nums)
-    except ValueError:
-        return None
+
+    parts = output[len(version_prefix):].split('.')
+    nums = []
+    for part in parts:
+        try:
+            nums.append(int(part))
+        except ValueError:
+            break
+
+    while len(nums) < _VERSION_LEN:
+        nums.append(0)
+    return tuple(nums[:_VERSION_LEN])
 
 
 def require_git_version(required_version, git_path=_DEFAULT_GIT):
-    """Require git version >= version, or skip the calling test."""
+    """Require git version >= version, or skip the calling test.
+
+    :param required_version: A tuple of ints of the form (major, minor, point,
+        sub-point); ommitted components default to 0.
+    :param git_path: Path to the git executable; defaults to the version in
+        the system path.
+    :raise ValueError: if the required version tuple has too many parts.
+    :raise TestSkipped: if no suitable git version was found at the given path.
+    """
     found_version = git_version(git_path=git_path)
+    if len(required_version) > _VERSION_LEN:
+        raise ValueError('Invalid version tuple %s, expected %i parts' %
+                         (required_version, _VERSION_LEN))
+
+    required_version = list(required_version)
+    while len(found_version) < len(required_version):
+        required_version.append(0)
+    required_version = tuple(required_version)
+
     if found_version < required_version:
         required_version = '.'.join(map(str, required_version))
         found_version = '.'.join(map(str, found_version))

+ 16 - 16
dulwich/tests/test_fastexport.py

@@ -20,9 +20,6 @@
 from cStringIO import StringIO
 import stat
 
-from dulwich.fastexport import (
-    FastExporter,
-    )
 from dulwich.object_store import (
     MemoryObjectStore,
     )
@@ -33,48 +30,51 @@ from dulwich.objects import (
     )
 from dulwich.tests import (
     TestCase,
+    TestSkipped,
     )
 
 
-class FastExporterTests(TestCase):
+class GitFastExporterTests(TestCase):
 
     def setUp(self):
-        super(FastExporterTests, self).setUp()
+        super(GitFastExporterTests, self).setUp()
         self.store = MemoryObjectStore()
         self.stream = StringIO()
-        self.fastexporter = FastExporter(self.stream, self.store)
+        try:
+            from dulwich.fastexport import GitFastExporter
+        except ImportError:
+            raise TestSkipped("python-fastimport not available")
+        self.fastexporter = GitFastExporter(self.stream, self.store)
 
-    def test_export_blob(self):
+    def test_emit_blob(self):
         b = Blob()
         b.data = "fooBAR"
-        self.assertEquals(1, self.fastexporter.export_blob(b))
+        self.fastexporter.emit_blob(b)
         self.assertEquals('blob\nmark :1\ndata 6\nfooBAR\n',
             self.stream.getvalue())
 
-    def test_export_commit(self):
+    def test_emit_commit(self):
         b = Blob()
         b.data = "FOO"
         t = Tree()
         t.add(stat.S_IFREG | 0644, "foo", b.id)
         c = Commit()
         c.committer = c.author = "Jelmer <jelmer@host>"
-        c.author_time = c.commit_time = 1271345553.47
+        c.author_time = c.commit_time = 1271345553
         c.author_timezone = c.commit_timezone = 0
         c.message = "msg"
         c.tree = t.id
         self.store.add_objects([(b, None), (t, None), (c, None)])
-        self.assertEquals(2,
-                self.fastexporter.export_commit(c, "refs/heads/master"))
+        self.fastexporter.emit_commit(c, "refs/heads/master")
         self.assertEquals("""blob
 mark :1
 data 3
 FOO
 commit refs/heads/master
 mark :2
-author Jelmer <jelmer@host> 1271345553.47 +0000
-committer Jelmer <jelmer@host> 1271345553.47 +0000
+author Jelmer <jelmer@host> 1271345553 +0000
+committer Jelmer <jelmer@host> 1271345553 +0000
 data 3
 msg
-M 100644 :1 foo
-
+M 644 1 foo
 """, self.stream.getvalue())

+ 5 - 0
dulwich/tests/test_file.py

@@ -119,6 +119,11 @@ class GitFileTests(TestCase):
         self.assertEquals('contents', f.read())
         f.close()
 
+    def test_default_mode(self):
+        f = GitFile(self.path('foo'))
+        self.assertEquals('foo contents', f.read())
+        f.close()
+
     def test_write(self):
         foo = self.path('foo')
         foo_lock = '%s.lock' % foo

+ 132 - 0
dulwich/tests/test_object_store.py

@@ -23,12 +23,26 @@ import os
 import shutil
 import tempfile
 
+from dulwich.index import (
+    commit_tree,
+    )
+from dulwich.errors import (
+    NotTreeError,
+    )
 from dulwich.objects import (
+    object_class,
     Blob,
+    ShaFile,
+    Tag,
+    Tree,
     )
 from dulwich.object_store import (
     DiskObjectStore,
     MemoryObjectStore,
+    tree_lookup_path,
+    )
+from dulwich.pack import (
+    write_pack_data,
     )
 from dulwich.tests import (
     TestCase,
@@ -75,6 +89,68 @@ class ObjectStoreTests(object):
         r = self.store[testobject.id]
         self.assertEquals(r, testobject)
 
+    def test_iter_tree_contents(self):
+        blob_a = make_object(Blob, data='a')
+        blob_b = make_object(Blob, data='b')
+        blob_c = make_object(Blob, data='c')
+        for blob in [blob_a, blob_b, blob_c]:
+            self.store.add_object(blob)
+
+        blobs = [
+          ('a', blob_a.id, 0100644),
+          ('ad/b', blob_b.id, 0100644),
+          ('ad/bd/c', blob_c.id, 0100755),
+          ('ad/c', blob_c.id, 0100644),
+          ('c', blob_c.id, 0100644),
+          ]
+        tree_id = commit_tree(self.store, blobs)
+        self.assertEquals([(p, m, h) for (p, h, m) in blobs],
+                          list(self.store.iter_tree_contents(tree_id)))
+
+    def test_iter_tree_contents_include_trees(self):
+        blob_a = make_object(Blob, data='a')
+        blob_b = make_object(Blob, data='b')
+        blob_c = make_object(Blob, data='c')
+        for blob in [blob_a, blob_b, blob_c]:
+            self.store.add_object(blob)
+
+        blobs = [
+          ('a', blob_a.id, 0100644),
+          ('ad/b', blob_b.id, 0100644),
+          ('ad/bd/c', blob_c.id, 0100755),
+          ]
+        tree_id = commit_tree(self.store, blobs)
+        tree = self.store[tree_id]
+        tree_ad = self.store[tree['ad'][1]]
+        tree_bd = self.store[tree_ad['bd'][1]]
+
+        expected = [
+          ('', 0040000, tree_id),
+          ('a', 0100644, blob_a.id),
+          ('ad', 0040000, tree_ad.id),
+          ('ad/b', 0100644, blob_b.id),
+          ('ad/bd', 0040000, tree_bd.id),
+          ('ad/bd/c', 0100755, blob_c.id),
+          ]
+        actual = self.store.iter_tree_contents(tree_id, include_trees=True)
+        self.assertEquals(expected, list(actual))
+
+    def make_tag(self, name, obj):
+        tag = make_object(Tag, name=name, message='',
+                          tag_time=12345, tag_timezone=0,
+                          tagger='Test Tagger <test@example.com>',
+                          object=(object_class(obj.type_name), obj.id))
+        self.store.add_object(tag)
+        return tag
+
+    def test_peel_sha(self):
+        self.store.add_object(testobject)
+        tag1 = self.make_tag('1', testobject)
+        tag2 = self.make_tag('2', testobject)
+        tag3 = self.make_tag('3', testobject)
+        for obj in [testobject, tag1, tag2, tag3]:
+            self.assertEqual(testobject, self.store.peel_sha(obj.id))
+
 
 class MemoryObjectStoreTests(ObjectStoreTests, TestCase):
 
@@ -114,4 +190,60 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
         o = DiskObjectStore(self.store_dir)
         self.assertEquals(os.path.join(self.store_dir, "pack"), o.pack_dir)
 
+    def test_add_pack(self):
+        o = DiskObjectStore(self.store_dir)
+        f, commit = o.add_pack()
+        b = make_object(Blob, data="more yummy data")
+        write_pack_data(f, [(b, None)], 1)
+        commit()
+
+    def test_add_thin_pack(self):
+        o = DiskObjectStore(self.store_dir)
+        f, commit = o.add_thin_pack()
+        b = make_object(Blob, data="more yummy data")
+        write_pack_data(f, [(b, None)], 1)
+        commit()
+
+
+class TreeLookupPathTests(TestCase):
+
+    def setUp(self):
+        TestCase.setUp(self)
+        self.store = MemoryObjectStore()
+        blob_a = make_object(Blob, data='a')
+        blob_b = make_object(Blob, data='b')
+        blob_c = make_object(Blob, data='c')
+        for blob in [blob_a, blob_b, blob_c]:
+            self.store.add_object(blob)
+
+        blobs = [
+          ('a', blob_a.id, 0100644),
+          ('ad/b', blob_b.id, 0100644),
+          ('ad/bd/c', blob_c.id, 0100755),
+          ('ad/c', blob_c.id, 0100644),
+          ('c', blob_c.id, 0100644),
+          ]
+        self.tree_id = commit_tree(self.store, blobs)
+
+    def get_object(self, sha):
+        return self.store[sha]
+
+    def test_lookup_blob(self):
+        o_id = tree_lookup_path(self.get_object, self.tree_id, 'a')[1]
+        self.assertTrue(isinstance(self.store[o_id], Blob))
+
+    def test_lookup_tree(self):
+        o_id = tree_lookup_path(self.get_object, self.tree_id, 'ad')[1]
+        self.assertTrue(isinstance(self.store[o_id], Tree))
+        o_id = tree_lookup_path(self.get_object, self.tree_id, 'ad/bd')[1]
+        self.assertTrue(isinstance(self.store[o_id], Tree))
+        o_id = tree_lookup_path(self.get_object, self.tree_id, 'ad/bd/')[1]
+        self.assertTrue(isinstance(self.store[o_id], Tree))
+
+    def test_lookup_nonexistent(self):
+        self.assertRaises(KeyError, tree_lookup_path, self.get_object, self.tree_id, 'j')
+
+    def test_lookup_not_tree(self):
+        self.assertRaises(NotTreeError, tree_lookup_path, self.get_object, self.tree_id, 'ad/b/j')
+
 # TODO: MissingObjectFinderTests

+ 7 - 1
dulwich/tests/test_objects.py

@@ -151,7 +151,6 @@ class BlobReadTests(TestCase):
     def test_legacy_from_file(self):
         b1 = Blob.from_string("foo")
         b_raw = b1.as_legacy_object()
-        open('x', 'w+').write(b_raw)
         b2 = b1.from_file(StringIO(b_raw))
         self.assertEquals(b1, b2)
 
@@ -235,6 +234,13 @@ class BlobReadTests(TestCase):
         self.assertEqual(c.author_timezone, 0)
         self.assertEqual(c.message, 'Merge ../b\n')
 
+    def test_stub_sha(self):
+        sha = '5' * 40
+        c = make_commit(id=sha, message='foo')
+        self.assertTrue(isinstance(c, Commit))
+        self.assertEqual(sha, c.id)
+        self.assertNotEqual(sha, c._make_sha())
+
 
 class ShaFileCheckTests(TestCase):
 

+ 79 - 28
dulwich/tests/test_pack.py

@@ -38,12 +38,15 @@ from dulwich.objects import (
     Tree,
     )
 from dulwich.pack import (
+    MemoryPackIndex,
     Pack,
     PackData,
+    ThinPackData,
     apply_delta,
     create_delta,
     load_pack_index,
     read_zlib_chunks,
+    write_pack_header,
     write_pack_index_v1,
     write_pack_index_v2,
     write_pack,
@@ -162,6 +165,25 @@ class TestPackData(PackTests):
     def test_create_pack(self):
         p = self.get_pack_data(pack1_sha)
 
+    def test_from_file(self):
+        path = os.path.join(self.datadir, 'pack-%s.pack' % pack1_sha)
+        PackData.from_file(open(path), os.path.getsize(path))
+
+    # TODO: more ThinPackData tests.
+    def test_thin_from_file(self):
+        test_sha = '1' * 40
+
+        def resolve(sha):
+            self.assertEqual(test_sha, sha)
+            return 3, 'data'
+
+        path = os.path.join(self.datadir, 'pack-%s.pack' % pack1_sha)
+        data = ThinPackData.from_file(resolve, open(path),
+                                      os.path.getsize(path))
+        idx = self.get_pack_index(pack1_sha)
+        Pack.from_objects(data, idx)
+        self.assertEqual((None, 3, 'data'), data.get_ref(test_sha))
+
     def test_pack_len(self):
         p = self.get_pack_data(pack1_sha)
         self.assertEquals(3, len(p))
@@ -277,16 +299,19 @@ class TestPack(PackTests):
         self.assertEquals(pack1_sha, p.name())
 
 
-pack_checksum = hex_to_sha('721980e866af9a5f93ad674144e1459b8ba3e7b7')
+class WritePackHeaderTests(TestCase):
 
+    def test_simple(self):
+        f = StringIO()
+        write_pack_header(f, 42)
+        self.assertEquals('PACK\x00\x00\x00\x02\x00\x00\x00*',
+                f.getvalue())
 
-class BaseTestPackIndexWriting(object):
 
-    def setUp(self):
-        self.tempdir = tempfile.mkdtemp()
+pack_checksum = hex_to_sha('721980e866af9a5f93ad674144e1459b8ba3e7b7')
 
-    def tearDown(self):
-        shutil.rmtree(self.tempdir)
+
+class BaseTestPackIndexWriting(object):
 
     def assertSucceeds(self, func, *args, **kwargs):
         try:
@@ -294,30 +319,18 @@ class BaseTestPackIndexWriting(object):
         except ChecksumMismatch, e:
             self.fail(e)
 
-    def writeIndex(self, filename, entries, pack_checksum):
-        # FIXME: Write to StringIO instead rather than hitting disk ?
-        f = GitFile(filename, "wb")
-        try:
-            self._write_fn(f, entries, pack_checksum)
-        finally:
-            f.close()
+    def index(self, filename, entries, pack_checksum):
+        raise NotImplementedError(self.index)
 
     def test_empty(self):
-        filename = os.path.join(self.tempdir, 'empty.idx')
-        self.writeIndex(filename, [], pack_checksum)
-        idx = load_pack_index(filename)
-        self.assertSucceeds(idx.check)
+        idx = self.index('empty.idx', [], pack_checksum)
         self.assertEquals(idx.get_pack_checksum(), pack_checksum)
         self.assertEquals(0, len(idx))
 
     def test_single(self):
         entry_sha = hex_to_sha('6f670c0fb53f9463760b7295fbb814e965fb20c8')
         my_entries = [(entry_sha, 178, 42)]
-        filename = os.path.join(self.tempdir, 'single.idx')
-        self.writeIndex(filename, my_entries, pack_checksum)
-        idx = load_pack_index(filename)
-        self.assertEquals(idx.version, self._expected_version)
-        self.assertSucceeds(idx.check)
+        idx = self.index('single.idx', my_entries, pack_checksum)
         self.assertEquals(idx.get_pack_checksum(), pack_checksum)
         self.assertEquals(1, len(idx))
         actual_entries = list(idx.iterentries())
@@ -333,32 +346,70 @@ class BaseTestPackIndexWriting(object):
                 self.assertTrue(actual_crc is None)
 
 
-class TestPackIndexWritingv1(TestCase, BaseTestPackIndexWriting):
+class BaseTestFilePackIndexWriting(BaseTestPackIndexWriting):
+
+    def setUp(self):
+        self.tempdir = tempfile.mkdtemp()
+
+    def tearDown(self):
+        shutil.rmtree(self.tempdir)
+
+    def index(self, filename, entries, pack_checksum):
+        path = os.path.join(self.tempdir, filename)
+        self.writeIndex(path, entries, pack_checksum)
+        idx = load_pack_index(path)
+        self.assertSucceeds(idx.check)
+        self.assertEquals(idx.version, self._expected_version)
+        return idx
+
+    def writeIndex(self, filename, entries, pack_checksum):
+        # FIXME: Write to StringIO instead rather than hitting disk ?
+        f = GitFile(filename, "wb")
+        try:
+            self._write_fn(f, entries, pack_checksum)
+        finally:
+            f.close()
+
+
+class TestMemoryIndexWriting(TestCase, BaseTestPackIndexWriting):
+
+    def setUp(self):
+        TestCase.setUp(self)
+        self._has_crc32_checksum = True
+
+    def index(self, filename, entries, pack_checksum):
+        return MemoryPackIndex(entries, pack_checksum)
+
+    def tearDown(self):
+        TestCase.tearDown(self)
+
+
+class TestPackIndexWritingv1(TestCase, BaseTestFilePackIndexWriting):
 
     def setUp(self):
         TestCase.setUp(self)
-        BaseTestPackIndexWriting.setUp(self)
+        BaseTestFilePackIndexWriting.setUp(self)
         self._has_crc32_checksum = False
         self._expected_version = 1
         self._write_fn = write_pack_index_v1
 
     def tearDown(self):
         TestCase.tearDown(self)
-        BaseTestPackIndexWriting.tearDown(self)
+        BaseTestFilePackIndexWriting.tearDown(self)
 
 
-class TestPackIndexWritingv2(TestCase, BaseTestPackIndexWriting):
+class TestPackIndexWritingv2(TestCase, BaseTestFilePackIndexWriting):
 
     def setUp(self):
         TestCase.setUp(self)
-        BaseTestPackIndexWriting.setUp(self)
+        BaseTestFilePackIndexWriting.setUp(self)
         self._has_crc32_checksum = True
         self._expected_version = 2
         self._write_fn = write_pack_index_v2
 
     def tearDown(self):
         TestCase.tearDown(self)
-        BaseTestPackIndexWriting.tearDown(self)
+        BaseTestFilePackIndexWriting.tearDown(self)
 
 
 class ReadZlibTests(TestCase):

+ 26 - 1
dulwich/tests/test_patch.py

@@ -28,7 +28,9 @@ from dulwich.patch import (
     git_am_patch_split,
     write_commit_patch,
     )
-from dulwich.tests import TestCase
+from dulwich.tests import (
+    TestCase,
+    )
 
 
 class WriteCommitPatchTests(TestCase):
@@ -80,9 +82,32 @@ Subject: [PATCH 1/2] Remove executable bit from prey.ico (triggers a lintian war
         c, diff, version = git_am_patch_split(StringIO(text))
         self.assertEquals("Jelmer Vernooij <jelmer@samba.org>", c.committer)
         self.assertEquals("Jelmer Vernooij <jelmer@samba.org>", c.author)
+        self.assertEquals("Remove executable bit from prey.ico "
+            "(triggers a lintian warning).\n", c.message)
         self.assertEquals(""" pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
  1 files changed, 0 insertions(+), 0 deletions(-)
  mode change 100755 => 100644 pixmaps/prey.ico
 
 """, diff)
         self.assertEquals("1.7.0.4", version)
+
+    def test_extract_spaces(self):
+        text = """From ff643aae102d8870cac88e8f007e70f58f3a7363 Mon Sep 17 00:00:00 2001
+From: Jelmer Vernooij <jelmer@samba.org>
+Date: Thu, 15 Apr 2010 15:40:28 +0200
+Subject:  [Dulwich-users] [PATCH] Added unit tests for
+ dulwich.object_store.tree_lookup_path.
+
+* dulwich/tests/test_object_store.py
+  (TreeLookupPathTests): This test case contains a few tests that ensure the
+   tree_lookup_path function works as expected.
+---
+ pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
+ 1 files changed, 0 insertions(+), 0 deletions(-)
+ mode change 100755 => 100644 pixmaps/prey.ico
+
+-- 
+1.7.0.4
+"""
+        c, diff, version = git_am_patch_split(StringIO(text))
+        self.assertEquals('Added unit tests for dulwich.object_store.tree_lookup_path.\n\n* dulwich/tests/test_object_store.py\n  (TreeLookupPathTests): This test case contains a few tests that ensure the\n   tree_lookup_path function works as expected.\n', c.message)

+ 91 - 35
dulwich/tests/test_protocol.py

@@ -30,6 +30,7 @@ from dulwich.protocol import (
     SINGLE_ACK,
     MULTI_ACK,
     MULTI_ACK_DETAILED,
+    BufferedPktLineWriter,
     )
 from dulwich.tests import TestCase
 
@@ -38,42 +39,42 @@ class BaseProtocolTests(object):
 
     def test_write_pkt_line_none(self):
         self.proto.write_pkt_line(None)
-        self.assertEquals(self.rout.getvalue(), "0000")
+        self.assertEquals(self.rout.getvalue(), '0000')
 
     def test_write_pkt_line(self):
-        self.proto.write_pkt_line("bla")
-        self.assertEquals(self.rout.getvalue(), "0007bla")
+        self.proto.write_pkt_line('bla')
+        self.assertEquals(self.rout.getvalue(), '0007bla')
 
     def test_read_pkt_line(self):
-        self.rin.write("0008cmd ")
+        self.rin.write('0008cmd ')
         self.rin.seek(0)
-        self.assertEquals("cmd ", self.proto.read_pkt_line())
+        self.assertEquals('cmd ', self.proto.read_pkt_line())
 
     def test_read_pkt_seq(self):
-        self.rin.write("0008cmd 0005l0000")
+        self.rin.write('0008cmd 0005l0000')
         self.rin.seek(0)
-        self.assertEquals(["cmd ", "l"], list(self.proto.read_pkt_seq()))
+        self.assertEquals(['cmd ', 'l'], list(self.proto.read_pkt_seq()))
 
     def test_read_pkt_line_none(self):
-        self.rin.write("0000")
+        self.rin.write('0000')
         self.rin.seek(0)
         self.assertEquals(None, self.proto.read_pkt_line())
 
     def test_write_sideband(self):
-        self.proto.write_sideband(3, "bloe")
-        self.assertEquals(self.rout.getvalue(), "0009\x03bloe")
+        self.proto.write_sideband(3, 'bloe')
+        self.assertEquals(self.rout.getvalue(), '0009\x03bloe')
 
     def test_send_cmd(self):
-        self.proto.send_cmd("fetch", "a", "b")
-        self.assertEquals(self.rout.getvalue(), "000efetch a\x00b\x00")
+        self.proto.send_cmd('fetch', 'a', 'b')
+        self.assertEquals(self.rout.getvalue(), '000efetch a\x00b\x00')
 
     def test_read_cmd(self):
-        self.rin.write("0012cmd arg1\x00arg2\x00")
+        self.rin.write('0012cmd arg1\x00arg2\x00')
         self.rin.seek(0)
-        self.assertEquals(("cmd", ["arg1", "arg2"]), self.proto.read_cmd())
+        self.assertEquals(('cmd', ['arg1', 'arg2']), self.proto.read_cmd())
 
     def test_read_cmd_noend0(self):
-        self.rin.write("0011cmd arg1\x00arg2")
+        self.rin.write('0011cmd arg1\x00arg2')
         self.rin.seek(0)
         self.assertRaises(AssertionError, self.proto.read_cmd)
 
@@ -94,7 +95,7 @@ class ReceivableStringIO(StringIO):
         # fail fast if no bytes are available; in a real socket, this would
         # block forever
         if self.tell() == len(self.getvalue()):
-            raise AssertionError("Blocking read past end of socket")
+            raise AssertionError('Blocking read past end of socket')
         if size == 1:
             return self.read(1)
         # calls shouldn't return quite as much as asked for
@@ -111,10 +112,10 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         self.proto._rbufsize = 8
 
     def test_recv(self):
-        all_data = "1234567" * 10  # not a multiple of bufsize
+        all_data = '1234567' * 10  # not a multiple of bufsize
         self.rin.write(all_data)
         self.rin.seek(0)
-        data = ""
+        data = ''
         # We ask for 8 bytes each time and actually read 7, so it should take
         # exactly 10 iterations.
         for _ in xrange(10):
@@ -124,28 +125,28 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         self.assertEquals(all_data, data)
 
     def test_recv_read(self):
-        all_data = "1234567"  # recv exactly in one call
+        all_data = '1234567'  # recv exactly in one call
         self.rin.write(all_data)
         self.rin.seek(0)
-        self.assertEquals("1234", self.proto.recv(4))
-        self.assertEquals("567", self.proto.read(3))
+        self.assertEquals('1234', self.proto.recv(4))
+        self.assertEquals('567', self.proto.read(3))
         self.assertRaises(AssertionError, self.proto.recv, 10)
 
     def test_read_recv(self):
-        all_data = "12345678abcdefg"
+        all_data = '12345678abcdefg'
         self.rin.write(all_data)
         self.rin.seek(0)
-        self.assertEquals("1234", self.proto.read(4))
-        self.assertEquals("5678abc", self.proto.recv(8))
-        self.assertEquals("defg", self.proto.read(4))
+        self.assertEquals('1234', self.proto.read(4))
+        self.assertEquals('5678abc', self.proto.recv(8))
+        self.assertEquals('defg', self.proto.read(4))
         self.assertRaises(AssertionError, self.proto.recv, 10)
 
     def test_mixed(self):
         # arbitrary non-repeating string
-        all_data = ",".join(str(i) for i in xrange(100))
+        all_data = ','.join(str(i) for i in xrange(100))
         self.rin.write(all_data)
         self.rin.seek(0)
-        data = ""
+        data = ''
 
         for i in xrange(1, 100):
             data += self.proto.recv(i)
@@ -168,20 +169,20 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
 class CapabilitiesTestCase(TestCase):
 
     def test_plain(self):
-        self.assertEquals(("bla", []), extract_capabilities("bla"))
+        self.assertEquals(('bla', []), extract_capabilities('bla'))
 
     def test_caps(self):
-        self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la"))
-        self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la\n"))
-        self.assertEquals(("bla", ["la", "la"]), extract_capabilities("bla\0la la"))
+        self.assertEquals(('bla', ['la']), extract_capabilities('bla\0la'))
+        self.assertEquals(('bla', ['la']), extract_capabilities('bla\0la\n'))
+        self.assertEquals(('bla', ['la', 'la']), extract_capabilities('bla\0la la'))
 
     def test_plain_want_line(self):
-        self.assertEquals(("want bla", []), extract_want_line_capabilities("want bla"))
+        self.assertEquals(('want bla', []), extract_want_line_capabilities('want bla'))
 
     def test_caps_want_line(self):
-        self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la"))
-        self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la\n"))
-        self.assertEquals(("want bla", ["la", "la"]), extract_want_line_capabilities("want bla la la"))
+        self.assertEquals(('want bla', ['la']), extract_want_line_capabilities('want bla la'))
+        self.assertEquals(('want bla', ['la']), extract_want_line_capabilities('want bla la\n'))
+        self.assertEquals(('want bla', ['la', 'la']), extract_want_line_capabilities('want bla la la'))
 
     def test_ack_type(self):
         self.assertEquals(SINGLE_ACK, ack_type(['foo', 'bar']))
@@ -192,3 +193,58 @@ class CapabilitiesTestCase(TestCase):
         self.assertEquals(MULTI_ACK_DETAILED,
                           ack_type(['foo', 'bar', 'multi_ack',
                                     'multi_ack_detailed']))
+
+
+class BufferedPktLineWriterTests(TestCase):
+
+    def setUp(self):
+        TestCase.setUp(self)
+        self._output = StringIO()
+        self._writer = BufferedPktLineWriter(self._output.write, bufsize=16)
+
+    def assertOutputEquals(self, expected):
+        self.assertEquals(expected, self._output.getvalue())
+
+    def _truncate(self):
+        self._output.seek(0)
+        self._output.truncate()
+
+    def test_write(self):
+        self._writer.write('foo')
+        self.assertOutputEquals('')
+        self._writer.flush()
+        self.assertOutputEquals('0007foo')
+
+    def test_write_none(self):
+        self._writer.write(None)
+        self.assertOutputEquals('')
+        self._writer.flush()
+        self.assertOutputEquals('0000')
+
+    def test_flush_empty(self):
+        self._writer.flush()
+        self.assertOutputEquals('')
+
+    def test_write_multiple(self):
+        self._writer.write('foo')
+        self._writer.write('bar')
+        self.assertOutputEquals('')
+        self._writer.flush()
+        self.assertOutputEquals('0007foo0007bar')
+
+    def test_write_across_boundary(self):
+        self._writer.write('foo')
+        self._writer.write('barbaz')
+        self.assertOutputEquals('0007foo000abarba')
+        self._truncate()
+        self._writer.flush()
+        self.assertOutputEquals('z')
+
+    def test_write_to_boundary(self):
+        self._writer.write('foo')
+        self._writer.write('barba')
+        self.assertOutputEquals('0007foo0009barba')
+        self._truncate()
+        self._writer.write('z')
+        self._writer.flush()
+        self.assertOutputEquals('0005z')

+ 5 - 2
dulwich/tests/test_repository.py

@@ -26,6 +26,9 @@ import tempfile
 import warnings
 
 from dulwich import errors
+from dulwich.file import (
+    GitFile,
+    )
 from dulwich.object_store import (
     tree_lookup_path,
     )
@@ -767,10 +770,10 @@ class DiskRefsContainerTests(RefsContainerTests, TestCase):
 
     def test_remove_packed_without_peeled(self):
         refs_file = os.path.join(self._repo.path, 'packed-refs')
-        f = open(refs_file)
+        f = GitFile(refs_file)
         refs_data = f.read()
         f.close()
-        f = open(refs_file, 'wb')
+        f = GitFile(refs_file, 'wb')
         f.write('\n'.join(l for l in refs_data.split('\n')
                           if not l or l[0] not in '#^'))
         f.close()

+ 91 - 94
dulwich/tests/test_server.py

@@ -21,20 +21,26 @@
 
 from dulwich.errors import (
     GitProtocolError,
+    UnexpectedCommandError,
+    )
+from dulwich.repo import (
+    MemoryRepo,
     )
 from dulwich.server import (
     Backend,
     DictBackend,
-    BackendRepo,
     Handler,
     MultiAckGraphWalkerImpl,
     MultiAckDetailedGraphWalkerImpl,
+    _split_proto_line,
     ProtocolGraphWalker,
     SingleAckGraphWalkerImpl,
     UploadPackHandler,
     )
 from dulwich.tests import TestCase
-
+from utils import (
+    make_commit,
+    )
 
 
 ONE = '1' * 40
@@ -76,13 +82,25 @@ class TestProto(object):
             return None
 
 
+class TestGenericHandler(Handler):
+
+    def __init__(self):
+        Handler.__init__(self, Backend(), None)
+
+    @classmethod
+    def capabilities(cls):
+        return ('cap1', 'cap2', 'cap3')
+
+    @classmethod
+    def required_capabilities(cls):
+        return ('cap2',)
+
+
 class HandlerTestCase(TestCase):
 
     def setUp(self):
         super(HandlerTestCase, self).setUp()
-        self._handler = Handler(Backend(), None)
-        self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
-        self._handler.required_capabilities = lambda: ('cap2',)
+        self._handler = TestGenericHandler()
 
     def assertSucceeds(self, func, *args, **kwargs):
         try:
@@ -124,10 +142,10 @@ class UploadPackHandlerTestCase(TestCase):
 
     def setUp(self):
         super(UploadPackHandlerTestCase, self).setUp()
-        self._backend = DictBackend({"/": BackendRepo()})
-        self._handler = UploadPackHandler(self._backend,
-                ["/", "host=lolcathost"], None, None)
-        self._handler.proto = TestProto()
+        self._repo = MemoryRepo.init_bare([], {})
+        backend = DictBackend({'/': self._repo})
+        self._handler = UploadPackHandler(
+          backend, ['/', 'host=lolcathost'], TestProto())
 
     def test_progress(self):
         caps = self._handler.required_capabilities()
@@ -153,63 +171,30 @@ class UploadPackHandlerTestCase(TestCase):
             'refs/tags/tag2': TWO,
             'refs/heads/master': FOUR,  # not a tag, no peeled value
             }
+        # repo needs to peel this object
+        self._repo.object_store.add_object(make_commit(id=FOUR))
+        self._repo.refs._update(refs)
         peeled = {
-            'refs/tags/tag1': '1234',
-            'refs/tags/tag2': '5678',
+            'refs/tags/tag1': '1234' * 10,
+            'refs/tags/tag2': '5678' * 10,
             }
-
-        class TestRepo(object):
-            def get_peeled(self, ref):
-                return peeled.get(ref, refs[ref])
+        self._repo.refs._update_peeled(peeled)
 
         caps = list(self._handler.required_capabilities()) + ['include-tag']
         self._handler.set_client_capabilities(caps)
-        self.assertEquals({'1234': ONE, '5678': TWO},
-                          self._handler.get_tagged(refs, repo=TestRepo()))
+        self.assertEquals({'1234' * 10: ONE, '5678' * 10: TWO},
+                          self._handler.get_tagged(refs, repo=self._repo))
 
         # non-include-tag case
         caps = self._handler.required_capabilities()
         self._handler.set_client_capabilities(caps)
-        self.assertEquals({}, self._handler.get_tagged(refs, repo=TestRepo()))
-
-
-class TestCommit(object):
-
-    def __init__(self, sha, parents, commit_time):
-        self.id = sha
-        self.parents = parents
-        self.commit_time = commit_time
-        self.type_name = "commit"
-
-    def __repr__(self):
-        return '%s(%s)' % (self.__class__.__name__, self._sha)
-
-
-class TestRepo(object):
-    def __init__(self):
-        self.peeled = {}
+        self.assertEquals({}, self._handler.get_tagged(refs, repo=self._repo))
 
-    def get_peeled(self, name):
-        return self.peeled[name]
 
-
-class TestBackend(object):
-
-    def __init__(self, repo, objects):
-        self.repo = repo
-        self.object_store = objects
-
-
-class TestUploadPackHandler(Handler):
-
-    def __init__(self, objects, proto):
-        self.backend = TestBackend(TestRepo(), objects)
-        self.proto = proto
-        self.stateless_rpc = False
-        self.advertise_refs = False
-
-    def capabilities(self):
-        return ('multi_ack',)
+class TestUploadPackHandler(UploadPackHandler):
+    @classmethod
+    def required_capabilities(self):
+        return ()
 
 
 class ProtocolGraphWalkerTestCase(TestCase):
@@ -220,17 +205,18 @@ class ProtocolGraphWalkerTestCase(TestCase):
         #   3---5
         #  /
         # 1---2---4
-        self._objects = {
-          ONE: TestCommit(ONE, [], 111),
-          TWO: TestCommit(TWO, [ONE], 222),
-          THREE: TestCommit(THREE, [ONE], 333),
-          FOUR: TestCommit(FOUR, [TWO], 444),
-          FIVE: TestCommit(FIVE, [THREE], 555),
-          }
-
+        commits = [
+          make_commit(id=ONE, parents=[], commit_time=111),
+          make_commit(id=TWO, parents=[ONE], commit_time=222),
+          make_commit(id=THREE, parents=[ONE], commit_time=333),
+          make_commit(id=FOUR, parents=[TWO], commit_time=444),
+          make_commit(id=FIVE, parents=[THREE], commit_time=555),
+          ]
+        self._repo = MemoryRepo.init_bare(commits, {})
+        backend = DictBackend({'/': self._repo})
         self._walker = ProtocolGraphWalker(
-            TestUploadPackHandler(self._objects, TestProto()),
-            self._objects, None)
+            TestUploadPackHandler(backend, ['/', 'host=lolcats'], TestProto()),
+            self._repo.object_store, self._repo.get_peeled)
 
     def test_is_satisfied_no_haves(self):
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))
@@ -257,22 +243,21 @@ class ProtocolGraphWalkerTestCase(TestCase):
         self.assertFalse(self._walker.all_wants_satisfied([THREE]))
         self.assertTrue(self._walker.all_wants_satisfied([TWO, THREE]))
 
-    def test_read_proto_line(self):
-        self._walker.proto.set_output([
-          'want %s' % ONE,
-          'want %s' % TWO,
-          'have %s' % THREE,
-          'foo %s' % FOUR,
-          'bar',
-          'done',
-          ])
-        self.assertEquals(('want', ONE), self._walker.read_proto_line())
-        self.assertEquals(('want', TWO), self._walker.read_proto_line())
-        self.assertEquals(('have', THREE), self._walker.read_proto_line())
-        self.assertRaises(GitProtocolError, self._walker.read_proto_line)
-        self.assertRaises(GitProtocolError, self._walker.read_proto_line)
-        self.assertEquals(('done', None), self._walker.read_proto_line())
-        self.assertEquals((None, None), self._walker.read_proto_line())
+    def test_split_proto_line(self):
+        allowed = ('want', 'done', None)
+        self.assertEquals(('want', ONE),
+                          _split_proto_line('want %s\n' % ONE, allowed))
+        self.assertEquals(('want', TWO),
+                          _split_proto_line('want %s\n' % TWO, allowed))
+        self.assertRaises(GitProtocolError, _split_proto_line,
+                          'want xxxx\n', allowed)
+        self.assertRaises(UnexpectedCommandError, _split_proto_line,
+                          'have %s\n' % THREE, allowed)
+        self.assertRaises(GitProtocolError, _split_proto_line,
+                          'foo %s\n' % FOUR, allowed)
+        self.assertRaises(GitProtocolError, _split_proto_line, 'bar', allowed)
+        self.assertEquals(('done', None), _split_proto_line('done\n', allowed))
+        self.assertEquals((None, None), _split_proto_line('', allowed))
 
     def test_determine_wants(self):
         self.assertRaises(GitProtocolError, self._walker.determine_wants, {})
@@ -281,8 +266,12 @@ class ProtocolGraphWalkerTestCase(TestCase):
           'want %s multi_ack' % ONE,
           'want %s' % TWO,
           ])
-        heads = {'ref1': ONE, 'ref2': TWO, 'ref3': THREE}
-        self._walker.get_peeled = heads.get
+        heads = {
+          'refs/heads/ref1': ONE,
+          'refs/heads/ref2': TWO,
+          'refs/heads/ref3': THREE,
+          }
+        self._repo.refs._update(heads)
         self.assertEquals([ONE, TWO], self._walker.determine_wants(heads))
 
         self._walker.proto.set_output(['want %s multi_ack' % FOUR])
@@ -300,9 +289,14 @@ class ProtocolGraphWalkerTestCase(TestCase):
     def test_determine_wants_advertisement(self):
         self._walker.proto.set_output([])
         # advertise branch tips plus tag
-        heads = {'ref4': FOUR, 'ref5': FIVE, 'tag6': SIX}
-        peeled = {'ref4': FOUR, 'ref5': FIVE, 'tag6': FIVE}
-        self._walker.get_peeled = peeled.get
+        heads = {
+          'refs/heads/ref4': FOUR,
+          'refs/heads/ref5': FIVE,
+          'refs/heads/tag6': SIX,
+          }
+        self._repo.refs._update(heads)
+        self._repo.refs._update_peeled(heads)
+        self._repo.refs._update_peeled({'refs/heads/tag6': FIVE})
         self._walker.determine_wants(heads)
         lines = []
         while True:
@@ -315,16 +309,16 @@ class ProtocolGraphWalkerTestCase(TestCase):
             lines.append(line.rstrip())
 
         self.assertEquals([
-          '%s ref4' % FOUR,
-          '%s ref5' % FIVE,
-          '%s tag6^{}' % FIVE,
-          '%s tag6' % SIX,
+          '%s refs/heads/ref4' % FOUR,
+          '%s refs/heads/ref5' % FIVE,
+          '%s refs/heads/tag6^{}' % FIVE,
+          '%s refs/heads/tag6' % SIX,
           ], sorted(lines))
 
         # ensure peeled tag was advertised immediately following tag
         for i, line in enumerate(lines):
-            if line.endswith(' tag6'):
-                self.assertEquals('%s tag6^{}' % FIVE, lines[i+1])
+            if line.endswith(' refs/heads/tag6'):
+                self.assertEquals('%s refs/heads/tag6^{}' % FIVE, lines[i+1])
 
     # TODO: test commit time cutoff
 
@@ -338,8 +332,11 @@ class TestProtocolGraphWalker(object):
         self.stateless_rpc = False
         self.advertise_refs = False
 
-    def read_proto_line(self):
-        return self.lines.pop(0)
+    def read_proto_line(self, allowed):
+        command, sha = self.lines.pop(0)
+        if allowed is not None:
+            assert command in allowed
+        return command, sha
 
     def send_ack(self, sha, ack_type=''):
         self.acks.append((sha, ack_type))

+ 214 - 87
dulwich/tests/test_web.py

@@ -21,8 +21,20 @@
 from cStringIO import StringIO
 import re
 
+from dulwich.object_store import (
+    MemoryObjectStore,
+    )
 from dulwich.objects import (
     Blob,
+    Tag,
+    )
+from dulwich.repo import (
+    BaseRepo,
+    DictRefsContainer,
+    MemoryRepo,
+    )
+from dulwich.server import (
+    DictBackend,
     )
 from dulwich.tests import (
     TestCase,
@@ -31,33 +43,73 @@ from dulwich.web import (
     HTTP_OK,
     HTTP_NOT_FOUND,
     HTTP_FORBIDDEN,
+    HTTP_ERROR,
     send_file,
+    get_text_file,
+    get_loose_object,
+    get_pack_file,
+    get_idx_file,
     get_info_refs,
+    get_info_packs,
     handle_service_request,
     _LengthLimitedFile,
     HTTPGitRequest,
     HTTPGitApplication,
     )
 
+from utils import make_object
+
+
+class TestHTTPGitRequest(HTTPGitRequest):
+    """HTTPGitRequest with overridden methods to help test caching."""
+
+    def __init__(self, *args, **kwargs):
+        HTTPGitRequest.__init__(self, *args, **kwargs)
+        self.cached = None
+
+    def nocache(self):
+        self.cached = False
+
+    def cache_forever(self):
+        self.cached = True
+
 
 class WebTestCase(TestCase):
-    """Base TestCase that sets up some useful instance vars."""
+    """Base TestCase with useful instance vars and utility functions."""
+
+    _req_class = TestHTTPGitRequest
 
     def setUp(self):
         super(WebTestCase, self).setUp()
         self._environ = {}
-        self._req = HTTPGitRequest(self._environ, self._start_response,
-                                   handlers=self._handlers())
+        self._req = self._req_class(self._environ, self._start_response,
+                                    handlers=self._handlers())
         self._status = None
         self._headers = []
+        self._output = StringIO()
 
     def _start_response(self, status, headers):
         self._status = status
         self._headers = list(headers)
+        return self._output.write
 
     def _handlers(self):
         return None
 
+    def assertContentTypeEquals(self, expected):
+        self.assertTrue(('Content-Type', expected) in self._headers)
+
+
+def _test_backend(objects, refs=None, named_files=None):
+    if not refs:
+        refs = {}
+    if not named_files:
+        named_files = {}
+    repo = MemoryRepo.init_bare(objects, refs)
+    for path, contents in named_files.iteritems():
+        repo._put_named_file(path, contents)
+    return DictBackend({'/': repo})
+
 
 class DumbHandlersTestCase(WebTestCase):
 
@@ -67,10 +119,10 @@ class DumbHandlersTestCase(WebTestCase):
 
     def test_send_file(self):
         f = StringIO('foobar')
-        output = ''.join(send_file(self._req, f, 'text/plain'))
+        output = ''.join(send_file(self._req, f, 'some/thing'))
         self.assertEquals('foobar', output)
         self.assertEquals(HTTP_OK, self._status)
-        self.assertTrue(('Content-Type', 'text/plain') in self._headers)
+        self.assertContentTypeEquals('some/thing')
         self.assertTrue(f.closed)
 
     def test_send_file_buffered(self):
@@ -78,93 +130,152 @@ class DumbHandlersTestCase(WebTestCase):
         xs = 'x' * bufsize
         f = StringIO(2 * xs)
         self.assertEquals([xs, xs],
-                          list(send_file(self._req, f, 'text/plain')))
+                          list(send_file(self._req, f, 'some/thing')))
         self.assertEquals(HTTP_OK, self._status)
-        self.assertTrue(('Content-Type', 'text/plain') in self._headers)
+        self.assertContentTypeEquals('some/thing')
         self.assertTrue(f.closed)
 
     def test_send_file_error(self):
         class TestFile(object):
-            def __init__(self):
+            def __init__(self, exc_class):
                 self.closed = False
+                self._exc_class = exc_class
 
             def read(self, size=-1):
-                raise IOError
+                raise self._exc_class()
 
             def close(self):
                 self.closed = True
 
-        f = TestFile()
-        list(send_file(self._req, f, 'text/plain'))
-        self.assertEquals(HTTP_NOT_FOUND, self._status)
+        f = TestFile(IOError)
+        list(send_file(self._req, f, 'some/thing'))
+        self.assertEquals(HTTP_ERROR, self._status)
+        self.assertTrue(f.closed)
+        self.assertFalse(self._req.cached)
+
+        # non-IOErrors are reraised
+        f = TestFile(AttributeError)
+        self.assertRaises(AttributeError, list,
+                          send_file(self._req, f, 'some/thing'))
         self.assertTrue(f.closed)
+        self.assertFalse(self._req.cached)
+
+    def test_get_text_file(self):
+        backend = _test_backend([], named_files={'description': 'foo'})
+        mat = re.search('.*', 'description')
+        output = ''.join(get_text_file(self._req, backend, mat))
+        self.assertEquals('foo', output)
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertContentTypeEquals('text/plain')
+        self.assertFalse(self._req.cached)
+
+    def test_get_loose_object(self):
+        blob = make_object(Blob, data='foo')
+        backend = _test_backend([blob])
+        mat = re.search('^(..)(.{38})$', blob.id)
+        output = ''.join(get_loose_object(self._req, backend, mat))
+        self.assertEquals(blob.as_legacy_object(), output)
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertContentTypeEquals('application/x-git-loose-object')
+        self.assertTrue(self._req.cached)
+
+    def test_get_loose_object_missing(self):
+        mat = re.search('^(..)(.{38})$', '1' * 40)
+        list(get_loose_object(self._req, _test_backend([]), mat))
+        self.assertEquals(HTTP_NOT_FOUND, self._status)
+
+    def test_get_loose_object_error(self):
+        blob = make_object(Blob, data='foo')
+        backend = _test_backend([blob])
+        mat = re.search('^(..)(.{38})$', blob.id)
+
+        def as_legacy_object_error():
+            raise IOError
+
+        blob.as_legacy_object = as_legacy_object_error
+        list(get_loose_object(self._req, backend, mat))
+        self.assertEquals(HTTP_ERROR, self._status)
+
+    def test_get_pack_file(self):
+        pack_name = 'objects/pack/pack-%s.pack' % ('1' * 40)
+        backend = _test_backend([], named_files={pack_name: 'pack contents'})
+        mat = re.search('.*', pack_name)
+        output = ''.join(get_pack_file(self._req, backend, mat))
+        self.assertEquals('pack contents', output)
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertContentTypeEquals('application/x-git-packed-objects')
+        self.assertTrue(self._req.cached)
+
+    def test_get_idx_file(self):
+        idx_name = 'objects/pack/pack-%s.idx' % ('1' * 40)
+        backend = _test_backend([], named_files={idx_name: 'idx contents'})
+        mat = re.search('.*', idx_name)
+        output = ''.join(get_idx_file(self._req, backend, mat))
+        self.assertEquals('idx contents', output)
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertContentTypeEquals('application/x-git-packed-objects-toc')
+        self.assertTrue(self._req.cached)
 
     def test_get_info_refs(self):
         self._environ['QUERY_STRING'] = ''
 
-        class TestTag(object):
-            def __init__(self, sha, obj_class, obj_sha):
-                self.sha = lambda: sha
-                self.object = (obj_class, obj_sha)
-
-        class TestBlob(object):
-            def __init__(self, sha):
-                self.sha = lambda: sha
-
-        blob1 = TestBlob('111')
-        blob2 = TestBlob('222')
-        blob3 = TestBlob('333')
-
-        tag1 = TestTag('aaa', Blob, '222')
-
-        class TestRepo(object):
-
-            def __init__(self, objects, peeled):
-                self._objects = dict((o.sha(), o) for o in objects)
-                self._peeled = peeled
-
-            def get_peeled(self, sha):
-                return self._peeled[sha]
-
-            def __getitem__(self, sha):
-                return self._objects[sha]
-
-            def get_refs(self):
-                return {
-                    'HEAD': '000',
-                    'refs/heads/master': blob1.sha(),
-                    'refs/tags/tag-tag': tag1.sha(),
-                    'refs/tags/blob-tag': blob3.sha(),
-                    }
-
-        class TestBackend(object):
-            def __init__(self):
-                objects = [blob1, blob2, blob3, tag1]
-                self.repo = TestRepo(objects, {
-                  'HEAD': '000',
-                  'refs/heads/master': blob1.sha(),
-                  'refs/tags/tag-tag': blob2.sha(),
-                  'refs/tags/blob-tag': blob3.sha(),
-                  })
-
-            def open_repository(self, path):
-                assert path == '/'
-                return self.repo
-
-            def get_refs(self):
-                return {
-                  'HEAD': '000',
-                  'refs/heads/master': blob1.sha(),
-                  'refs/tags/tag-tag': tag1.sha(),
-                  'refs/tags/blob-tag': blob3.sha(),
-                  }
+        blob1 = make_object(Blob, data='1')
+        blob2 = make_object(Blob, data='2')
+        blob3 = make_object(Blob, data='3')
+
+        tag1 = make_object(Tag, name='tag-tag',
+                           tagger='Test <test@example.com>',
+                           tag_time=12345,
+                           tag_timezone=0,
+                           message='message',
+                           object=(Blob, blob2.id))
+
+        objects = [blob1, blob2, blob3, tag1]
+        refs = {
+          'HEAD': '000',
+          'refs/heads/master': blob1.id,
+          'refs/tags/tag-tag': tag1.id,
+          'refs/tags/blob-tag': blob3.id,
+          }
+        backend = _test_backend(objects, refs=refs)
 
         mat = re.search('.*', '//info/refs')
-        self.assertEquals(['111\trefs/heads/master\n',
-                           '333\trefs/tags/blob-tag\n',
-                           'aaa\trefs/tags/tag-tag\n',
-                           '222\trefs/tags/tag-tag^{}\n'],
-                          list(get_info_refs(self._req, TestBackend(), mat)))
+        self.assertEquals(['%s\trefs/heads/master\n' % blob1.id,
+                           '%s\trefs/tags/blob-tag\n' % blob3.id,
+                           '%s\trefs/tags/tag-tag\n' % tag1.id,
+                           '%s\trefs/tags/tag-tag^{}\n' % blob2.id],
+                          list(get_info_refs(self._req, backend, mat)))
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertContentTypeEquals('text/plain')
+        self.assertFalse(self._req.cached)
+
+    def test_get_info_packs(self):
+        class TestPack(object):
+            def __init__(self, sha):
+                self._sha = sha
+
+            def name(self):
+                return self._sha
+
+        packs = [TestPack(str(i) * 40) for i in xrange(1, 4)]
+
+        class TestObjectStore(MemoryObjectStore):
+            # property must be overridden, can't be assigned
+            @property
+            def packs(self):
+                return packs
+
+        store = TestObjectStore()
+        repo = BaseRepo(store, None)
+        backend = DictBackend({'/': repo})
+        mat = re.search('.*', '//info/packs')
+        output = ''.join(get_info_packs(self._req, backend, mat))
+        expected = 'P pack-%s.pack\n' * 3
+        expected %= ('1' * 40, '2' * 40, '3' * 40)
+        self.assertEquals(expected, output)
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertContentTypeEquals('text/plain')
+        self.assertFalse(self._req.cached)
 
 
 class SmartHandlersTestCase(WebTestCase):
@@ -191,43 +302,55 @@ class SmartHandlersTestCase(WebTestCase):
         mat = re.search('.*', '/git-evil-handler')
         list(handle_service_request(self._req, 'backend', mat))
         self.assertEquals(HTTP_FORBIDDEN, self._status)
+        self.assertFalse(self._req.cached)
 
-    def test_handle_service_request(self):
+    def _run_handle_service_request(self, content_length=None):
         self._environ['wsgi.input'] = StringIO('foo')
+        if content_length is not None:
+            self._environ['CONTENT_LENGTH'] = content_length
         mat = re.search('.*', '/git-upload-pack')
-        output = ''.join(handle_service_request(self._req, 'backend', mat))
-        self.assertEqual('handled input: foo', output)
-        response_type = 'application/x-git-upload-pack-response'
-        self.assertTrue(('Content-Type', response_type) in self._headers)
+        handler_output = ''.join(
+          handle_service_request(self._req, 'backend', mat))
+        write_output = self._output.getvalue()
+        # Ensure all output was written via the write callback.
+        self.assertEqual('', handler_output)
+        self.assertEqual('handled input: foo', write_output)
+        self.assertContentTypeEquals('application/x-git-upload-pack-response')
         self.assertFalse(self._handler.advertise_refs)
         self.assertTrue(self._handler.stateless_rpc)
+        self.assertFalse(self._req.cached)
+
+    def test_handle_service_request(self):
+        self._run_handle_service_request()
 
     def test_handle_service_request_with_length(self):
-        self._environ['wsgi.input'] = StringIO('foobar')
-        self._environ['CONTENT_LENGTH'] = 3
-        mat = re.search('.*', '/git-upload-pack')
-        output = ''.join(handle_service_request(self._req, 'backend', mat))
-        self.assertEqual('handled input: foo', output)
-        response_type = 'application/x-git-upload-pack-response'
-        self.assertTrue(('Content-Type', response_type) in self._headers)
+        self._run_handle_service_request(content_length='3')
+
+    def test_handle_service_request_empty_length(self):
+        self._run_handle_service_request(content_length='')
 
     def test_get_info_refs_unknown(self):
         self._environ['QUERY_STRING'] = 'service=git-evil-handler'
         list(get_info_refs(self._req, 'backend', None))
         self.assertEquals(HTTP_FORBIDDEN, self._status)
+        self.assertFalse(self._req.cached)
 
     def test_get_info_refs(self):
         self._environ['wsgi.input'] = StringIO('foo')
         self._environ['QUERY_STRING'] = 'service=git-upload-pack'
 
         mat = re.search('.*', '/git-upload-pack')
-        output = ''.join(get_info_refs(self._req, 'backend', mat))
+        handler_output = ''.join(get_info_refs(self._req, 'backend', mat))
+        write_output = self._output.getvalue()
         self.assertEquals(('001e# service=git-upload-pack\n'
                            '0000'
                            # input is ignored by the handler
-                           'handled input: '), output)
+                           'handled input: '), write_output)
+        # Ensure all output was written via the write callback.
+        self.assertEquals('', handler_output)
         self.assertTrue(self._handler.advertise_refs)
         self.assertTrue(self._handler.stateless_rpc)
+        self.assertFalse(self._req.cached)
 
 
 class LengthLimitedFileTestCase(TestCase):
@@ -248,6 +371,10 @@ class LengthLimitedFileTestCase(TestCase):
 
 
 class HTTPGitRequestTestCase(WebTestCase):
+
+    # This class tests the contents of the actual cache headers
+    _req_class = HTTPGitRequest
+
     def test_not_found(self):
         self._req.cache_forever()  # cache headers should be discarded
         message = 'Something not found'

+ 24 - 3
dulwich/tests/utils.py

@@ -26,7 +26,10 @@ import shutil
 import tempfile
 import time
 
-from dulwich.objects import Commit
+from dulwich.objects import (
+    FixedSha,
+    Commit,
+    )
 from dulwich.repo import Repo
 
 
@@ -57,12 +60,30 @@ def tear_down_repo(repo):
 def make_object(cls, **attrs):
     """Make an object for testing and assign some members.
 
+    This method creates a new subclass to allow arbitrary attribute
+    reassignment, which is not otherwise possible with objects having __slots__.
+
     :param attrs: dict of attributes to set on the new object.
     :return: A newly initialized object of type cls.
     """
-    obj = cls()
+
+    class TestObject(cls):
+        """Class that inherits from the given class, but without __slots__.
+
+        Note that classes with __slots__ can't have arbitrary attributes monkey-
+        patched in, so this is a class that is exactly the same only with a
+        __dict__ instead of __slots__.
+        """
+        pass
+
+    obj = TestObject()
     for name, value in attrs.iteritems():
-        setattr(obj, name, value)
+        if name == 'id':
+            # id property is read-only, so we overwrite sha instead.
+            sha = FixedSha(value)
+            obj.sha = lambda: sha
+        else:
+            setattr(obj, name, value)
     return obj
 
 

+ 29 - 16
dulwich/web.py

@@ -48,10 +48,16 @@ logger = log_utils.getLogger(__name__)
 HTTP_OK = '200 OK'
 HTTP_NOT_FOUND = '404 Not Found'
 HTTP_FORBIDDEN = '403 Forbidden'
+HTTP_ERROR = '500 Internal Server Error'
 
 
 def date_time_string(timestamp=None):
-    # Based on BaseHTTPServer.py in python2.5
+    # From BaseHTTPRequestHandler.date_time_string in BaseHTTPServer.py in the
+    # Python 2.6.5 standard library, following modifications:
+    #  - Made a global rather than an instance method.
+    #  - weekdayname and monthname are renamed and locals rather than class
+    #    variables.
+    # Copyright (c) 2001-2010 Python Software Foundation; All Rights Reserved
     weekdays = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
     months = [None,
               'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
@@ -100,7 +106,7 @@ def send_file(req, f, content_type):
         f.close()
     except IOError:
         f.close()
-        yield req.not_found('Error reading file')
+        yield req.error('Error reading file')
     except:
         f.close()
         raise
@@ -128,7 +134,8 @@ def get_loose_object(req, backend, mat):
     try:
         data = object_store[sha].as_legacy_object()
     except IOError:
-        yield req.not_found('Error reading object')
+        yield req.error('Error reading object')
+        return
     req.cache_forever()
     req.respond(HTTP_OK, 'application/x-git-loose-object')
     yield data
@@ -159,15 +166,13 @@ def get_info_refs(req, backend, mat):
             yield req.forbidden('Unsupported service %s' % service)
             return
         req.nocache()
-        req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
-        output = StringIO()
-        proto = ReceivableProtocol(StringIO().read, output.write)
+        write = req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
+        proto = ReceivableProtocol(StringIO().read, write)
         handler = handler_cls(backend, [url_prefix(mat)], proto,
                               stateless_rpc=True, advertise_refs=True)
         handler.proto.write_pkt_line('# service=%s\n' % service)
         handler.proto.write_pkt_line(None)
         handler.handle()
-        yield output.getvalue()
     else:
         # non-smart fallback
         # TODO: select_getanyfile() (see http-backend.c)
@@ -230,20 +235,19 @@ def handle_service_request(req, backend, mat):
         yield req.forbidden('Unsupported service %s' % service)
         return
     req.nocache()
-    req.respond(HTTP_OK, 'application/x-%s-response' % service)
+    write = req.respond(HTTP_OK, 'application/x-%s-response' % service)
 
-    output = StringIO()
     input = req.environ['wsgi.input']
     # This is not necessary if this app is run from a conforming WSGI server.
     # Unfortunately, there's no way to tell that at this point.
     # TODO: git may used HTTP/1.1 chunked encoding instead of specifying
     # content-length
-    if 'CONTENT_LENGTH' in req.environ:
-        input = _LengthLimitedFile(input, int(req.environ['CONTENT_LENGTH']))
-    proto = ReceivableProtocol(input.read, output.write)
+    content_length = req.environ.get('CONTENT_LENGTH', '')
+    if content_length:
+        input = _LengthLimitedFile(input, int(content_length))
+    proto = ReceivableProtocol(input.read, write)
     handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True)
     handler.handle()
-    yield output.getvalue()
 
 
 class HTTPGitRequest(object):
@@ -255,7 +259,7 @@ class HTTPGitRequest(object):
     def __init__(self, environ, start_response, dumb=False, handlers=None):
         self.environ = environ
         self.dumb = dumb
-        self.handlers = handlers and handlers or DEFAULT_HANDLERS
+        self.handlers = handlers
         self._start_response = start_response
         self._cache_headers = []
         self._headers = []
@@ -272,7 +276,7 @@ class HTTPGitRequest(object):
             self._headers.append(('Content-Type', content_type))
         self._headers.extend(self._cache_headers)
 
-        self._start_response(status, self._headers)
+        return self._start_response(status, self._headers)
 
     def not_found(self, message):
         """Begin a HTTP 404 response and return the text of a message."""
@@ -288,6 +292,13 @@ class HTTPGitRequest(object):
         self.respond(HTTP_FORBIDDEN, 'text/plain')
         return message
 
+    def error(self, message):
+        """Begin a HTTP 500 response and return the text of a message."""
+        self._cache_headers = []
+        logger.error('Error: %s', message)
+        self.respond(HTTP_ERROR, 'text/plain')
+        return message
+
     def nocache(self):
         """Set the response to never be cached by the client."""
         self._cache_headers = [
@@ -329,7 +340,9 @@ class HTTPGitApplication(object):
     def __init__(self, backend, dumb=False, handlers=None):
         self.backend = backend
         self.dumb = dumb
-        self.handlers = handlers
+        self.handlers = dict(DEFAULT_HANDLERS)
+        if handlers is not None:
+            self.handlers.update(handlers)
 
     def __call__(self, environ, start_response):
         path = environ['PATH_INFO']

+ 1 - 1
setup.py

@@ -8,7 +8,7 @@ except ImportError:
     from distutils.core import setup, Extension
 from distutils.core import Distribution
 
-dulwich_version_string = '0.6.1'
+dulwich_version_string = '0.6.2'
 
 include_dirs = []
 # Windows MSVC support