瀏覽代碼

Import upstream version 0.6.1+bzr685

Jelmer Vernooij 14 年之前
父節點
當前提交
6643b18378

+ 0 - 7
.bzrignore

@@ -1,7 +0,0 @@
-_trial_temp
-build
-MANIFEST
-dist
-apidocs
-*,cover
-.testrepository

+ 94 - 0
NEWS

@@ -1,3 +1,85 @@
+0.6.2	UNRELEASED
+
+ BUG FIXES
+
+  * HTTP server correctly handles empty CONTENT_LENGTH. (Dave Borowitz)
+
+  * Don't error when creating GitFiles with the default mode. (Dave Borowitz)
+
+  * ThinPackData.from_file now works with resolve_ext_ref callback.
+    (Dave Borowitz)
+
+  * Provide strnlen() on mingw32 which doesn't have it. (Hans Kolek)
+
+ FEATURES
+
+  * Use slots for core objects to save up on memory. (Jelmer Vernooij)
+
+  * Web server supports streaming progress/pack output. (Dave Borowitz)
+
+  * New public function dulwich.pack.write_pack_header. (Dave Borowitz)
+
+  * Distinguish between missing files and read errors in HTTP server.
+    (Dave Borowitz)
+
+  * Initial work on support for fastimport using python-fastimport.
+    (Jelmer Vernooij)
+
+  * New dulwich.pack.MemoryPackIndex class. (Jelmer Vernooij)
+
+  * Delegate SHA peeling to the object store.  (Dave Borowitz)
+
+ TESTS
+
+  * Use GitFile when modifying packed-refs in tests. (Dave Borowitz)
+
+  * New tests in test_web with better coverage and fewer ad-hoc mocks.
+    (Dave Borowitz)
+
+  * Standardize quote delimiters in test_protocol. (Dave Borowitz)
+
+  * Fix use when testtools is installed. (Jelmer Vernooij)
+
+  * Add trivial test for write_pack_header. (Jelmer Vernooij)
+
+  * Refactor some of dulwich.tests.compat.server_utils. (Dave Borowitz)
+
+  * Allow overwriting id property of objects in test utils. (Dave Borowitz)
+
+  * Use real in-memory objects rather than stubs for server tests.
+    (Dave Borowitz)
+
+  * Clean up MissingObjectFinder. (Dave Borowitz)
+
+ API CHANGES
+
+  * ObjectStore.iter_tree_contents now walks contents in depth-first, sorted
+    order. (Dave Borowitz)
+
+  * ObjectStore.iter_tree_contents can optionally yield tree objects as well.
+    (Dave Borowitz).
+
+  * Add side-band-64k support to ReceivePackHandler. (Dave Borowitz)
+
+  * Change server capabilities methods to classmethods. (Dave Borowitz)
+
+  * Tweak server handler injection. (Dave Borowitz)
+
+  * PackIndex1 and PackIndex2 now subclass FilePackIndex, which is 
+    itself a subclass of PackIndex. (Jelmer Vernooij)
+
+ DOCUMENTATION
+
+  * Add docstrings for various functions in dulwich.objects. (Jelmer Vernooij)
+
+  * Clean up docstrings in dulwich.protocol. (Dave Borowitz)
+
+  * Explicitly specify allowed protocol commands to
+    ProtocolGraphWalker.read_proto_line.  (Dave Borowitz)
+
+  * Add utility functions to DictRefsContainer. (Dave Borowitz)
+
+
 0.6.1	2010-07-22
 0.6.1	2010-07-22
 
 
  BUG FIXES
  BUG FIXES
@@ -31,10 +113,18 @@
 
 
   * Quiet logging output from web tests. (Dave Borowitz)
   * Quiet logging output from web tests. (Dave Borowitz)
 
 
+  * More flexible version checking for compat tests. (Dave Borowitz)
+
+  * Compat tests for servers with and without side-band-64k. (Dave Borowitz)
+
  CLEANUP
  CLEANUP
 
 
   * Clean up file headers. (Dave Borowitz)
   * Clean up file headers. (Dave Borowitz)
 
 
+ TESTS
+
+  * Use GitFile when modifying packed-refs in tests. (Dave Borowitz)
+
  API CHANGES
  API CHANGES
 
 
   * dulwich.pack.write_pack_index_v{1,2} now take a file-like object
   * dulwich.pack.write_pack_index_v{1,2} now take a file-like object
@@ -45,6 +135,10 @@
 
 
   * Move reference WSGI handler to web.py. (Dave Borowitz)
   * Move reference WSGI handler to web.py. (Dave Borowitz)
 
 
+  * Factor out _report_status in ReceivePackHandler. (Dave Borowitz)
+
+  * Factor out a function to convert a line to a pkt-line. (Dave Borowitz)
+
 
 
 0.6.0	2010-05-22
 0.6.0	2010-05-22
 
 

+ 1 - 1
dulwich/__init__.py

@@ -27,4 +27,4 @@ import protocol
 import repo
 import repo
 import server
 import server
 
 
-__version__ = (0, 6, 1)
+__version__ = (0, 6, 2)

+ 8 - 0
dulwich/_objects.c

@@ -25,6 +25,14 @@
 typedef int Py_ssize_t;
 typedef int Py_ssize_t;
 #endif
 #endif
 
 
+#if defined(__MINGW32_VERSION) || defined(__APPLE__)
+size_t strnlen(char *text, size_t maxlen)
+{
+	const char *last = memchr(text, '\0', maxlen);
+	return last ? (size_t) (last - text) : maxlen;
+}
+#endif
+
 #define bytehex(x) (((x)<0xa)?('0'+(x)):('a'-0xa+(x)))
 #define bytehex(x) (((x)<0xa)?('0'+(x)):('a'-0xa+(x)))
 
 
 static PyObject *sha_to_pyhex(const unsigned char *sha)
 static PyObject *sha_to_pyhex(const unsigned char *sha)

+ 11 - 0
dulwich/errors.py

@@ -137,6 +137,17 @@ class HangupException(GitProtocolError):
             "The remote server unexpectedly closed the connection.")
             "The remote server unexpectedly closed the connection.")
 
 
 
 
+class UnexpectedCommandError(GitProtocolError):
+    """Unexpected command received in a proto line."""
+
+    def __init__(self, command):
+        if command is None:
+            command = 'flush-pkt'
+        else:
+            command = 'command %s' % command
+        GitProtocolError.__init__(self, 'Protocol got unexpected %s' % command)
+
+
 class FileFormatException(Exception):
 class FileFormatException(Exception):
     """Base class for exceptions relating to reading git file formats."""
     """Base class for exceptions relating to reading git file formats."""
 
 

+ 226 - 39
dulwich/fastexport.py

@@ -20,13 +20,30 @@
 
 
 """Fast export/import functionality."""
 """Fast export/import functionality."""
 
 
+from dulwich.index import (
+    commit_tree,
+    )
 from dulwich.objects import (
 from dulwich.objects import (
-    format_timezone,
+    Blob,
+    Commit,
+    Tag,
+    parse_timezone,
+    )
+from fastimport import (
+    commands,
+    errors as fastimport_errors,
+    processor,
     )
     )
 
 
 import stat
 import stat
 
 
-class FastExporter(object):
+
+def split_email(text):
+    (name, email) = text.rsplit(" <", 1)
+    return (name, email.rstrip(">"))
+
+
+class GitFastExporter(object):
     """Generate a fast-export output stream for Git objects."""
     """Generate a fast-export output stream for Git objects."""
 
 
     def __init__(self, outf, store):
     def __init__(self, outf, store):
@@ -35,47 +52,217 @@ class FastExporter(object):
         self.markers = {}
         self.markers = {}
         self._marker_idx = 0
         self._marker_idx = 0
 
 
+    def print_cmd(self, cmd):
+        self.outf.write("%r\n" % cmd)
+
     def _allocate_marker(self):
     def _allocate_marker(self):
         self._marker_idx+=1
         self._marker_idx+=1
-        return self._marker_idx
-
-    def _dump_blob(self, blob, marker):
-        self.outf.write("blob\nmark :%s\n" % marker)
-        self.outf.write("data %s\n" % blob.raw_length())
-        for chunk in blob.as_raw_chunks():
-            self.outf.write(chunk)
-        self.outf.write("\n")
-
-    def export_blob(self, blob):
-        i = self._allocate_marker()
-        self.markers[i] = blob.id
-        self._dump_blob(blob, i)
-        return i
-
-    def _dump_commit(self, commit, marker, ref, file_changes):
-        self.outf.write("commit %s\n" % ref)
-        self.outf.write("mark :%s\n" % marker)
-        self.outf.write("author %s %s %s\n" % (commit.author,
-            commit.author_time, format_timezone(commit.author_timezone)))
-        self.outf.write("committer %s %s %s\n" % (commit.committer,
-            commit.commit_time, format_timezone(commit.commit_timezone)))
-        self.outf.write("data %s\n" % len(commit.message))
-        self.outf.write(commit.message)
-        self.outf.write("\n")
-        self.outf.write('\n'.join(file_changes))
-        self.outf.write("\n\n")
-
-    def export_commit(self, commit, ref, base_tree=None):
-        file_changes = []
+        return str(self._marker_idx)
+
+    def _export_blob(self, blob):
+        marker = self._allocate_marker()
+        self.markers[marker] = blob.id
+        return (commands.BlobCommand(marker, blob.data), marker)
+
+    def emit_blob(self, blob):
+        (cmd, marker) = self._export_blob(blob)
+        self.print_cmd(cmd)
+        return marker
+
+    def _iter_files(self, base_tree, new_tree):
         for (old_path, new_path), (old_mode, new_mode), (old_hexsha, new_hexsha) in \
         for (old_path, new_path), (old_mode, new_mode), (old_hexsha, new_hexsha) in \
-                self.store.tree_changes(base_tree, commit.tree):
+                self.store.tree_changes(base_tree, new_tree):
             if new_path is None:
             if new_path is None:
-                file_changes.append("D %s" % old_path)
+                yield commands.FileDeleteCommand(old_path)
                 continue
                 continue
             if not stat.S_ISDIR(new_mode):
             if not stat.S_ISDIR(new_mode):
-                marker = self.export_blob(self.store[new_hexsha])
-            file_changes.append("M %o :%s %s" % (new_mode, marker, new_path))
+                blob = self.store[new_hexsha]
+                marker = self.emit_blob(blob)
+            if old_path != new_path and old_path is not None:
+                yield commands.FileRenameCommand(old_path, new_path)
+            if old_mode != new_mode or old_hexsha != new_hexsha:
+                yield commands.FileModifyCommand(new_path, new_mode, marker, None)
+
+    def _export_commit(self, commit, ref, base_tree=None):
+        file_cmds = list(self._iter_files(base_tree, commit.tree))
+        marker = self._allocate_marker()
+        if commit.parents:
+            from_ = commit.parents[0]
+            merges = commit.parents[1:]
+        else:
+            from_ = None
+            merges = []
+        author, author_email = split_email(commit.author)
+        committer, committer_email = split_email(commit.committer)
+        cmd = commands.CommitCommand(ref, marker,
+            (author, author_email, commit.author_time, commit.author_timezone),
+            (committer, committer_email, commit.commit_time, commit.commit_timezone),
+            commit.message, from_, merges, file_cmds)
+        return (cmd, marker)
+
+    def emit_commit(self, commit, ref, base_tree=None):
+        cmd, marker = self._export_commit(commit, ref, base_tree)
+        self.print_cmd(cmd)
+        return marker
+
+
+class FastImporter(object):
+    """Class for importing fastimport streams.
+
+    Please note that this is mostly a stub implementation at the moment,
+    doing the bare mimimum.
+    """
+
+    def __init__(self, repo):
+        self.repo = repo
+
+    def _parse_person(self, line):
+        (name, timestr, timezonestr) = line.rsplit(" ", 2)
+        return name, int(timestr), parse_timezone(timezonestr)[0]
+
+    def _read_blob(self, stream):
+        line = stream.readline()
+        if line.startswith("mark :"):
+            mark = line[len("mark :"):-1]
+            line = stream.readline()
+        else:
+            mark = None
+        if not line.startswith("data "):
+            raise ValueError("Blob without valid data line: %s" % line)
+        size = int(line[len("data "):])
+        o = Blob()
+        o.data = stream.read(size)
+        stream.readline()
+        self.repo.object_store.add_object(o)
+        return mark, o.id
+
+    def _read_commit(self, stream, contents, marks):
+        line = stream.readline()
+        if line.startswith("mark :"):
+            mark = line[len("mark :"):-1]
+            line = stream.readline()
+        else:
+            mark = None
+        o = Commit()
+        o.author = None
+        o.author_time = None
+        while line.startswith("author "):
+            (o.author, o.author_time, o.author_timezone) = \
+                    self._parse_person(line[len("author "):-1])
+            line = stream.readline()
+        while line.startswith("committer "):
+            (o.committer, o.commit_time, o.commit_timezone) = \
+                    self._parse_person(line[len("committer "):-1])
+            line = stream.readline()
+        if o.author is None:
+            o.author = o.committer
+        if o.author_time is None:
+            o.author_time = o.commit_time
+            o.author_timezone = o.commit_timezone
+        if not line.startswith("data "):
+            raise ValueError("Blob without valid data line: %s" % line)
+        size = int(line[len("data "):])
+        o.message = stream.read(size)
+        stream.readline()
+        line = stream.readline()[:-1]
+        while line:
+            if line.startswith("M "):
+                (kind, modestr, val, path) = line.split(" ")
+                if val[0] == ":":
+                    val = marks[val[1:]]
+                contents[path] = (int(modestr, 8), val)
+            else:
+                raise ValueError(line)
+            line = stream.readline()[:-1]
+        try:
+            o.parents = (self.repo.head(),)
+        except KeyError:
+            o.parents = ()
+        o.tree = commit_tree(self.repo.object_store,
+            ((path, hexsha, mode) for (path, (mode, hexsha)) in
+                contents.iteritems()))
+        self.repo.object_store.add_object(o)
+        return mark, o.id
+
+    def import_stream(self, stream):
+        """Import from a file-like object.
+
+        :param stream: File-like object to read a fastimport stream from.
+        :return: Dictionary with marks
+        """
+        contents = {}
+        marks = {}
+        while True:
+            line = stream.readline()
+            if not line:
+                break
+            line = line[:-1]
+            if line == "" or line[0] == "#":
+                continue
+            if line.startswith("blob"):
+                mark, hexsha = self._read_blob(stream)
+                if mark is not None:
+                    marks[mark] = hexsha
+            elif line.startswith("commit "):
+                ref = line[len("commit "):-1]
+                mark, hexsha = self._read_commit(stream, contents, marks)
+                if mark is not None:
+                    marks[mark] = hexsha
+                self.repo.refs["HEAD"] = self.repo.refs[ref] = hexsha
+            else:
+                raise ValueError("invalid command '%s'" % line)
+        return marks
+
+
+class GitImportProcessor(processor.ImportProcessor):
+    """An import processor that imports into a Git repository using Dulwich.
+
+    """
+
+    def __init__(self, repo, params=None, verbose=False, outf=None):
+        processor.ImportProcessor.__init__(self, params, verbose)
+        self.repo = repo
+        self.last_commit = None
+
+    def blob_handler(self, cmd):
+        """Process a BlobCommand."""
+        self.repo.object_store.add_object(Blob.from_string(cmd.data))
+
+    def checkpoint_handler(self, cmd):
+        """Process a CheckpointCommand."""
+        pass
+
+    def commit_handler(self, cmd):
+        """Process a CommitCommand."""
+        commit = Commit()
+        commit.author = cmd.author
+        commit.committer = cmd.committer
+        commit.message = cmd.message
+        commit.parents = []
+        if self.last_commit is not None:
+            commit.parents.append(self.last_commit)
+        commit.parents += cmd.merges
+        self.repo[cmd.ref] = commit.id
+        self.last_commit = commit.id
+
+    def progress_handler(self, cmd):
+        """Process a ProgressCommand."""
+        pass
+
+    def reset_handler(self, cmd):
+        """Process a ResetCommand."""
+        self.last_commit = cmd.from_
+        self.rep.refs[cmd.from_] = cmd.id
+
+    def tag_handler(self, cmd):
+        """Process a TagCommand."""
+        tag = Tag()
+        tag.tagger = cmd.tagger
+        tag.message = cmd.message
+        tag.name = cmd.tag
+        self.repo.add_object(tag)
+        self.repo.refs["refs/tags/" + tag.name] = tag.id
 
 
-        i = self._allocate_marker()
-        self._dump_commit(commit, i, ref, file_changes)
-        return i
+    def feature_handler(self, cmd):
+        """Process a FeatureCommand."""
+        raise fastimport_errors.UnknownFeature(cmd.feature_name)

+ 1 - 1
dulwich/file.py

@@ -60,7 +60,7 @@ def fancy_rename(oldname, newname):
     os.remove(tmpfile)
     os.remove(tmpfile)
 
 
 
 
-def GitFile(filename, mode='r', bufsize=-1):
+def GitFile(filename, mode='rb', bufsize=-1):
     """Create a file object that obeys the git file locking protocol.
     """Create a file object that obeys the git file locking protocol.
 
 
     :return: a builtin file object or a _GitFile object
     :return: a builtin file object or a _GitFile object

+ 43 - 17
dulwich/object_store.py

@@ -41,6 +41,7 @@ from dulwich.objects import (
     sha_to_hex,
     sha_to_hex,
     hex_to_filename,
     hex_to_filename,
     S_ISGITLINK,
     S_ISGITLINK,
+    object_class,
     )
     )
 from dulwich.pack import (
 from dulwich.pack import (
     Pack,
     Pack,
@@ -175,21 +176,26 @@ class BaseObjectStore(object):
                     else:
                     else:
                         todo.add((None, newhexsha, childpath))
                         todo.add((None, newhexsha, childpath))
 
 
-    def iter_tree_contents(self, tree):
-        """Yield (path, mode, hexsha) tuples for all non-Tree objects in a tree.
+    def iter_tree_contents(self, tree_id, include_trees=False):
+        """Iterate the contents of a tree and all subtrees.
 
 
-        :param tree: SHA1 of the root of the tree
+        Iteration is depth-first pre-order, as in e.g. os.walk.
+
+        :param tree_id: SHA1 of the tree.
+        :param include_trees: If True, include tree objects in the iteration.
+        :yield: Tuples of (path, mode, hexhsa) for objects in a tree.
         """
         """
-        todo = set([(tree, "")])
+        todo = [('', stat.S_IFDIR, tree_id)]
         while todo:
         while todo:
-            (tid, tpath) = todo.pop()
-            tree = self[tid]
-            for name, mode, hexsha in tree.iteritems():
-                path = posixpath.join(tpath, name)
-                if stat.S_ISDIR(mode):
-                    todo.add((hexsha, path))
-                else:
-                    yield path, mode, hexsha
+            path, mode, hexsha = todo.pop()
+            is_subtree = stat.S_ISDIR(mode)
+            if not is_subtree or include_trees:
+                yield path, mode, hexsha
+            if is_subtree:
+                entries = reversed(list(self[hexsha].iteritems()))
+                for name, entry_mode, entry_hexsha in entries:
+                    entry_path = posixpath.join(path, name)
+                    todo.append((entry_path, entry_mode, entry_hexsha))
 
 
     def find_missing_objects(self, haves, wants, progress=None,
     def find_missing_objects(self, haves, wants, progress=None,
                              get_tagged=None):
                              get_tagged=None):
@@ -238,6 +244,21 @@ class BaseObjectStore(object):
         """
         """
         return self.iter_shas(self.find_missing_objects(have, want, progress))
         return self.iter_shas(self.find_missing_objects(have, want, progress))
 
 
+    def peel_sha(self, sha):
+        """Peel all tags from a SHA.
+
+        :param sha: The object SHA to peel.
+        :return: The fully-peeled SHA1 of a tag object, after peeling all
+            intermediate tags; if the original ref does not point to a tag, this
+            will equal the original SHA1.
+        """
+        obj = self[sha]
+        obj_class = object_class(obj.type_name)
+        while obj_class is Tag:
+            obj_class, sha = obj.object
+            obj = self[sha]
+        return obj
+
 
 
 class PackBasedObjectStore(BaseObjectStore):
 class PackBasedObjectStore(BaseObjectStore):
 
 
@@ -588,7 +609,7 @@ class ObjectImporter(object):
         raise NotImplementedError(self.add_object)
         raise NotImplementedError(self.add_object)
 
 
     def finish(self, object):
     def finish(self, object):
-        """Finish the imoprt and write objects to disk."""
+        """Finish the import and write objects to disk."""
         raise NotImplementedError(self.finish)
         raise NotImplementedError(self.finish)
 
 
 
 
@@ -690,8 +711,10 @@ class MissingObjectFinder(object):
 
 
     def __init__(self, object_store, haves, wants, progress=None,
     def __init__(self, object_store, haves, wants, progress=None,
                  get_tagged=None):
                  get_tagged=None):
-        self.sha_done = set(haves)
-        self.objects_to_send = set([(w, None, False) for w in wants if w not in haves])
+        haves = set(haves)
+        self.sha_done = haves
+        self.objects_to_send = set([(w, None, False) for w in wants
+                                    if w not in haves])
         self.object_store = object_store
         self.object_store = object_store
         if progress is None:
         if progress is None:
             self.progress = lambda x: None
             self.progress = lambda x: None
@@ -700,10 +723,13 @@ class MissingObjectFinder(object):
         self._tagged = get_tagged and get_tagged() or {}
         self._tagged = get_tagged and get_tagged() or {}
 
 
     def add_todo(self, entries):
     def add_todo(self, entries):
-        self.objects_to_send.update([e for e in entries if not e[0] in self.sha_done])
+        self.objects_to_send.update([e for e in entries
+                                     if not e[0] in self.sha_done])
 
 
     def parse_tree(self, tree):
     def parse_tree(self, tree):
-        self.add_todo([(sha, name, not stat.S_ISDIR(mode)) for (mode, name, sha) in tree.entries() if not S_ISGITLINK(mode)])
+        self.add_todo([(sha, name, not stat.S_ISDIR(mode))
+                       for mode, name, sha in tree.entries()
+                       if not S_ISGITLINK(mode)])
 
 
     def parse_commit(self, commit):
     def parse_commit(self, commit):
         self.add_todo([(commit.tree, "", False)])
         self.add_todo([(commit.tree, "", False)])

+ 32 - 0
dulwich/objects.py

@@ -157,6 +157,8 @@ def check_identity(identity, error_msg):
 class FixedSha(object):
 class FixedSha(object):
     """SHA object that behaves like hashlib's but is given a fixed value."""
     """SHA object that behaves like hashlib's but is given a fixed value."""
 
 
+    __slots__ = ('_hexsha', '_sha')
+
     def __init__(self, hexsha):
     def __init__(self, hexsha):
         self._hexsha = hexsha
         self._hexsha = hexsha
         self._sha = hex_to_sha(hexsha)
         self._sha = hex_to_sha(hexsha)
@@ -171,6 +173,9 @@ class FixedSha(object):
 class ShaFile(object):
 class ShaFile(object):
     """A git SHA file."""
     """A git SHA file."""
 
 
+    __slots__ = ('_needs_parsing', '_chunked_text', '_file', '_path', 
+                 '_sha', '_needs_serialization', '_magic')
+
     @staticmethod
     @staticmethod
     def _parse_legacy_object_header(magic, f):
     def _parse_legacy_object_header(magic, f):
         """Parse a legacy object, creating it but not reading the file."""
         """Parse a legacy object, creating it but not reading the file."""
@@ -474,6 +479,8 @@ class ShaFile(object):
 class Blob(ShaFile):
 class Blob(ShaFile):
     """A Git Blob object."""
     """A Git Blob object."""
 
 
+    __slots__ = ()
+
     type_name = 'blob'
     type_name = 'blob'
     type_num = 3
     type_num = 3
 
 
@@ -555,6 +562,10 @@ class Tag(ShaFile):
     type_name = 'tag'
     type_name = 'tag'
     type_num = 4
     type_num = 4
 
 
+    __slots__ = ('_tag_timezone_neg_utc', '_name', '_object_sha', 
+                 '_object_class', '_tag_time', '_tag_timezone',
+                 '_tagger', '_message')
+
     def __init__(self):
     def __init__(self):
         super(Tag, self).__init__()
         super(Tag, self).__init__()
         self._tag_timezone_neg_utc = False
         self._tag_timezone_neg_utc = False
@@ -740,6 +751,8 @@ class Tree(ShaFile):
     type_name = 'tree'
     type_name = 'tree'
     type_num = 2
     type_num = 2
 
 
+    __slots__ = ('_entries')
+
     def __init__(self):
     def __init__(self):
         super(Tree, self).__init__()
         super(Tree, self).__init__()
         self._entries = {}
         self._entries = {}
@@ -865,6 +878,13 @@ class Tree(ShaFile):
 
 
 
 
 def parse_timezone(text):
 def parse_timezone(text):
+    """Parse a timezone text fragment (e.g. '+0100').
+
+    :param text: Text to parse.
+    :return: Tuple with timezone as seconds difference to UTC 
+        and a boolean indicating whether this was a UTC timezone
+        prefixed with a negative sign (-0000).
+    """
     offset = int(text)
     offset = int(text)
     negative_utc = (offset == 0 and text[0] == '-')
     negative_utc = (offset == 0 and text[0] == '-')
     signum = (offset < 0) and -1 or 1
     signum = (offset < 0) and -1 or 1
@@ -875,6 +895,12 @@ def parse_timezone(text):
 
 
 
 
 def format_timezone(offset, negative_utc=False):
 def format_timezone(offset, negative_utc=False):
+    """Format a timezone for Git serialization.
+
+    :param offset: Timezone offset as seconds difference to UTC
+    :param negative_utc: Whether to use a minus sign for UTC
+        (-0000 rather than +0000).
+    """
     if offset % 60 != 0:
     if offset % 60 != 0:
         raise ValueError("Unable to handle non-minute offset.")
         raise ValueError("Unable to handle non-minute offset.")
     if offset < 0 or (offset == 0 and negative_utc):
     if offset < 0 or (offset == 0 and negative_utc):
@@ -895,6 +921,12 @@ class Commit(ShaFile):
     type_name = 'commit'
     type_name = 'commit'
     type_num = 1
     type_num = 1
 
 
+    __slots__ = ('_parents', '_encoding', '_extra', '_author_timezone_neg_utc',
+                 '_commit_timezone_neg_utc', '_commit_time',
+                 '_author_time', '_author_timezone', '_commit_timezone',
+                 '_author', '_committer', '_parents', '_extra',
+                 '_encoding', '_tree', '_message')
+
     def __init__(self):
     def __init__(self):
         super(Commit, self).__init__()
         super(Commit, self).__init__()
         self._parents = []
         self._parents = []

+ 133 - 46
dulwich/pack.py

@@ -222,6 +222,111 @@ class PackIndex(object):
 
 
     Given a sha id of an object a pack index can tell you the location in the
     Given a sha id of an object a pack index can tell you the location in the
     packfile of that object if it has it.
     packfile of that object if it has it.
+    """
+
+    def __eq__(self, other):
+        if not isinstance(other, PackIndex):
+            return False
+
+        for (name1, _, _), (name2, _, _) in izip(self.iterentries(),
+                                                 other.iterentries()):
+            if name1 != name2:
+                return False
+        return True
+
+    def __ne__(self, other):
+        return not self.__eq__(other)
+
+    def __len__(self):
+        """Return the number of entries in this pack index."""
+        raise NotImplementedError(self.__len__)
+
+    def __iter__(self):
+        """Iterate over the SHAs in this pack."""
+        raise NotImplementedError(self.__iter__)
+
+    def iterentries(self):
+        """Iterate over the entries in this pack index.
+
+        :return: iterator over tuples with object name, offset in packfile and
+            crc32 checksum.
+        """
+        raise NotImplementedError(self.iterentries)
+
+    def get_pack_checksum(self):
+        """Return the SHA1 checksum stored for the corresponding packfile.
+
+        :return: 20-byte binary digest
+        """
+        raise NotImplementedError(self.get_pack_checksum)
+
+    def object_index(self, sha):
+        """Return the index in to the corresponding packfile for the object.
+
+        Given the name of an object it will return the offset that object
+        lives at within the corresponding pack file. If the pack file doesn't
+        have the object then None will be returned.
+        """
+        if len(sha) == 40:
+            sha = hex_to_sha(sha)
+        return self._object_index(sha)
+
+    def _object_index(self, sha):
+        """See object_index.
+
+        :param sha: A *binary* SHA string. (20 characters long)_
+        """
+        raise NotImplementedError(self._object_index)
+
+    def __iter__(self):
+        """Iterate over the SHAs in this pack."""
+        return imap(sha_to_hex, self._itersha())
+
+    def objects_sha1(self):
+        """Return the hex SHA1 over all the shas of all objects in this pack.
+
+        :note: This is used for the filename of the pack.
+        """
+        return iter_sha1(self._itersha())
+
+    def _itersha(self):
+        """Yield all the SHA1's of the objects in the index, sorted."""
+        raise NotImplementedError(self._itersha)
+
+
+class MemoryPackIndex(PackIndex):
+    """Pack index that is stored entirely in memory."""
+
+    def __init__(self, entries, pack_checksum=None):
+        """Create a new MemoryPackIndex.
+
+        :param entries: Sequence of name, idx, crc32 (sorted)
+        :param pack_checksum: Optional pack checksum
+        """
+        self._by_sha = {}
+        for name, idx, crc32 in entries:
+            self._by_sha[name] = idx
+        self._entries = entries
+        self._pack_checksum = pack_checksum
+
+    def get_pack_checksum(self):
+        return self._pack_checksum
+
+    def __len__(self):
+        return len(self._entries)
+
+    def _object_index(self, sha):
+        return self._by_sha[sha][0]
+
+    def _itersha(self):
+        return iter(self._by_sha)
+
+    def iterentries(self):
+        return iter(self._entries)
+
+
+class FilePackIndex(PackIndex):
+    """Pack index that is based on a file.
 
 
     To do the loop it opens the file, and indexes first 256 4 byte groups
     To do the loop it opens the file, and indexes first 256 4 byte groups
     with the first byte of the sha id. The value in the four byte group indexed
     with the first byte of the sha id. The value in the four byte group indexed
@@ -250,20 +355,12 @@ class PackIndex(object):
             self._contents, self._size = (contents, size)
             self._contents, self._size = (contents, size)
 
 
     def __eq__(self, other):
     def __eq__(self, other):
-        if not isinstance(other, PackIndex):
-            return False
-
-        if self._fan_out_table != other._fan_out_table:
+        # Quick optimization:
+        if (isinstance(other, FilePackIndex) and
+            self._fan_out_table != other._fan_out_table):
             return False
             return False
 
 
-        for (name1, _, _), (name2, _, _) in izip(self.iterentries(),
-                                                 other.iterentries()):
-            if name1 != name2:
-                return False
-        return True
-
-    def __ne__(self, other):
-        return not self.__eq__(other)
+        return super(FilePackIndex, self).__eq__(other)
 
 
     def close(self):
     def close(self):
         self._file.close()
         self._file.close()
@@ -292,21 +389,10 @@ class PackIndex(object):
         """Unpack the crc32 checksum for the i-th object from the index file."""
         """Unpack the crc32 checksum for the i-th object from the index file."""
         raise NotImplementedError(self._unpack_crc32_checksum)
         raise NotImplementedError(self._unpack_crc32_checksum)
 
 
-    def __iter__(self):
-        """Iterate over the SHAs in this pack."""
-        return imap(sha_to_hex, self._itersha())
-
     def _itersha(self):
     def _itersha(self):
         for i in range(len(self)):
         for i in range(len(self)):
             yield self._unpack_name(i)
             yield self._unpack_name(i)
 
 
-    def objects_sha1(self):
-        """Return the hex SHA1 over all the shas of all objects in this pack.
-
-        :note: This is used for the filename of the pack.
-        """
-        return iter_sha1(self._itersha())
-
     def iterentries(self):
     def iterentries(self):
         """Iterate over the entries in this pack index.
         """Iterate over the entries in this pack index.
 
 
@@ -351,17 +437,6 @@ class PackIndex(object):
         """
         """
         return str(self._contents[-20:])
         return str(self._contents[-20:])
 
 
-    def object_index(self, sha):
-        """Return the index in to the corresponding packfile for the object.
-
-        Given the name of an object it will return the offset that object
-        lives at within the corresponding pack file. If the pack file doesn't
-        have the object then None will be returned.
-        """
-        if len(sha) == 40:
-            sha = hex_to_sha(sha)
-        return self._object_index(sha)
-
     def _object_index(self, sha):
     def _object_index(self, sha):
         """See object_index.
         """See object_index.
 
 
@@ -380,11 +455,11 @@ class PackIndex(object):
         return self._unpack_offset(i)
         return self._unpack_offset(i)
 
 
 
 
-class PackIndex1(PackIndex):
-    """Version 1 Pack Index."""
+class PackIndex1(FilePackIndex):
+    """Version 1 Pack Index file."""
 
 
     def __init__(self, filename, file=None, contents=None, size=None):
     def __init__(self, filename, file=None, contents=None, size=None):
-        PackIndex.__init__(self, filename, file, contents, size)
+        super(PackIndex1, self).__init__(filename, file, contents, size)
         self.version = 1
         self.version = 1
         self._fan_out_table = self._read_fan_out_table(0)
         self._fan_out_table = self._read_fan_out_table(0)
 
 
@@ -406,11 +481,11 @@ class PackIndex1(PackIndex):
         return None
         return None
 
 
 
 
-class PackIndex2(PackIndex):
-    """Version 2 Pack Index."""
+class PackIndex2(FilePackIndex):
+    """Version 2 Pack Index file."""
 
 
     def __init__(self, filename, file=None, contents=None, size=None):
     def __init__(self, filename, file=None, contents=None, size=None):
-        PackIndex.__init__(self, filename, file, contents, size)
+        super(PackIndex2, self).__init__(filename, file, contents, size)
         assert self._contents[:4] == '\377tOc', "Not a v2 pack index file"
         assert self._contents[:4] == '\377tOc', "Not a v2 pack index file"
         (self.version, ) = unpack_from(">L", self._contents, 4)
         (self.version, ) = unpack_from(">L", self._contents, 4)
         assert self.version == 2, "Version was %d" % self.version
         assert self.version == 2, "Version was %d" % self.version
@@ -888,6 +963,10 @@ class ThinPackData(PackData):
         super(ThinPackData, self).__init__(*args, **kwargs)
         super(ThinPackData, self).__init__(*args, **kwargs)
         self.resolve_ext_ref = resolve_ext_ref
         self.resolve_ext_ref = resolve_ext_ref
 
 
+    @classmethod
+    def from_file(cls, resolve_ext_ref, file, size):
+        return cls(resolve_ext_ref, str(file), file=file, size=size)
+
     def get_ref(self, sha):
     def get_ref(self, sha):
         """Resolve a reference looking in both this pack and the store."""
         """Resolve a reference looking in both this pack and the store."""
         try:
         try:
@@ -1061,11 +1140,21 @@ def write_pack(filename, objects, num_objects):
         f.close()
         f.close()
 
 
 
 
+def write_pack_header(f, num_objects):
+    """Write a pack header for the given number of objects."""
+    f.write('PACK')                          # Pack header
+    f.write(struct.pack('>L', 2))            # Pack version
+    f.write(struct.pack('>L', num_objects))  # Number of objects in pack
+
+
 def write_pack_data(f, objects, num_objects, window=10):
 def write_pack_data(f, objects, num_objects, window=10):
-    """Write a new pack file.
+    """Write a new pack data file.
 
 
-    :param filename: The filename of the new pack file.
-    :param objects: List of objects to write (tuples with object and path)
+    :param f: File to write to
+    :param objects: Iterable over (object, path) tuples to write
+    :param num_objects: Number of objects to write
+    :param window: Sliding window size for searching for deltas; currently
+                   unimplemented
     :return: List with (name, offset, crc32 checksum) entries, pack checksum
     :return: List with (name, offset, crc32 checksum) entries, pack checksum
     """
     """
     recency = list(objects)
     recency = list(objects)
@@ -1085,9 +1174,7 @@ def write_pack_data(f, objects, num_objects, window=10):
     # Write the pack
     # Write the pack
     entries = []
     entries = []
     f = SHA1Writer(f)
     f = SHA1Writer(f)
-    f.write("PACK")               # Pack header
-    f.write(struct.pack(">L", 2)) # Pack version
-    f.write(struct.pack(">L", num_objects)) # Number of objects in pack
+    write_pack_header(f, num_objects)
     for o, path in recency:
     for o, path in recency:
         sha1 = o.sha().digest()
         sha1 = o.sha().digest()
         orig_t = o.type_num
         orig_t = o.type_num

+ 11 - 4
dulwich/patch.py

@@ -154,14 +154,21 @@ def git_am_patch_split(f):
     c = Commit()
     c = Commit()
     c.author = msg["from"]
     c.author = msg["from"]
     c.committer = msg["from"]
     c.committer = msg["from"]
-    if msg["subject"].startswith("[PATCH"):
-        subject = msg["subject"].split("]", 1)[1][1:]
-    else:
+    try:
+        patch_tag_start = msg["subject"].index("[PATCH")
+    except ValueError:
         subject = msg["subject"]
         subject = msg["subject"]
-    c.message = subject
+    else:
+        close = msg["subject"].index("] ", patch_tag_start)
+        subject = msg["subject"][close+2:]
+    c.message = subject.replace("\n", "") + "\n"
+    first = True
     for l in f:
     for l in f:
         if l == "---\n":
         if l == "---\n":
             break
             break
+        if first:
+            c.message += "\n"
+            first = False
         c.message += l
         c.message += l
     diff = ""
     diff = ""
     for l in f:
     for l in f:

+ 92 - 31
dulwich/protocol.py

@@ -41,10 +41,7 @@ MULTI_ACK_DETAILED = 2
 
 
 
 
 class ProtocolFile(object):
 class ProtocolFile(object):
-    """
-    Some network ops are like file ops. The file ops expect to operate on
-    file objects, so provide them with a dummy file.
-    """
+    """A dummy file for network ops that expect file-like objects."""
 
 
     def __init__(self, read, write):
     def __init__(self, read, write):
         self.read = read
         self.read = read
@@ -57,7 +54,29 @@ class ProtocolFile(object):
         pass
         pass
 
 
 
 
+def pkt_line(data):
+    """Wrap data in a pkt-line.
+
+    :param data: The data to wrap, as a str or None.
+    :return: The data prefixed with its length in pkt-line format; if data was
+        None, returns the flush-pkt ('0000').
+    """
+    if data is None:
+        return '0000'
+    return '%04x%s' % (len(data) + 4, data)
+
+
 class Protocol(object):
 class Protocol(object):
+    """Class for interacting with a remote git process over the wire.
+
+    Parts of the git wire protocol use 'pkt-lines' to communicate. A pkt-line
+    consists of the length of the line as a 4-byte hex string, followed by the
+    payload data. The length includes the 4-byte header. The special line '0000'
+    indicates the end of a section of input and is called a 'flush-pkt'.
+
+    For details on the pkt-line format, see the cgit distribution:
+        Documentation/technical/protocol-common.txt
+    """
 
 
     def __init__(self, read, write, report_activity=None):
     def __init__(self, read, write, report_activity=None):
         self.read = read
         self.read = read
@@ -65,10 +84,10 @@ class Protocol(object):
         self.report_activity = report_activity
         self.report_activity = report_activity
 
 
     def read_pkt_line(self):
     def read_pkt_line(self):
-        """
-        Reads a 'pkt line' from the remote git process
+        """Reads a pkt-line from the remote git process.
 
 
-        :return: The next string from the stream
+        :return: The next string from the stream, without the length prefix, or
+            None for a flush-pkt ('0000').
         """
         """
         try:
         try:
             sizestr = self.read(4)
             sizestr = self.read(4)
@@ -86,30 +105,32 @@ class Protocol(object):
             raise GitProtocolError(e)
             raise GitProtocolError(e)
 
 
     def read_pkt_seq(self):
     def read_pkt_seq(self):
+        """Read a sequence of pkt-lines from the remote git process.
+
+        :yield: Each line of data up to but not including the next flush-pkt.
+        """
         pkt = self.read_pkt_line()
         pkt = self.read_pkt_line()
         while pkt:
         while pkt:
             yield pkt
             yield pkt
             pkt = self.read_pkt_line()
             pkt = self.read_pkt_line()
 
 
     def write_pkt_line(self, line):
     def write_pkt_line(self, line):
-        """
-        Sends a 'pkt line' to the remote git process
+        """Sends a pkt-line to the remote git process.
 
 
-        :param line: A string containing the data to send
+        :param line: A string containing the data to send, without the length
+            prefix.
         """
         """
         try:
         try:
-            if line is None:
-                self.write("0000")
-                if self.report_activity:
-                    self.report_activity(4, 'write')
-            else:
-                self.write("%04x%s" % (len(line)+4, line))
-                if self.report_activity:
-                    self.report_activity(4+len(line), 'write')
+            line = pkt_line(line)
+            self.write(line)
+            if self.report_activity:
+                self.report_activity(len(line), 'write')
         except socket.error, e:
         except socket.error, e:
             raise GitProtocolError(e)
             raise GitProtocolError(e)
 
 
     def write_file(self):
     def write_file(self):
+        """Return a writable file-like object for this protocol."""
+
         class ProtocolFile(object):
         class ProtocolFile(object):
 
 
             def __init__(self, proto):
             def __init__(self, proto):
@@ -129,11 +150,10 @@ class Protocol(object):
         return ProtocolFile(self)
         return ProtocolFile(self)
 
 
     def write_sideband(self, channel, blob):
     def write_sideband(self, channel, blob):
-        """
-        Write data to the sideband (a git multiplexing method)
+        """Write multiplexed data to the sideband.
 
 
-        :param channel: int specifying which channel to write to
-        :param blob: a blob of data (as a string) to send on this channel
+        :param channel: An int specifying the channel to write to.
+        :param blob: A blob of data (as a string) to send on this channel.
         """
         """
         # a pktline can be a max of 65520. a sideband line can therefore be
         # a pktline can be a max of 65520. a sideband line can therefore be
         # 65520-5 = 65515
         # 65520-5 = 65515
@@ -143,23 +163,21 @@ class Protocol(object):
             blob = blob[65515:]
             blob = blob[65515:]
 
 
     def send_cmd(self, cmd, *args):
     def send_cmd(self, cmd, *args):
-        """
-        Send a command and some arguments to a git server
+        """Send a command and some arguments to a git server.
 
 
-        Only used for git://
+        Only used for the TCP git protocol (git://).
 
 
-        :param cmd: The remote service to access
-        :param args: List of arguments to send to remove service
+        :param cmd: The remote service to access.
+        :param args: List of arguments to send to remove service.
         """
         """
         self.write_pkt_line("%s %s" % (cmd, "".join(["%s\0" % a for a in args])))
         self.write_pkt_line("%s %s" % (cmd, "".join(["%s\0" % a for a in args])))
 
 
     def read_cmd(self):
     def read_cmd(self):
-        """
-        Read a command and some arguments from the git client
+        """Read a command and some arguments from the git client
 
 
-        Only used for git://
+        Only used for the TCP git protocol (git://).
 
 
-        :return: A tuple of (command, [list of arguments])
+        :return: A tuple of (command, [list of arguments]).
         """
         """
         line = self.read_pkt_line()
         line = self.read_pkt_line()
         splice_at = line.find(" ")
         splice_at = line.find(" ")
@@ -310,3 +328,46 @@ def ack_type(capabilities):
     elif 'multi_ack' in capabilities:
     elif 'multi_ack' in capabilities:
         return MULTI_ACK
         return MULTI_ACK
     return SINGLE_ACK
     return SINGLE_ACK
+
+
+class BufferedPktLineWriter(object):
+    """Writer that wraps its data in pkt-lines and has an independent buffer.
+
+    Consecutive calls to write() wrap the data in a pkt-line and then buffers it
+    until enough lines have been written such that their total length (including
+    length prefix) reach the buffer size.
+    """
+
+    def __init__(self, write, bufsize=65515):
+        """Initialize the BufferedPktLineWriter.
+
+        :param write: A write callback for the underlying writer.
+        :param bufsize: The internal buffer size, including length prefixes.
+        """
+        self._write = write
+        self._bufsize = bufsize
+        self._wbuf = StringIO()
+        self._buflen = 0
+
+    def write(self, data):
+        """Write data, wrapping it in a pkt-line."""
+        line = pkt_line(data)
+        line_len = len(line)
+        over = self._buflen + line_len - self._bufsize
+        if over >= 0:
+            start = line_len - over
+            self._wbuf.write(line[:start])
+            self.flush()
+        else:
+            start = 0
+        saved = line[start:]
+        self._wbuf.write(saved)
+        self._buflen += len(saved)
+
+    def flush(self):
+        """Flush all data from the buffer."""
+        data = self._wbuf.getvalue()
+        if data:
+            self._write(data)
+        self._len = 0
+        self._wbuf = StringIO()

+ 17 - 8
dulwich/repo.py

@@ -342,6 +342,7 @@ class DictRefsContainer(RefsContainer):
 
 
     def __init__(self, refs):
     def __init__(self, refs):
         self._refs = refs
         self._refs = refs
+        self._peeled = {}
 
 
     def allkeys(self):
     def allkeys(self):
         return self._refs.keys()
         return self._refs.keys()
@@ -374,6 +375,19 @@ class DictRefsContainer(RefsContainer):
         del self._refs[name]
         del self._refs[name]
         return True
         return True
 
 
+    def get_peeled(self, name):
+        return self._peeled.get(name)
+
+    def _update(self, refs):
+        """Update multiple refs; intended only for testing."""
+        # TODO(dborowitz): replace this with a public function that uses
+        # set_if_equal.
+        self._refs.update(refs)
+
+    def _update_peeled(self, peeled):
+        """Update cached peeled refs; intended only for testing."""
+        self._peeled.update(peeled)
+
 
 
 class DiskRefsContainer(RefsContainer):
 class DiskRefsContainer(RefsContainer):
     """Refs container that reads refs from disk."""
     """Refs container that reads refs from disk."""
@@ -924,20 +938,15 @@ class BaseRepo(object):
     def get_peeled(self, ref):
     def get_peeled(self, ref):
         """Get the peeled value of a ref.
         """Get the peeled value of a ref.
 
 
-        :param ref: the refname to peel
-        :return: the fully-peeled SHA1 of a tag object, after peeling all
+        :param ref: The refname to peel.
+        :return: The fully-peeled SHA1 of a tag object, after peeling all
             intermediate tags; if the original ref does not point to a tag, this
             intermediate tags; if the original ref does not point to a tag, this
             will equal the original SHA1.
             will equal the original SHA1.
         """
         """
         cached = self.refs.get_peeled(ref)
         cached = self.refs.get_peeled(ref)
         if cached is not None:
         if cached is not None:
             return cached
             return cached
-        obj = self[ref]
-        obj_class = object_class(obj.type_name)
-        while obj_class is Tag:
-            obj_class, sha = obj.object
-            obj = self.get_object(sha)
-        return obj.id
+        return self.object_store.peel_sha(self.refs[ref]).id
 
 
     def revision_history(self, head):
     def revision_history(self, head):
         """Returns a list of the commits reachable from head.
         """Returns a list of the commits reachable from head.

+ 96 - 53
dulwich/server.py

@@ -36,6 +36,7 @@ from dulwich.errors import (
     ApplyDeltaError,
     ApplyDeltaError,
     ChecksumMismatch,
     ChecksumMismatch,
     GitProtocolError,
     GitProtocolError,
+    UnexpectedCommandError,
     ObjectFormatException,
     ObjectFormatException,
     )
     )
 from dulwich import log_utils
 from dulwich import log_utils
@@ -57,6 +58,7 @@ from dulwich.protocol import (
     ack_type,
     ack_type,
     extract_capabilities,
     extract_capabilities,
     extract_want_line_capabilities,
     extract_want_line_capabilities,
+    BufferedPktLineWriter,
     )
     )
 from dulwich.repo import (
 from dulwich.repo import (
     Repo,
     Repo,
@@ -161,16 +163,20 @@ class Handler(object):
         self.proto = proto
         self.proto = proto
         self._client_capabilities = None
         self._client_capabilities = None
 
 
-    def capability_line(self):
-        return " ".join(self.capabilities())
+    @classmethod
+    def capability_line(cls):
+        return " ".join(cls.capabilities())
 
 
-    def capabilities(self):
-        raise NotImplementedError(self.capabilities)
+    @classmethod
+    def capabilities(cls):
+        raise NotImplementedError(cls.capabilities)
 
 
-    def innocuous_capabilities(self):
+    @classmethod
+    def innocuous_capabilities(cls):
         return ("include-tag", "thin-pack", "no-progress", "ofs-delta")
         return ("include-tag", "thin-pack", "no-progress", "ofs-delta")
 
 
-    def required_capabilities(self):
+    @classmethod
+    def required_capabilities(cls):
         """Return a list of capabilities that we require the client to have."""
         """Return a list of capabilities that we require the client to have."""
         return []
         return []
 
 
@@ -206,11 +212,13 @@ class UploadPackHandler(Handler):
         self.stateless_rpc = stateless_rpc
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
         self.advertise_refs = advertise_refs
 
 
-    def capabilities(self):
+    @classmethod
+    def capabilities(cls):
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
                 "ofs-delta", "no-progress", "include-tag")
                 "ofs-delta", "no-progress", "include-tag")
 
 
-    def required_capabilities(self):
+    @classmethod
+    def required_capabilities(cls):
         return ("side-band-64k", "thin-pack", "ofs-delta")
         return ("side-band-64k", "thin-pack", "ofs-delta")
 
 
     def progress(self, message):
     def progress(self, message):
@@ -269,6 +277,41 @@ class UploadPackHandler(Handler):
         self.proto.write("0000")
         self.proto.write("0000")
 
 
 
 
+def _split_proto_line(line, allowed):
+    """Split a line read from the wire.
+
+    :param line: The line read from the wire.
+    :param allowed: An iterable of command names that should be allowed.
+        Command names not listed below as possible return values will be
+        ignored.  If None, any commands from the possible return values are
+        allowed.
+    :return: a tuple having one of the following forms:
+        ('want', obj_id)
+        ('have', obj_id)
+        ('done', None)
+        (None, None)  (for a flush-pkt)
+
+    :raise UnexpectedCommandError: if the line cannot be parsed into one of the
+        allowed return values.
+    """
+    if not line:
+        fields = [None]
+    else:
+        fields = line.rstrip('\n').split(' ', 1)
+    command = fields[0]
+    if allowed is not None and command not in allowed:
+        raise UnexpectedCommandError(command)
+    try:
+        if len(fields) == 1 and command in ('done', None):
+            return (command, None)
+        elif len(fields) == 2 and command in ('want', 'have'):
+            hex_to_sha(fields[1])
+            return tuple(fields)
+    except (TypeError, AssertionError), e:
+        raise GitProtocolError(e)
+    raise GitProtocolError('Received invalid line from client: %s' % line)
+
+
 class ProtocolGraphWalker(object):
 class ProtocolGraphWalker(object):
     """A graph walker that knows the git protocol.
     """A graph walker that knows the git protocol.
 
 
@@ -333,18 +376,16 @@ class ProtocolGraphWalker(object):
         line, caps = extract_want_line_capabilities(want)
         line, caps = extract_want_line_capabilities(want)
         self.handler.set_client_capabilities(caps)
         self.handler.set_client_capabilities(caps)
         self.set_ack_type(ack_type(caps))
         self.set_ack_type(ack_type(caps))
-        command, sha = self._split_proto_line(line)
+        allowed = ('want', None)
+        command, sha = _split_proto_line(line, allowed)
 
 
         want_revs = []
         want_revs = []
         while command != None:
         while command != None:
-            if command != 'want':
-                raise GitProtocolError(
-                  'Protocol got unexpected command %s' % command)
             if sha not in values:
             if sha not in values:
                 raise GitProtocolError(
                 raise GitProtocolError(
                   'Client wants invalid object %s' % sha)
                   'Client wants invalid object %s' % sha)
             want_revs.append(sha)
             want_revs.append(sha)
-            command, sha = self.read_proto_line()
+            command, sha = self.read_proto_line(allowed)
 
 
         self.set_wants(want_revs)
         self.set_wants(want_revs)
         return want_revs
         return want_revs
@@ -366,34 +407,14 @@ class ProtocolGraphWalker(object):
             return None
             return None
         return self._cache[self._cache_index]
         return self._cache[self._cache_index]
 
 
-    def _split_proto_line(self, line):
-        fields = line.rstrip('\n').split(' ', 1)
-        if len(fields) == 1 and fields[0] == 'done':
-            return ('done', None)
-        elif len(fields) == 2 and fields[0] in ('want', 'have'):
-            try:
-                hex_to_sha(fields[1])
-                return tuple(fields)
-            except (TypeError, AssertionError), e:
-                raise GitProtocolError(e)
-        raise GitProtocolError('Received invalid line from client:\n%s' % line)
-
-    def read_proto_line(self):
+    def read_proto_line(self, allowed):
         """Read a line from the wire.
         """Read a line from the wire.
 
 
-        :return: a tuple having one of the following forms:
-            ('want', obj_id)
-            ('have', obj_id)
-            ('done', None)
-            (None, None)  (for a flush-pkt)
-
-        :raise GitProtocolError: if the line cannot be parsed into one of the
-            possible return values.
+        :param allowed: An iterable of command names that should be allowed.
+        :return: A tuple of (command, value); see _split_proto_line.
+        :raise GitProtocolError: If an error occurred reading the line.
         """
         """
-        line = self.proto.read_pkt_line()
-        if not line:
-            return (None, None)
-        return self._split_proto_line(line)
+        return _split_proto_line(self.proto.read_pkt_line(), allowed)
 
 
     def send_ack(self, sha, ack_type=''):
     def send_ack(self, sha, ack_type=''):
         if ack_type:
         if ack_type:
@@ -457,6 +478,9 @@ class ProtocolGraphWalker(object):
         self._impl = impl_classes[ack_type](self)
         self._impl = impl_classes[ack_type](self)
 
 
 
 
+_GRAPH_WALKER_COMMANDS = ('have', 'done', None)
+
+
 class SingleAckGraphWalkerImpl(object):
 class SingleAckGraphWalkerImpl(object):
     """Graph walker implementation that speaks the single-ack protocol."""
     """Graph walker implementation that speaks the single-ack protocol."""
 
 
@@ -470,7 +494,7 @@ class SingleAckGraphWalkerImpl(object):
             self._sent_ack = True
             self._sent_ack = True
 
 
     def next(self):
     def next(self):
-        command, sha = self.walker.read_proto_line()
+        command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
         if command in (None, 'done'):
         if command in (None, 'done'):
             if not self._sent_ack:
             if not self._sent_ack:
                 self.walker.send_nak()
                 self.walker.send_nak()
@@ -497,7 +521,7 @@ class MultiAckGraphWalkerImpl(object):
 
 
     def next(self):
     def next(self):
         while True:
         while True:
-            command, sha = self.walker.read_proto_line()
+            command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
             if command is None:
             if command is None:
                 self.walker.send_nak()
                 self.walker.send_nak()
                 # in multi-ack mode, a flush-pkt indicates the client wants to
                 # in multi-ack mode, a flush-pkt indicates the client wants to
@@ -537,7 +561,7 @@ class MultiAckDetailedGraphWalkerImpl(object):
 
 
     def next(self):
     def next(self):
         while True:
         while True:
-            command, sha = self.walker.read_proto_line()
+            command, sha = self.walker.read_proto_line(_GRAPH_WALKER_COMMANDS)
             if command is None:
             if command is None:
                 self.walker.send_nak()
                 self.walker.send_nak()
                 if self.walker.stateless_rpc:
                 if self.walker.stateless_rpc:
@@ -569,8 +593,9 @@ class ReceivePackHandler(Handler):
         self.stateless_rpc = stateless_rpc
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
         self.advertise_refs = advertise_refs
 
 
-    def capabilities(self):
-        return ("report-status", "delete-refs")
+    @classmethod
+    def capabilities(cls):
+        return ("report-status", "delete-refs", "side-band-64k")
 
 
     def _apply_pack(self, refs):
     def _apply_pack(self, refs):
         f, commit = self.repo.object_store.add_thin_pack()
         f, commit = self.repo.object_store.add_thin_pack()
@@ -614,6 +639,29 @@ class ReceivePackHandler(Handler):
 
 
         return status
         return status
 
 
+    def _report_status(self, status):
+        if self.has_capability('side-band-64k'):
+            writer = BufferedPktLineWriter(
+              lambda d: self.proto.write_sideband(1, d))
+            write = writer.write
+
+            def flush():
+                writer.flush()
+                self.proto.write_pkt_line(None)
+        else:
+            write = self.proto.write_pkt_line
+            flush = lambda: None
+
+        for name, msg in status:
+            if name == 'unpack':
+                write('unpack %s\n' % msg)
+            elif msg == 'ok':
+                write('ok %s\n' % name)
+            else:
+                write('ng %s %s\n' % (name, msg))
+        write(None)
+        flush()
+
     def handle(self):
     def handle(self):
         refs = self.repo.get_refs().items()
         refs = self.repo.get_refs().items()
 
 
@@ -654,14 +702,7 @@ class ReceivePackHandler(Handler):
         # when we have read all the pack from the client, send a status report
         # when we have read all the pack from the client, send a status report
         # if the client asked for it
         # if the client asked for it
         if self.has_capability('report-status'):
         if self.has_capability('report-status'):
-            for name, msg in status:
-                if name == 'unpack':
-                    self.proto.write_pkt_line('unpack %s\n' % msg)
-                elif msg == 'ok':
-                    self.proto.write_pkt_line('ok %s\n' % name)
-                else:
-                    self.proto.write_pkt_line('ng %s %s\n' % (name, msg))
-            self.proto.write_pkt_line(None)
+            self._report_status(status)
 
 
 
 
 # Default handler classes for git services.
 # Default handler classes for git services.
@@ -674,7 +715,7 @@ DEFAULT_HANDLERS = {
 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
 
 
     def __init__(self, handlers, *args, **kwargs):
     def __init__(self, handlers, *args, **kwargs):
-        self.handlers = handlers and handlers or DEFAULT_HANDLERS
+        self.handlers = handlers
         SocketServer.StreamRequestHandler.__init__(self, *args, **kwargs)
         SocketServer.StreamRequestHandler.__init__(self, *args, **kwargs)
 
 
     def handle(self):
     def handle(self):
@@ -698,8 +739,10 @@ class TCPGitServer(SocketServer.TCPServer):
         return TCPGitRequestHandler(self.handlers, *args, **kwargs)
         return TCPGitRequestHandler(self.handlers, *args, **kwargs)
 
 
     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT, handlers=None):
     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT, handlers=None):
+        self.handlers = dict(DEFAULT_HANDLERS)
+        if handlers is not None:
+            self.handlers.update(handlers)
         self.backend = backend
         self.backend = backend
-        self.handlers = handlers
         logger.info('Listening for TCP connections on %s:%d', listen_addr, port)
         logger.info('Listening for TCP connections on %s:%d', listen_addr, port)
         SocketServer.TCPServer.__init__(self, (listen_addr, port),
         SocketServer.TCPServer.__init__(self, (listen_addr, port),
                                         self._make_handler)
                                         self._make_handler)

+ 40 - 17
dulwich/tests/compat/server_utils.py

@@ -24,12 +24,15 @@ import select
 import socket
 import socket
 import threading
 import threading
 
 
+from dulwich.server import (
+    ReceivePackHandler,
+    )
 from dulwich.tests.utils import (
 from dulwich.tests.utils import (
     tear_down_repo,
     tear_down_repo,
     )
     )
 from utils import (
 from utils import (
     import_repo,
     import_repo,
-    run_git,
+    run_git_or_fail,
     )
     )
 
 
 
 
@@ -40,41 +43,49 @@ class ServerTests(object):
     """
     """
 
 
     def setUp(self):
     def setUp(self):
-        self._old_repo = import_repo('server_old.export')
-        self._new_repo = import_repo('server_new.export')
+        self._old_repo = None
+        self._new_repo = None
         self._server = None
         self._server = None
 
 
     def tearDown(self):
     def tearDown(self):
         if self._server is not None:
         if self._server is not None:
             self._server.shutdown()
             self._server.shutdown()
             self._server = None
             self._server = None
-        tear_down_repo(self._old_repo)
-        tear_down_repo(self._new_repo)
+        if self._old_repo is not None:
+            tear_down_repo(self._old_repo)
+        if self._new_repo is not None:
+            tear_down_repo(self._new_repo)
+
+    def import_repos(self):
+        self._old_repo = import_repo('server_old.export')
+        self._new_repo = import_repo('server_new.export')
+
+    def url(self, port):
+        return '%s://localhost:%s/' % (self.protocol, port)
+
+    def branch_args(self, branches=None):
+        if branches is None:
+            branches = ['master', 'branch']
+        return ['%s:%s' % (b, b) for b in branches]
 
 
     def test_push_to_dulwich(self):
     def test_push_to_dulwich(self):
+        self.import_repos()
         self.assertReposNotEqual(self._old_repo, self._new_repo)
         self.assertReposNotEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._old_repo)
         port = self._start_server(self._old_repo)
 
 
-        all_branches = ['master', 'branch']
-        branch_args = ['%s:%s' % (b, b) for b in all_branches]
-        url = '%s://localhost:%s/' % (self.protocol, port)
-        returncode, _ = run_git(['push', url] + branch_args,
-                                cwd=self._new_repo.path)
-        self.assertEqual(0, returncode)
+        run_git_or_fail(['push', self.url(port)] + self.branch_args(),
+                        cwd=self._new_repo.path)
         self.assertReposEqual(self._old_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
 
 
     def test_fetch_from_dulwich(self):
     def test_fetch_from_dulwich(self):
+        self.import_repos()
         self.assertReposNotEqual(self._old_repo, self._new_repo)
         self.assertReposNotEqual(self._old_repo, self._new_repo)
         port = self._start_server(self._new_repo)
         port = self._start_server(self._new_repo)
 
 
-        all_branches = ['master', 'branch']
-        branch_args = ['%s:%s' % (b, b) for b in all_branches]
-        url = '%s://localhost:%s/' % (self.protocol, port)
-        returncode, _ = run_git(['fetch', url] + branch_args,
-                                cwd=self._old_repo.path)
+        run_git_or_fail(['fetch', self.url(port)] + self.branch_args(),
+                        cwd=self._old_repo.path)
         # flush the pack cache so any new packs are picked up
         # flush the pack cache so any new packs are picked up
         self._old_repo.object_store._pack_cache = None
         self._old_repo.object_store._pack_cache = None
-        self.assertEqual(0, returncode)
         self.assertReposEqual(self._old_repo, self._new_repo)
         self.assertReposEqual(self._old_repo, self._new_repo)
 
 
 
 
@@ -155,3 +166,15 @@ class ShutdownServerMixIn:
             except:
             except:
                 self.handle_error(request, client_address)
                 self.handle_error(request, client_address)
                 self.close_request(request)
                 self.close_request(request)
+
+
+# TODO(dborowitz): Come up with a better way of testing various permutations of
+# capabilities. The only reason it is the way it is now is that side-band-64k
+# was only recently introduced into git-receive-pack.
+class NoSideBand64kReceivePackHandler(ReceivePackHandler):
+    """ReceivePackHandler that does not support side-band-64k."""
+
+    @classmethod
+    def capabilities(cls):
+        return tuple(c for c in ReceivePackHandler.capabilities()
+                     if c != 'side-band-64k')

+ 5 - 5
dulwich/tests/compat/test_client.py

@@ -40,7 +40,7 @@ from utils import (
     CompatTestCase,
     CompatTestCase,
     check_for_daemon,
     check_for_daemon,
     import_repo_to_dir,
     import_repo_to_dir,
-    run_git,
+    run_git_or_fail,
     )
     )
 
 
 class DulwichClientTestBase(object):
 class DulwichClientTestBase(object):
@@ -50,7 +50,7 @@ class DulwichClientTestBase(object):
         self.gitroot = os.path.dirname(import_repo_to_dir('server_new.export'))
         self.gitroot = os.path.dirname(import_repo_to_dir('server_new.export'))
         dest = os.path.join(self.gitroot, 'dest')
         dest = os.path.join(self.gitroot, 'dest')
         file.ensure_dir_exists(dest)
         file.ensure_dir_exists(dest)
-        run_git(['init', '--quiet', '--bare'], cwd=dest)
+        run_git_or_fail(['init', '--quiet', '--bare'], cwd=dest)
 
 
     def tearDown(self):
     def tearDown(self):
         shutil.rmtree(self.gitroot)
         shutil.rmtree(self.gitroot)
@@ -99,8 +99,8 @@ class DulwichClientTestBase(object):
     def disable_ff_and_make_dummy_commit(self):
     def disable_ff_and_make_dummy_commit(self):
         # disable non-fast-forward pushes to the server
         # disable non-fast-forward pushes to the server
         dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
         dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
-        run_git(['config', 'receive.denyNonFastForwards', 'true'],
-                cwd=dest.path)
+        run_git_or_fail(['config', 'receive.denyNonFastForwards', 'true'],
+                        cwd=dest.path)
         b = objects.Blob.from_string('hi')
         b = objects.Blob.from_string('hi')
         dest.object_store.add_object(b)
         dest.object_store.add_object(b)
         t = index.commit_tree(dest.object_store, [('hi', b.id, 0100644)])
         t = index.commit_tree(dest.object_store, [('hi', b.id, 0100644)])
@@ -176,7 +176,7 @@ class DulwichTCPClientTest(CompatTestCase, DulwichClientTestBase):
         fd, self.pidfile = tempfile.mkstemp(prefix='dulwich-test-git-client',
         fd, self.pidfile = tempfile.mkstemp(prefix='dulwich-test-git-client',
                                             suffix=".pid")
                                             suffix=".pid")
         os.fdopen(fd).close()
         os.fdopen(fd).close()
-        run_git(
+        run_git_or_fail(
             ['daemon', '--verbose', '--export-all',
             ['daemon', '--verbose', '--export-all',
              '--pid-file=%s' % self.pidfile, '--base-path=%s' % self.gitroot,
              '--pid-file=%s' % self.pidfile, '--base-path=%s' % self.gitroot,
              '--detach', '--reuseaddr', '--enable=receive-pack',
              '--detach', '--reuseaddr', '--enable=receive-pack',

+ 2 - 5
dulwich/tests/compat/test_pack.py

@@ -34,7 +34,7 @@ from dulwich.tests.test_pack import (
     )
     )
 from utils import (
 from utils import (
     require_git_version,
     require_git_version,
-    run_git,
+    run_git_or_fail,
     )
     )
 
 
 
 
@@ -56,10 +56,7 @@ class TestPack(PackTests):
         pack_path = os.path.join(self._tempdir, "Elch")
         pack_path = os.path.join(self._tempdir, "Elch")
         write_pack(pack_path, [(x, "") for x in origpack.iterobjects()],
         write_pack(pack_path, [(x, "") for x in origpack.iterobjects()],
                    len(origpack))
                    len(origpack))
-
-        returncode, output = run_git(['verify-pack', '-v', pack_path],
-                                     capture_stdout=True)
-        self.assertEquals(0, returncode)
+        output = run_git_or_fail(['verify-pack', '-v', pack_path])
 
 
         pack_shas = set()
         pack_shas = set()
         for line in output.splitlines():
         for line in output.splitlines():

+ 2 - 5
dulwich/tests/compat/test_repository.py

@@ -35,7 +35,7 @@ from dulwich.tests.utils import (
     )
     )
 
 
 from utils import (
 from utils import (
-    run_git,
+    run_git_or_fail,
     import_repo,
     import_repo,
     CompatTestCase,
     CompatTestCase,
     )
     )
@@ -53,10 +53,7 @@ class ObjectStoreTestCase(CompatTestCase):
         tear_down_repo(self._repo)
         tear_down_repo(self._repo)
 
 
     def _run_git(self, args):
     def _run_git(self, args):
-        returncode, output = run_git(args, capture_stdout=True,
-                                     cwd=self._repo.path)
-        self.assertEqual(0, returncode)
-        return output
+        return run_git_or_fail(args, cwd=self._repo.path)
 
 
     def _parse_refs(self, output):
     def _parse_refs(self, output):
         refs = {}
         refs = {}

+ 32 - 2
dulwich/tests/compat/test_server.py

@@ -29,10 +29,12 @@ import threading
 from dulwich.server import (
 from dulwich.server import (
     DictBackend,
     DictBackend,
     TCPGitServer,
     TCPGitServer,
+    ReceivePackHandler,
     )
     )
 from server_utils import (
 from server_utils import (
     ServerTests,
     ServerTests,
     ShutdownServerMixIn,
     ShutdownServerMixIn,
+    NoSideBand64kReceivePackHandler,
     )
     )
 from utils import (
 from utils import (
     CompatTestCase,
     CompatTestCase,
@@ -54,7 +56,10 @@ if not getattr(TCPGitServer, 'shutdown', None):
 
 
 
 
 class GitServerTestCase(ServerTests, CompatTestCase):
 class GitServerTestCase(ServerTests, CompatTestCase):
-    """Tests for client/server compatibility."""
+    """Tests for client/server compatibility.
+
+    This server test case does not use side-band-64k in git-receive-pack.
+    """
 
 
     protocol = 'git'
     protocol = 'git'
 
 
@@ -66,10 +71,35 @@ class GitServerTestCase(ServerTests, CompatTestCase):
         ServerTests.tearDown(self)
         ServerTests.tearDown(self)
         CompatTestCase.tearDown(self)
         CompatTestCase.tearDown(self)
 
 
+    def _handlers(self):
+        return {'git-receive-pack': NoSideBand64kReceivePackHandler}
+
+    def _check_server(self, dul_server):
+        receive_pack_handler_cls = dul_server.handlers['git-receive-pack']
+        caps = receive_pack_handler_cls.capabilities()
+        self.assertFalse('side-band-64k' in caps)
+
     def _start_server(self, repo):
     def _start_server(self, repo):
         backend = DictBackend({'/': repo})
         backend = DictBackend({'/': repo})
-        dul_server = TCPGitServer(backend, 'localhost', 0)
+        dul_server = TCPGitServer(backend, 'localhost', 0,
+                                  handlers=self._handlers())
+        self._check_server(dul_server)
         threading.Thread(target=dul_server.serve).start()
         threading.Thread(target=dul_server.serve).start()
         self._server = dul_server
         self._server = dul_server
         _, port = self._server.socket.getsockname()
         _, port = self._server.socket.getsockname()
         return port
         return port
+
+
+class GitServerSideBand64kTestCase(GitServerTestCase):
+    """Tests for client/server compatibility with side-band-64k support."""
+
+    # side-band-64k in git-receive-pack was introduced in git 1.7.0.2
+    min_git_version = (1, 7, 0, 2)
+
+    def _handlers(self):
+        return None  # default handlers include side-band-64k
+
+    def _check_server(self, server):
+        receive_pack_handler_cls = server.handlers['git-receive-pack']
+        caps = receive_pack_handler_cls.capabilities()
+        self.assertTrue('side-band-64k' in caps)

+ 91 - 0
dulwich/tests/compat/test_utils.py

@@ -0,0 +1,91 @@
+# test_utils.py -- Tests for git compatibility utilities
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+
+# as published by the Free Software Foundation; either version 2
+# of the License, or (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor,
+# Boston, MA  02110-1301, USA.
+
+"""Tests for git compatibility utilities."""
+
+from unittest import TestCase
+
+from dulwich.tests import (
+    TestSkipped,
+    )
+import utils
+
+
+class GitVersionTests(TestCase):
+
+    def setUp(self):
+        self._orig_run_git = utils.run_git
+        self._version_str = None  # tests can override to set stub version
+
+        def run_git(args, **unused_kwargs):
+            self.assertEqual(['--version'], args)
+            return 0, self._version_str
+        utils.run_git = run_git
+
+    def tearDown(self):
+        utils.run_git = self._orig_run_git
+
+    def test_git_version_none(self):
+        self._version_str = 'not a git version'
+        self.assertEqual(None, utils.git_version())
+
+    def test_git_version_3(self):
+        self._version_str = 'git version 1.6.6'
+        self.assertEqual((1, 6, 6, 0), utils.git_version())
+
+    def test_git_version_4(self):
+        self._version_str = 'git version 1.7.0.2'
+        self.assertEqual((1, 7, 0, 2), utils.git_version())
+
+    def test_git_version_extra(self):
+        self._version_str = 'git version 1.7.0.3.295.gd8fa2'
+        self.assertEqual((1, 7, 0, 3), utils.git_version())
+
+    def assertRequireSucceeds(self, required_version):
+        try:
+            utils.require_git_version(required_version)
+        except TestSkipped:
+            self.fail()
+
+    def assertRequireFails(self, required_version):
+        self.assertRaises(TestSkipped, utils.require_git_version,
+                          required_version)
+
+    def test_require_git_version(self):
+        try:
+            self._version_str = 'git version 1.6.6'
+            self.assertRequireSucceeds((1, 6, 6))
+            self.assertRequireSucceeds((1, 6, 6, 0))
+            self.assertRequireSucceeds((1, 6, 5))
+            self.assertRequireSucceeds((1, 6, 5, 99))
+            self.assertRequireFails((1, 7, 0))
+            self.assertRequireFails((1, 7, 0, 2))
+            self.assertRaises(ValueError, utils.require_git_version,
+                              (1, 6, 6, 0, 0))
+
+            self._version_str = 'git version 1.7.0.2'
+            self.assertRequireSucceeds((1, 6, 6))
+            self.assertRequireSucceeds((1, 6, 6, 0))
+            self.assertRequireSucceeds((1, 7, 0))
+            self.assertRequireSucceeds((1, 7, 0, 2))
+            self.assertRequireFails((1, 7, 0, 3))
+            self.assertRequireFails((1, 7, 1))
+        except TestSkipped, e:
+            # This test is designed to catch all TestSkipped exceptions.
+            self.fail('Test unexpectedly skipped: %s' % e)

+ 31 - 2
dulwich/tests/compat/test_web.py

@@ -42,6 +42,7 @@ from dulwich.web import (
 from server_utils import (
 from server_utils import (
     ServerTests,
     ServerTests,
     ShutdownServerMixIn,
     ShutdownServerMixIn,
+    NoSideBand64kReceivePackHandler,
     )
     )
 from utils import (
 from utils import (
     CompatTestCase,
     CompatTestCase,
@@ -84,7 +85,10 @@ class WebTests(ServerTests):
 
 
 
 
 class SmartWebTestCase(WebTests, CompatTestCase):
 class SmartWebTestCase(WebTests, CompatTestCase):
-    """Test cases for smart HTTP server."""
+    """Test cases for smart HTTP server.
+
+    This server test case does not use side-band-64k in git-receive-pack.
+    """
 
 
     min_git_version = (1, 6, 6)
     min_git_version = (1, 6, 6)
 
 
@@ -96,8 +100,33 @@ class SmartWebTestCase(WebTests, CompatTestCase):
         WebTests.tearDown(self)
         WebTests.tearDown(self)
         CompatTestCase.tearDown(self)
         CompatTestCase.tearDown(self)
 
 
+    def _handlers(self):
+        return {'git-receive-pack': NoSideBand64kReceivePackHandler}
+
+    def _check_app(self, app):
+        receive_pack_handler_cls = app.handlers['git-receive-pack']
+        caps = receive_pack_handler_cls.capabilities()
+        self.assertFalse('side-band-64k' in caps)
+
     def _make_app(self, backend):
     def _make_app(self, backend):
-        return HTTPGitApplication(backend)
+        app = HTTPGitApplication(backend, handlers=self._handlers())
+        self._check_app(app)
+        return app
+
+
+class SmartWebSideBand64kTestCase(SmartWebTestCase):
+    """Test cases for smart HTTP server with side-band-64k support."""
+
+    # side-band-64k in git-receive-pack was introduced in git 1.7.0.2
+    min_git_version = (1, 7, 0, 2)
+
+    def _handlers(self):
+        return None  # default handlers include side-band-64k
+
+    def _check_app(self, app):
+        receive_pack_handler_cls = app.handlers['git-receive-pack']
+        caps = receive_pack_handler_cls.capabilities()
+        self.assertTrue('side-band-64k' in caps)
 
 
 
 
 class DumbWebTestCase(WebTests, CompatTestCase):
 class DumbWebTestCase(WebTests, CompatTestCase):

+ 34 - 15
dulwich/tests/compat/utils.py

@@ -35,6 +35,7 @@ from dulwich.tests import (
     )
     )
 
 
 _DEFAULT_GIT = 'git'
 _DEFAULT_GIT = 'git'
+_VERSION_LEN = 4
 
 
 
 
 def git_version(git_path=_DEFAULT_GIT):
 def git_version(git_path=_DEFAULT_GIT):
@@ -42,32 +43,50 @@ def git_version(git_path=_DEFAULT_GIT):
 
 
     :param git_path: Path to the git executable; defaults to the version in
     :param git_path: Path to the git executable; defaults to the version in
         the system path.
         the system path.
-    :return: A tuple of ints of the form (major, minor, point), or None if no
-        git installation was found.
+    :return: A tuple of ints of the form (major, minor, point, sub-point), or
+        None if no git installation was found.
     """
     """
     try:
     try:
-        _, output = run_git(['--version'], git_path=git_path,
-                            capture_stdout=True)
+        output = run_git_or_fail(['--version'], git_path=git_path)
     except OSError:
     except OSError:
         return None
         return None
     version_prefix = 'git version '
     version_prefix = 'git version '
     if not output.startswith(version_prefix):
     if not output.startswith(version_prefix):
         return None
         return None
-    output = output[len(version_prefix):]
-    nums = output.split('.')
-    if len(nums) == 2:
-        nums.add('0')
-    else:
-        nums = nums[:3]
-    try:
-        return tuple(int(x) for x in nums)
-    except ValueError:
-        return None
+
+    parts = output[len(version_prefix):].split('.')
+    nums = []
+    for part in parts:
+        try:
+            nums.append(int(part))
+        except ValueError:
+            break
+
+    while len(nums) < _VERSION_LEN:
+        nums.append(0)
+    return tuple(nums[:_VERSION_LEN])
 
 
 
 
 def require_git_version(required_version, git_path=_DEFAULT_GIT):
 def require_git_version(required_version, git_path=_DEFAULT_GIT):
-    """Require git version >= version, or skip the calling test."""
+    """Require git version >= version, or skip the calling test.
+
+    :param required_version: A tuple of ints of the form (major, minor, point,
+        sub-point); ommitted components default to 0.
+    :param git_path: Path to the git executable; defaults to the version in
+        the system path.
+    :raise ValueError: if the required version tuple has too many parts.
+    :raise TestSkipped: if no suitable git version was found at the given path.
+    """
     found_version = git_version(git_path=git_path)
     found_version = git_version(git_path=git_path)
+    if len(required_version) > _VERSION_LEN:
+        raise ValueError('Invalid version tuple %s, expected %i parts' %
+                         (required_version, _VERSION_LEN))
+
+    required_version = list(required_version)
+    while len(found_version) < len(required_version):
+        required_version.append(0)
+    required_version = tuple(required_version)
+
     if found_version < required_version:
     if found_version < required_version:
         required_version = '.'.join(map(str, required_version))
         required_version = '.'.join(map(str, required_version))
         found_version = '.'.join(map(str, found_version))
         found_version = '.'.join(map(str, found_version))

+ 16 - 16
dulwich/tests/test_fastexport.py

@@ -20,9 +20,6 @@
 from cStringIO import StringIO
 from cStringIO import StringIO
 import stat
 import stat
 
 
-from dulwich.fastexport import (
-    FastExporter,
-    )
 from dulwich.object_store import (
 from dulwich.object_store import (
     MemoryObjectStore,
     MemoryObjectStore,
     )
     )
@@ -33,48 +30,51 @@ from dulwich.objects import (
     )
     )
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
+    TestSkipped,
     )
     )
 
 
 
 
-class FastExporterTests(TestCase):
+class GitFastExporterTests(TestCase):
 
 
     def setUp(self):
     def setUp(self):
-        super(FastExporterTests, self).setUp()
+        super(GitFastExporterTests, self).setUp()
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
         self.stream = StringIO()
         self.stream = StringIO()
-        self.fastexporter = FastExporter(self.stream, self.store)
+        try:
+            from dulwich.fastexport import GitFastExporter
+        except ImportError:
+            raise TestSkipped("python-fastimport not available")
+        self.fastexporter = GitFastExporter(self.stream, self.store)
 
 
-    def test_export_blob(self):
+    def test_emit_blob(self):
         b = Blob()
         b = Blob()
         b.data = "fooBAR"
         b.data = "fooBAR"
-        self.assertEquals(1, self.fastexporter.export_blob(b))
+        self.fastexporter.emit_blob(b)
         self.assertEquals('blob\nmark :1\ndata 6\nfooBAR\n',
         self.assertEquals('blob\nmark :1\ndata 6\nfooBAR\n',
             self.stream.getvalue())
             self.stream.getvalue())
 
 
-    def test_export_commit(self):
+    def test_emit_commit(self):
         b = Blob()
         b = Blob()
         b.data = "FOO"
         b.data = "FOO"
         t = Tree()
         t = Tree()
         t.add(stat.S_IFREG | 0644, "foo", b.id)
         t.add(stat.S_IFREG | 0644, "foo", b.id)
         c = Commit()
         c = Commit()
         c.committer = c.author = "Jelmer <jelmer@host>"
         c.committer = c.author = "Jelmer <jelmer@host>"
-        c.author_time = c.commit_time = 1271345553.47
+        c.author_time = c.commit_time = 1271345553
         c.author_timezone = c.commit_timezone = 0
         c.author_timezone = c.commit_timezone = 0
         c.message = "msg"
         c.message = "msg"
         c.tree = t.id
         c.tree = t.id
         self.store.add_objects([(b, None), (t, None), (c, None)])
         self.store.add_objects([(b, None), (t, None), (c, None)])
-        self.assertEquals(2,
-                self.fastexporter.export_commit(c, "refs/heads/master"))
+        self.fastexporter.emit_commit(c, "refs/heads/master")
         self.assertEquals("""blob
         self.assertEquals("""blob
 mark :1
 mark :1
 data 3
 data 3
 FOO
 FOO
 commit refs/heads/master
 commit refs/heads/master
 mark :2
 mark :2
-author Jelmer <jelmer@host> 1271345553.47 +0000
-committer Jelmer <jelmer@host> 1271345553.47 +0000
+author Jelmer <jelmer@host> 1271345553 +0000
+committer Jelmer <jelmer@host> 1271345553 +0000
 data 3
 data 3
 msg
 msg
-M 100644 :1 foo
-
+M 644 1 foo
 """, self.stream.getvalue())
 """, self.stream.getvalue())

+ 5 - 0
dulwich/tests/test_file.py

@@ -119,6 +119,11 @@ class GitFileTests(TestCase):
         self.assertEquals('contents', f.read())
         self.assertEquals('contents', f.read())
         f.close()
         f.close()
 
 
+    def test_default_mode(self):
+        f = GitFile(self.path('foo'))
+        self.assertEquals('foo contents', f.read())
+        f.close()
+
     def test_write(self):
     def test_write(self):
         foo = self.path('foo')
         foo = self.path('foo')
         foo_lock = '%s.lock' % foo
         foo_lock = '%s.lock' % foo

+ 132 - 0
dulwich/tests/test_object_store.py

@@ -23,12 +23,26 @@ import os
 import shutil
 import shutil
 import tempfile
 import tempfile
 
 
+from dulwich.index import (
+    commit_tree,
+    )
+from dulwich.errors import (
+    NotTreeError,
+    )
 from dulwich.objects import (
 from dulwich.objects import (
+    object_class,
     Blob,
     Blob,
+    ShaFile,
+    Tag,
+    Tree,
     )
     )
 from dulwich.object_store import (
 from dulwich.object_store import (
     DiskObjectStore,
     DiskObjectStore,
     MemoryObjectStore,
     MemoryObjectStore,
+    tree_lookup_path,
+    )
+from dulwich.pack import (
+    write_pack_data,
     )
     )
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
@@ -75,6 +89,68 @@ class ObjectStoreTests(object):
         r = self.store[testobject.id]
         r = self.store[testobject.id]
         self.assertEquals(r, testobject)
         self.assertEquals(r, testobject)
 
 
+    def test_iter_tree_contents(self):
+        blob_a = make_object(Blob, data='a')
+        blob_b = make_object(Blob, data='b')
+        blob_c = make_object(Blob, data='c')
+        for blob in [blob_a, blob_b, blob_c]:
+            self.store.add_object(blob)
+
+        blobs = [
+          ('a', blob_a.id, 0100644),
+          ('ad/b', blob_b.id, 0100644),
+          ('ad/bd/c', blob_c.id, 0100755),
+          ('ad/c', blob_c.id, 0100644),
+          ('c', blob_c.id, 0100644),
+          ]
+        tree_id = commit_tree(self.store, blobs)
+        self.assertEquals([(p, m, h) for (p, h, m) in blobs],
+                          list(self.store.iter_tree_contents(tree_id)))
+
+    def test_iter_tree_contents_include_trees(self):
+        blob_a = make_object(Blob, data='a')
+        blob_b = make_object(Blob, data='b')
+        blob_c = make_object(Blob, data='c')
+        for blob in [blob_a, blob_b, blob_c]:
+            self.store.add_object(blob)
+
+        blobs = [
+          ('a', blob_a.id, 0100644),
+          ('ad/b', blob_b.id, 0100644),
+          ('ad/bd/c', blob_c.id, 0100755),
+          ]
+        tree_id = commit_tree(self.store, blobs)
+        tree = self.store[tree_id]
+        tree_ad = self.store[tree['ad'][1]]
+        tree_bd = self.store[tree_ad['bd'][1]]
+
+        expected = [
+          ('', 0040000, tree_id),
+          ('a', 0100644, blob_a.id),
+          ('ad', 0040000, tree_ad.id),
+          ('ad/b', 0100644, blob_b.id),
+          ('ad/bd', 0040000, tree_bd.id),
+          ('ad/bd/c', 0100755, blob_c.id),
+          ]
+        actual = self.store.iter_tree_contents(tree_id, include_trees=True)
+        self.assertEquals(expected, list(actual))
+
+    def make_tag(self, name, obj):
+        tag = make_object(Tag, name=name, message='',
+                          tag_time=12345, tag_timezone=0,
+                          tagger='Test Tagger <test@example.com>',
+                          object=(object_class(obj.type_name), obj.id))
+        self.store.add_object(tag)
+        return tag
+
+    def test_peel_sha(self):
+        self.store.add_object(testobject)
+        tag1 = self.make_tag('1', testobject)
+        tag2 = self.make_tag('2', testobject)
+        tag3 = self.make_tag('3', testobject)
+        for obj in [testobject, tag1, tag2, tag3]:
+            self.assertEqual(testobject, self.store.peel_sha(obj.id))
+
 
 
 class MemoryObjectStoreTests(ObjectStoreTests, TestCase):
 class MemoryObjectStoreTests(ObjectStoreTests, TestCase):
 
 
@@ -114,4 +190,60 @@ class DiskObjectStoreTests(PackBasedObjectStoreTests, TestCase):
         o = DiskObjectStore(self.store_dir)
         o = DiskObjectStore(self.store_dir)
         self.assertEquals(os.path.join(self.store_dir, "pack"), o.pack_dir)
         self.assertEquals(os.path.join(self.store_dir, "pack"), o.pack_dir)
 
 
+    def test_add_pack(self):
+        o = DiskObjectStore(self.store_dir)
+        f, commit = o.add_pack()
+        b = make_object(Blob, data="more yummy data")
+        write_pack_data(f, [(b, None)], 1)
+        commit()
+
+    def test_add_thin_pack(self):
+        o = DiskObjectStore(self.store_dir)
+        f, commit = o.add_thin_pack()
+        b = make_object(Blob, data="more yummy data")
+        write_pack_data(f, [(b, None)], 1)
+        commit()
+
+
+class TreeLookupPathTests(TestCase):
+
+    def setUp(self):
+        TestCase.setUp(self)
+        self.store = MemoryObjectStore()
+        blob_a = make_object(Blob, data='a')
+        blob_b = make_object(Blob, data='b')
+        blob_c = make_object(Blob, data='c')
+        for blob in [blob_a, blob_b, blob_c]:
+            self.store.add_object(blob)
+
+        blobs = [
+          ('a', blob_a.id, 0100644),
+          ('ad/b', blob_b.id, 0100644),
+          ('ad/bd/c', blob_c.id, 0100755),
+          ('ad/c', blob_c.id, 0100644),
+          ('c', blob_c.id, 0100644),
+          ]
+        self.tree_id = commit_tree(self.store, blobs)
+
+    def get_object(self, sha):
+        return self.store[sha]
+
+    def test_lookup_blob(self):
+        o_id = tree_lookup_path(self.get_object, self.tree_id, 'a')[1]
+        self.assertTrue(isinstance(self.store[o_id], Blob))
+
+    def test_lookup_tree(self):
+        o_id = tree_lookup_path(self.get_object, self.tree_id, 'ad')[1]
+        self.assertTrue(isinstance(self.store[o_id], Tree))
+        o_id = tree_lookup_path(self.get_object, self.tree_id, 'ad/bd')[1]
+        self.assertTrue(isinstance(self.store[o_id], Tree))
+        o_id = tree_lookup_path(self.get_object, self.tree_id, 'ad/bd/')[1]
+        self.assertTrue(isinstance(self.store[o_id], Tree))
+
+    def test_lookup_nonexistent(self):
+        self.assertRaises(KeyError, tree_lookup_path, self.get_object, self.tree_id, 'j')
+
+    def test_lookup_not_tree(self):
+        self.assertRaises(NotTreeError, tree_lookup_path, self.get_object, self.tree_id, 'ad/b/j')
+
 # TODO: MissingObjectFinderTests
 # TODO: MissingObjectFinderTests

+ 7 - 1
dulwich/tests/test_objects.py

@@ -151,7 +151,6 @@ class BlobReadTests(TestCase):
     def test_legacy_from_file(self):
     def test_legacy_from_file(self):
         b1 = Blob.from_string("foo")
         b1 = Blob.from_string("foo")
         b_raw = b1.as_legacy_object()
         b_raw = b1.as_legacy_object()
-        open('x', 'w+').write(b_raw)
         b2 = b1.from_file(StringIO(b_raw))
         b2 = b1.from_file(StringIO(b_raw))
         self.assertEquals(b1, b2)
         self.assertEquals(b1, b2)
 
 
@@ -235,6 +234,13 @@ class BlobReadTests(TestCase):
         self.assertEqual(c.author_timezone, 0)
         self.assertEqual(c.author_timezone, 0)
         self.assertEqual(c.message, 'Merge ../b\n')
         self.assertEqual(c.message, 'Merge ../b\n')
 
 
+    def test_stub_sha(self):
+        sha = '5' * 40
+        c = make_commit(id=sha, message='foo')
+        self.assertTrue(isinstance(c, Commit))
+        self.assertEqual(sha, c.id)
+        self.assertNotEqual(sha, c._make_sha())
+
 
 
 class ShaFileCheckTests(TestCase):
 class ShaFileCheckTests(TestCase):
 
 

+ 79 - 28
dulwich/tests/test_pack.py

@@ -38,12 +38,15 @@ from dulwich.objects import (
     Tree,
     Tree,
     )
     )
 from dulwich.pack import (
 from dulwich.pack import (
+    MemoryPackIndex,
     Pack,
     Pack,
     PackData,
     PackData,
+    ThinPackData,
     apply_delta,
     apply_delta,
     create_delta,
     create_delta,
     load_pack_index,
     load_pack_index,
     read_zlib_chunks,
     read_zlib_chunks,
+    write_pack_header,
     write_pack_index_v1,
     write_pack_index_v1,
     write_pack_index_v2,
     write_pack_index_v2,
     write_pack,
     write_pack,
@@ -162,6 +165,25 @@ class TestPackData(PackTests):
     def test_create_pack(self):
     def test_create_pack(self):
         p = self.get_pack_data(pack1_sha)
         p = self.get_pack_data(pack1_sha)
 
 
+    def test_from_file(self):
+        path = os.path.join(self.datadir, 'pack-%s.pack' % pack1_sha)
+        PackData.from_file(open(path), os.path.getsize(path))
+
+    # TODO: more ThinPackData tests.
+    def test_thin_from_file(self):
+        test_sha = '1' * 40
+
+        def resolve(sha):
+            self.assertEqual(test_sha, sha)
+            return 3, 'data'
+
+        path = os.path.join(self.datadir, 'pack-%s.pack' % pack1_sha)
+        data = ThinPackData.from_file(resolve, open(path),
+                                      os.path.getsize(path))
+        idx = self.get_pack_index(pack1_sha)
+        Pack.from_objects(data, idx)
+        self.assertEqual((None, 3, 'data'), data.get_ref(test_sha))
+
     def test_pack_len(self):
     def test_pack_len(self):
         p = self.get_pack_data(pack1_sha)
         p = self.get_pack_data(pack1_sha)
         self.assertEquals(3, len(p))
         self.assertEquals(3, len(p))
@@ -277,16 +299,19 @@ class TestPack(PackTests):
         self.assertEquals(pack1_sha, p.name())
         self.assertEquals(pack1_sha, p.name())
 
 
 
 
-pack_checksum = hex_to_sha('721980e866af9a5f93ad674144e1459b8ba3e7b7')
+class WritePackHeaderTests(TestCase):
 
 
+    def test_simple(self):
+        f = StringIO()
+        write_pack_header(f, 42)
+        self.assertEquals('PACK\x00\x00\x00\x02\x00\x00\x00*',
+                f.getvalue())
 
 
-class BaseTestPackIndexWriting(object):
 
 
-    def setUp(self):
-        self.tempdir = tempfile.mkdtemp()
+pack_checksum = hex_to_sha('721980e866af9a5f93ad674144e1459b8ba3e7b7')
 
 
-    def tearDown(self):
-        shutil.rmtree(self.tempdir)
+
+class BaseTestPackIndexWriting(object):
 
 
     def assertSucceeds(self, func, *args, **kwargs):
     def assertSucceeds(self, func, *args, **kwargs):
         try:
         try:
@@ -294,30 +319,18 @@ class BaseTestPackIndexWriting(object):
         except ChecksumMismatch, e:
         except ChecksumMismatch, e:
             self.fail(e)
             self.fail(e)
 
 
-    def writeIndex(self, filename, entries, pack_checksum):
-        # FIXME: Write to StringIO instead rather than hitting disk ?
-        f = GitFile(filename, "wb")
-        try:
-            self._write_fn(f, entries, pack_checksum)
-        finally:
-            f.close()
+    def index(self, filename, entries, pack_checksum):
+        raise NotImplementedError(self.index)
 
 
     def test_empty(self):
     def test_empty(self):
-        filename = os.path.join(self.tempdir, 'empty.idx')
-        self.writeIndex(filename, [], pack_checksum)
-        idx = load_pack_index(filename)
-        self.assertSucceeds(idx.check)
+        idx = self.index('empty.idx', [], pack_checksum)
         self.assertEquals(idx.get_pack_checksum(), pack_checksum)
         self.assertEquals(idx.get_pack_checksum(), pack_checksum)
         self.assertEquals(0, len(idx))
         self.assertEquals(0, len(idx))
 
 
     def test_single(self):
     def test_single(self):
         entry_sha = hex_to_sha('6f670c0fb53f9463760b7295fbb814e965fb20c8')
         entry_sha = hex_to_sha('6f670c0fb53f9463760b7295fbb814e965fb20c8')
         my_entries = [(entry_sha, 178, 42)]
         my_entries = [(entry_sha, 178, 42)]
-        filename = os.path.join(self.tempdir, 'single.idx')
-        self.writeIndex(filename, my_entries, pack_checksum)
-        idx = load_pack_index(filename)
-        self.assertEquals(idx.version, self._expected_version)
-        self.assertSucceeds(idx.check)
+        idx = self.index('single.idx', my_entries, pack_checksum)
         self.assertEquals(idx.get_pack_checksum(), pack_checksum)
         self.assertEquals(idx.get_pack_checksum(), pack_checksum)
         self.assertEquals(1, len(idx))
         self.assertEquals(1, len(idx))
         actual_entries = list(idx.iterentries())
         actual_entries = list(idx.iterentries())
@@ -333,32 +346,70 @@ class BaseTestPackIndexWriting(object):
                 self.assertTrue(actual_crc is None)
                 self.assertTrue(actual_crc is None)
 
 
 
 
-class TestPackIndexWritingv1(TestCase, BaseTestPackIndexWriting):
+class BaseTestFilePackIndexWriting(BaseTestPackIndexWriting):
+
+    def setUp(self):
+        self.tempdir = tempfile.mkdtemp()
+
+    def tearDown(self):
+        shutil.rmtree(self.tempdir)
+
+    def index(self, filename, entries, pack_checksum):
+        path = os.path.join(self.tempdir, filename)
+        self.writeIndex(path, entries, pack_checksum)
+        idx = load_pack_index(path)
+        self.assertSucceeds(idx.check)
+        self.assertEquals(idx.version, self._expected_version)
+        return idx
+
+    def writeIndex(self, filename, entries, pack_checksum):
+        # FIXME: Write to StringIO instead rather than hitting disk ?
+        f = GitFile(filename, "wb")
+        try:
+            self._write_fn(f, entries, pack_checksum)
+        finally:
+            f.close()
+
+
+class TestMemoryIndexWriting(TestCase, BaseTestPackIndexWriting):
+
+    def setUp(self):
+        TestCase.setUp(self)
+        self._has_crc32_checksum = True
+
+    def index(self, filename, entries, pack_checksum):
+        return MemoryPackIndex(entries, pack_checksum)
+
+    def tearDown(self):
+        TestCase.tearDown(self)
+
+
+class TestPackIndexWritingv1(TestCase, BaseTestFilePackIndexWriting):
 
 
     def setUp(self):
     def setUp(self):
         TestCase.setUp(self)
         TestCase.setUp(self)
-        BaseTestPackIndexWriting.setUp(self)
+        BaseTestFilePackIndexWriting.setUp(self)
         self._has_crc32_checksum = False
         self._has_crc32_checksum = False
         self._expected_version = 1
         self._expected_version = 1
         self._write_fn = write_pack_index_v1
         self._write_fn = write_pack_index_v1
 
 
     def tearDown(self):
     def tearDown(self):
         TestCase.tearDown(self)
         TestCase.tearDown(self)
-        BaseTestPackIndexWriting.tearDown(self)
+        BaseTestFilePackIndexWriting.tearDown(self)
 
 
 
 
-class TestPackIndexWritingv2(TestCase, BaseTestPackIndexWriting):
+class TestPackIndexWritingv2(TestCase, BaseTestFilePackIndexWriting):
 
 
     def setUp(self):
     def setUp(self):
         TestCase.setUp(self)
         TestCase.setUp(self)
-        BaseTestPackIndexWriting.setUp(self)
+        BaseTestFilePackIndexWriting.setUp(self)
         self._has_crc32_checksum = True
         self._has_crc32_checksum = True
         self._expected_version = 2
         self._expected_version = 2
         self._write_fn = write_pack_index_v2
         self._write_fn = write_pack_index_v2
 
 
     def tearDown(self):
     def tearDown(self):
         TestCase.tearDown(self)
         TestCase.tearDown(self)
-        BaseTestPackIndexWriting.tearDown(self)
+        BaseTestFilePackIndexWriting.tearDown(self)
 
 
 
 
 class ReadZlibTests(TestCase):
 class ReadZlibTests(TestCase):

+ 26 - 1
dulwich/tests/test_patch.py

@@ -28,7 +28,9 @@ from dulwich.patch import (
     git_am_patch_split,
     git_am_patch_split,
     write_commit_patch,
     write_commit_patch,
     )
     )
-from dulwich.tests import TestCase
+from dulwich.tests import (
+    TestCase,
+    )
 
 
 
 
 class WriteCommitPatchTests(TestCase):
 class WriteCommitPatchTests(TestCase):
@@ -80,9 +82,32 @@ Subject: [PATCH 1/2] Remove executable bit from prey.ico (triggers a lintian war
         c, diff, version = git_am_patch_split(StringIO(text))
         c, diff, version = git_am_patch_split(StringIO(text))
         self.assertEquals("Jelmer Vernooij <jelmer@samba.org>", c.committer)
         self.assertEquals("Jelmer Vernooij <jelmer@samba.org>", c.committer)
         self.assertEquals("Jelmer Vernooij <jelmer@samba.org>", c.author)
         self.assertEquals("Jelmer Vernooij <jelmer@samba.org>", c.author)
+        self.assertEquals("Remove executable bit from prey.ico "
+            "(triggers a lintian warning).\n", c.message)
         self.assertEquals(""" pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
         self.assertEquals(""" pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
  1 files changed, 0 insertions(+), 0 deletions(-)
  1 files changed, 0 insertions(+), 0 deletions(-)
  mode change 100755 => 100644 pixmaps/prey.ico
  mode change 100755 => 100644 pixmaps/prey.ico
 
 
 """, diff)
 """, diff)
         self.assertEquals("1.7.0.4", version)
         self.assertEquals("1.7.0.4", version)
+
+    def test_extract_spaces(self):
+        text = """From ff643aae102d8870cac88e8f007e70f58f3a7363 Mon Sep 17 00:00:00 2001
+From: Jelmer Vernooij <jelmer@samba.org>
+Date: Thu, 15 Apr 2010 15:40:28 +0200
+Subject:  [Dulwich-users] [PATCH] Added unit tests for
+ dulwich.object_store.tree_lookup_path.
+
+* dulwich/tests/test_object_store.py
+  (TreeLookupPathTests): This test case contains a few tests that ensure the
+   tree_lookup_path function works as expected.
+---
+ pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
+ 1 files changed, 0 insertions(+), 0 deletions(-)
+ mode change 100755 => 100644 pixmaps/prey.ico
+
+-- 
+1.7.0.4
+"""
+        c, diff, version = git_am_patch_split(StringIO(text))
+        self.assertEquals('Added unit tests for dulwich.object_store.tree_lookup_path.\n\n* dulwich/tests/test_object_store.py\n  (TreeLookupPathTests): This test case contains a few tests that ensure the\n   tree_lookup_path function works as expected.\n', c.message)

+ 91 - 35
dulwich/tests/test_protocol.py

@@ -30,6 +30,7 @@ from dulwich.protocol import (
     SINGLE_ACK,
     SINGLE_ACK,
     MULTI_ACK,
     MULTI_ACK,
     MULTI_ACK_DETAILED,
     MULTI_ACK_DETAILED,
+    BufferedPktLineWriter,
     )
     )
 from dulwich.tests import TestCase
 from dulwich.tests import TestCase
 
 
@@ -38,42 +39,42 @@ class BaseProtocolTests(object):
 
 
     def test_write_pkt_line_none(self):
     def test_write_pkt_line_none(self):
         self.proto.write_pkt_line(None)
         self.proto.write_pkt_line(None)
-        self.assertEquals(self.rout.getvalue(), "0000")
+        self.assertEquals(self.rout.getvalue(), '0000')
 
 
     def test_write_pkt_line(self):
     def test_write_pkt_line(self):
-        self.proto.write_pkt_line("bla")
-        self.assertEquals(self.rout.getvalue(), "0007bla")
+        self.proto.write_pkt_line('bla')
+        self.assertEquals(self.rout.getvalue(), '0007bla')
 
 
     def test_read_pkt_line(self):
     def test_read_pkt_line(self):
-        self.rin.write("0008cmd ")
+        self.rin.write('0008cmd ')
         self.rin.seek(0)
         self.rin.seek(0)
-        self.assertEquals("cmd ", self.proto.read_pkt_line())
+        self.assertEquals('cmd ', self.proto.read_pkt_line())
 
 
     def test_read_pkt_seq(self):
     def test_read_pkt_seq(self):
-        self.rin.write("0008cmd 0005l0000")
+        self.rin.write('0008cmd 0005l0000')
         self.rin.seek(0)
         self.rin.seek(0)
-        self.assertEquals(["cmd ", "l"], list(self.proto.read_pkt_seq()))
+        self.assertEquals(['cmd ', 'l'], list(self.proto.read_pkt_seq()))
 
 
     def test_read_pkt_line_none(self):
     def test_read_pkt_line_none(self):
-        self.rin.write("0000")
+        self.rin.write('0000')
         self.rin.seek(0)
         self.rin.seek(0)
         self.assertEquals(None, self.proto.read_pkt_line())
         self.assertEquals(None, self.proto.read_pkt_line())
 
 
     def test_write_sideband(self):
     def test_write_sideband(self):
-        self.proto.write_sideband(3, "bloe")
-        self.assertEquals(self.rout.getvalue(), "0009\x03bloe")
+        self.proto.write_sideband(3, 'bloe')
+        self.assertEquals(self.rout.getvalue(), '0009\x03bloe')
 
 
     def test_send_cmd(self):
     def test_send_cmd(self):
-        self.proto.send_cmd("fetch", "a", "b")
-        self.assertEquals(self.rout.getvalue(), "000efetch a\x00b\x00")
+        self.proto.send_cmd('fetch', 'a', 'b')
+        self.assertEquals(self.rout.getvalue(), '000efetch a\x00b\x00')
 
 
     def test_read_cmd(self):
     def test_read_cmd(self):
-        self.rin.write("0012cmd arg1\x00arg2\x00")
+        self.rin.write('0012cmd arg1\x00arg2\x00')
         self.rin.seek(0)
         self.rin.seek(0)
-        self.assertEquals(("cmd", ["arg1", "arg2"]), self.proto.read_cmd())
+        self.assertEquals(('cmd', ['arg1', 'arg2']), self.proto.read_cmd())
 
 
     def test_read_cmd_noend0(self):
     def test_read_cmd_noend0(self):
-        self.rin.write("0011cmd arg1\x00arg2")
+        self.rin.write('0011cmd arg1\x00arg2')
         self.rin.seek(0)
         self.rin.seek(0)
         self.assertRaises(AssertionError, self.proto.read_cmd)
         self.assertRaises(AssertionError, self.proto.read_cmd)
 
 
@@ -94,7 +95,7 @@ class ReceivableStringIO(StringIO):
         # fail fast if no bytes are available; in a real socket, this would
         # fail fast if no bytes are available; in a real socket, this would
         # block forever
         # block forever
         if self.tell() == len(self.getvalue()):
         if self.tell() == len(self.getvalue()):
-            raise AssertionError("Blocking read past end of socket")
+            raise AssertionError('Blocking read past end of socket')
         if size == 1:
         if size == 1:
             return self.read(1)
             return self.read(1)
         # calls shouldn't return quite as much as asked for
         # calls shouldn't return quite as much as asked for
@@ -111,10 +112,10 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         self.proto._rbufsize = 8
         self.proto._rbufsize = 8
 
 
     def test_recv(self):
     def test_recv(self):
-        all_data = "1234567" * 10  # not a multiple of bufsize
+        all_data = '1234567' * 10  # not a multiple of bufsize
         self.rin.write(all_data)
         self.rin.write(all_data)
         self.rin.seek(0)
         self.rin.seek(0)
-        data = ""
+        data = ''
         # We ask for 8 bytes each time and actually read 7, so it should take
         # We ask for 8 bytes each time and actually read 7, so it should take
         # exactly 10 iterations.
         # exactly 10 iterations.
         for _ in xrange(10):
         for _ in xrange(10):
@@ -124,28 +125,28 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
         self.assertEquals(all_data, data)
         self.assertEquals(all_data, data)
 
 
     def test_recv_read(self):
     def test_recv_read(self):
-        all_data = "1234567"  # recv exactly in one call
+        all_data = '1234567'  # recv exactly in one call
         self.rin.write(all_data)
         self.rin.write(all_data)
         self.rin.seek(0)
         self.rin.seek(0)
-        self.assertEquals("1234", self.proto.recv(4))
-        self.assertEquals("567", self.proto.read(3))
+        self.assertEquals('1234', self.proto.recv(4))
+        self.assertEquals('567', self.proto.read(3))
         self.assertRaises(AssertionError, self.proto.recv, 10)
         self.assertRaises(AssertionError, self.proto.recv, 10)
 
 
     def test_read_recv(self):
     def test_read_recv(self):
-        all_data = "12345678abcdefg"
+        all_data = '12345678abcdefg'
         self.rin.write(all_data)
         self.rin.write(all_data)
         self.rin.seek(0)
         self.rin.seek(0)
-        self.assertEquals("1234", self.proto.read(4))
-        self.assertEquals("5678abc", self.proto.recv(8))
-        self.assertEquals("defg", self.proto.read(4))
+        self.assertEquals('1234', self.proto.read(4))
+        self.assertEquals('5678abc', self.proto.recv(8))
+        self.assertEquals('defg', self.proto.read(4))
         self.assertRaises(AssertionError, self.proto.recv, 10)
         self.assertRaises(AssertionError, self.proto.recv, 10)
 
 
     def test_mixed(self):
     def test_mixed(self):
         # arbitrary non-repeating string
         # arbitrary non-repeating string
-        all_data = ",".join(str(i) for i in xrange(100))
+        all_data = ','.join(str(i) for i in xrange(100))
         self.rin.write(all_data)
         self.rin.write(all_data)
         self.rin.seek(0)
         self.rin.seek(0)
-        data = ""
+        data = ''
 
 
         for i in xrange(1, 100):
         for i in xrange(1, 100):
             data += self.proto.recv(i)
             data += self.proto.recv(i)
@@ -168,20 +169,20 @@ class ReceivableProtocolTests(BaseProtocolTests, TestCase):
 class CapabilitiesTestCase(TestCase):
 class CapabilitiesTestCase(TestCase):
 
 
     def test_plain(self):
     def test_plain(self):
-        self.assertEquals(("bla", []), extract_capabilities("bla"))
+        self.assertEquals(('bla', []), extract_capabilities('bla'))
 
 
     def test_caps(self):
     def test_caps(self):
-        self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la"))
-        self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la\n"))
-        self.assertEquals(("bla", ["la", "la"]), extract_capabilities("bla\0la la"))
+        self.assertEquals(('bla', ['la']), extract_capabilities('bla\0la'))
+        self.assertEquals(('bla', ['la']), extract_capabilities('bla\0la\n'))
+        self.assertEquals(('bla', ['la', 'la']), extract_capabilities('bla\0la la'))
 
 
     def test_plain_want_line(self):
     def test_plain_want_line(self):
-        self.assertEquals(("want bla", []), extract_want_line_capabilities("want bla"))
+        self.assertEquals(('want bla', []), extract_want_line_capabilities('want bla'))
 
 
     def test_caps_want_line(self):
     def test_caps_want_line(self):
-        self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la"))
-        self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la\n"))
-        self.assertEquals(("want bla", ["la", "la"]), extract_want_line_capabilities("want bla la la"))
+        self.assertEquals(('want bla', ['la']), extract_want_line_capabilities('want bla la'))
+        self.assertEquals(('want bla', ['la']), extract_want_line_capabilities('want bla la\n'))
+        self.assertEquals(('want bla', ['la', 'la']), extract_want_line_capabilities('want bla la la'))
 
 
     def test_ack_type(self):
     def test_ack_type(self):
         self.assertEquals(SINGLE_ACK, ack_type(['foo', 'bar']))
         self.assertEquals(SINGLE_ACK, ack_type(['foo', 'bar']))
@@ -192,3 +193,58 @@ class CapabilitiesTestCase(TestCase):
         self.assertEquals(MULTI_ACK_DETAILED,
         self.assertEquals(MULTI_ACK_DETAILED,
                           ack_type(['foo', 'bar', 'multi_ack',
                           ack_type(['foo', 'bar', 'multi_ack',
                                     'multi_ack_detailed']))
                                     'multi_ack_detailed']))
+
+
+class BufferedPktLineWriterTests(TestCase):
+
+    def setUp(self):
+        TestCase.setUp(self)
+        self._output = StringIO()
+        self._writer = BufferedPktLineWriter(self._output.write, bufsize=16)
+
+    def assertOutputEquals(self, expected):
+        self.assertEquals(expected, self._output.getvalue())
+
+    def _truncate(self):
+        self._output.seek(0)
+        self._output.truncate()
+
+    def test_write(self):
+        self._writer.write('foo')
+        self.assertOutputEquals('')
+        self._writer.flush()
+        self.assertOutputEquals('0007foo')
+
+    def test_write_none(self):
+        self._writer.write(None)
+        self.assertOutputEquals('')
+        self._writer.flush()
+        self.assertOutputEquals('0000')
+
+    def test_flush_empty(self):
+        self._writer.flush()
+        self.assertOutputEquals('')
+
+    def test_write_multiple(self):
+        self._writer.write('foo')
+        self._writer.write('bar')
+        self.assertOutputEquals('')
+        self._writer.flush()
+        self.assertOutputEquals('0007foo0007bar')
+
+    def test_write_across_boundary(self):
+        self._writer.write('foo')
+        self._writer.write('barbaz')
+        self.assertOutputEquals('0007foo000abarba')
+        self._truncate()
+        self._writer.flush()
+        self.assertOutputEquals('z')
+
+    def test_write_to_boundary(self):
+        self._writer.write('foo')
+        self._writer.write('barba')
+        self.assertOutputEquals('0007foo0009barba')
+        self._truncate()
+        self._writer.write('z')
+        self._writer.flush()
+        self.assertOutputEquals('0005z')

+ 5 - 2
dulwich/tests/test_repository.py

@@ -26,6 +26,9 @@ import tempfile
 import warnings
 import warnings
 
 
 from dulwich import errors
 from dulwich import errors
+from dulwich.file import (
+    GitFile,
+    )
 from dulwich.object_store import (
 from dulwich.object_store import (
     tree_lookup_path,
     tree_lookup_path,
     )
     )
@@ -767,10 +770,10 @@ class DiskRefsContainerTests(RefsContainerTests, TestCase):
 
 
     def test_remove_packed_without_peeled(self):
     def test_remove_packed_without_peeled(self):
         refs_file = os.path.join(self._repo.path, 'packed-refs')
         refs_file = os.path.join(self._repo.path, 'packed-refs')
-        f = open(refs_file)
+        f = GitFile(refs_file)
         refs_data = f.read()
         refs_data = f.read()
         f.close()
         f.close()
-        f = open(refs_file, 'wb')
+        f = GitFile(refs_file, 'wb')
         f.write('\n'.join(l for l in refs_data.split('\n')
         f.write('\n'.join(l for l in refs_data.split('\n')
                           if not l or l[0] not in '#^'))
                           if not l or l[0] not in '#^'))
         f.close()
         f.close()

+ 91 - 94
dulwich/tests/test_server.py

@@ -21,20 +21,26 @@
 
 
 from dulwich.errors import (
 from dulwich.errors import (
     GitProtocolError,
     GitProtocolError,
+    UnexpectedCommandError,
+    )
+from dulwich.repo import (
+    MemoryRepo,
     )
     )
 from dulwich.server import (
 from dulwich.server import (
     Backend,
     Backend,
     DictBackend,
     DictBackend,
-    BackendRepo,
     Handler,
     Handler,
     MultiAckGraphWalkerImpl,
     MultiAckGraphWalkerImpl,
     MultiAckDetailedGraphWalkerImpl,
     MultiAckDetailedGraphWalkerImpl,
+    _split_proto_line,
     ProtocolGraphWalker,
     ProtocolGraphWalker,
     SingleAckGraphWalkerImpl,
     SingleAckGraphWalkerImpl,
     UploadPackHandler,
     UploadPackHandler,
     )
     )
 from dulwich.tests import TestCase
 from dulwich.tests import TestCase
-
+from utils import (
+    make_commit,
+    )
 
 
 
 
 ONE = '1' * 40
 ONE = '1' * 40
@@ -76,13 +82,25 @@ class TestProto(object):
             return None
             return None
 
 
 
 
+class TestGenericHandler(Handler):
+
+    def __init__(self):
+        Handler.__init__(self, Backend(), None)
+
+    @classmethod
+    def capabilities(cls):
+        return ('cap1', 'cap2', 'cap3')
+
+    @classmethod
+    def required_capabilities(cls):
+        return ('cap2',)
+
+
 class HandlerTestCase(TestCase):
 class HandlerTestCase(TestCase):
 
 
     def setUp(self):
     def setUp(self):
         super(HandlerTestCase, self).setUp()
         super(HandlerTestCase, self).setUp()
-        self._handler = Handler(Backend(), None)
-        self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
-        self._handler.required_capabilities = lambda: ('cap2',)
+        self._handler = TestGenericHandler()
 
 
     def assertSucceeds(self, func, *args, **kwargs):
     def assertSucceeds(self, func, *args, **kwargs):
         try:
         try:
@@ -124,10 +142,10 @@ class UploadPackHandlerTestCase(TestCase):
 
 
     def setUp(self):
     def setUp(self):
         super(UploadPackHandlerTestCase, self).setUp()
         super(UploadPackHandlerTestCase, self).setUp()
-        self._backend = DictBackend({"/": BackendRepo()})
-        self._handler = UploadPackHandler(self._backend,
-                ["/", "host=lolcathost"], None, None)
-        self._handler.proto = TestProto()
+        self._repo = MemoryRepo.init_bare([], {})
+        backend = DictBackend({'/': self._repo})
+        self._handler = UploadPackHandler(
+          backend, ['/', 'host=lolcathost'], TestProto())
 
 
     def test_progress(self):
     def test_progress(self):
         caps = self._handler.required_capabilities()
         caps = self._handler.required_capabilities()
@@ -153,63 +171,30 @@ class UploadPackHandlerTestCase(TestCase):
             'refs/tags/tag2': TWO,
             'refs/tags/tag2': TWO,
             'refs/heads/master': FOUR,  # not a tag, no peeled value
             'refs/heads/master': FOUR,  # not a tag, no peeled value
             }
             }
+        # repo needs to peel this object
+        self._repo.object_store.add_object(make_commit(id=FOUR))
+        self._repo.refs._update(refs)
         peeled = {
         peeled = {
-            'refs/tags/tag1': '1234',
-            'refs/tags/tag2': '5678',
+            'refs/tags/tag1': '1234' * 10,
+            'refs/tags/tag2': '5678' * 10,
             }
             }
-
-        class TestRepo(object):
-            def get_peeled(self, ref):
-                return peeled.get(ref, refs[ref])
+        self._repo.refs._update_peeled(peeled)
 
 
         caps = list(self._handler.required_capabilities()) + ['include-tag']
         caps = list(self._handler.required_capabilities()) + ['include-tag']
         self._handler.set_client_capabilities(caps)
         self._handler.set_client_capabilities(caps)
-        self.assertEquals({'1234': ONE, '5678': TWO},
-                          self._handler.get_tagged(refs, repo=TestRepo()))
+        self.assertEquals({'1234' * 10: ONE, '5678' * 10: TWO},
+                          self._handler.get_tagged(refs, repo=self._repo))
 
 
         # non-include-tag case
         # non-include-tag case
         caps = self._handler.required_capabilities()
         caps = self._handler.required_capabilities()
         self._handler.set_client_capabilities(caps)
         self._handler.set_client_capabilities(caps)
-        self.assertEquals({}, self._handler.get_tagged(refs, repo=TestRepo()))
-
-
-class TestCommit(object):
-
-    def __init__(self, sha, parents, commit_time):
-        self.id = sha
-        self.parents = parents
-        self.commit_time = commit_time
-        self.type_name = "commit"
-
-    def __repr__(self):
-        return '%s(%s)' % (self.__class__.__name__, self._sha)
-
-
-class TestRepo(object):
-    def __init__(self):
-        self.peeled = {}
+        self.assertEquals({}, self._handler.get_tagged(refs, repo=self._repo))
 
 
-    def get_peeled(self, name):
-        return self.peeled[name]
 
 
-
-class TestBackend(object):
-
-    def __init__(self, repo, objects):
-        self.repo = repo
-        self.object_store = objects
-
-
-class TestUploadPackHandler(Handler):
-
-    def __init__(self, objects, proto):
-        self.backend = TestBackend(TestRepo(), objects)
-        self.proto = proto
-        self.stateless_rpc = False
-        self.advertise_refs = False
-
-    def capabilities(self):
-        return ('multi_ack',)
+class TestUploadPackHandler(UploadPackHandler):
+    @classmethod
+    def required_capabilities(self):
+        return ()
 
 
 
 
 class ProtocolGraphWalkerTestCase(TestCase):
 class ProtocolGraphWalkerTestCase(TestCase):
@@ -220,17 +205,18 @@ class ProtocolGraphWalkerTestCase(TestCase):
         #   3---5
         #   3---5
         #  /
         #  /
         # 1---2---4
         # 1---2---4
-        self._objects = {
-          ONE: TestCommit(ONE, [], 111),
-          TWO: TestCommit(TWO, [ONE], 222),
-          THREE: TestCommit(THREE, [ONE], 333),
-          FOUR: TestCommit(FOUR, [TWO], 444),
-          FIVE: TestCommit(FIVE, [THREE], 555),
-          }
-
+        commits = [
+          make_commit(id=ONE, parents=[], commit_time=111),
+          make_commit(id=TWO, parents=[ONE], commit_time=222),
+          make_commit(id=THREE, parents=[ONE], commit_time=333),
+          make_commit(id=FOUR, parents=[TWO], commit_time=444),
+          make_commit(id=FIVE, parents=[THREE], commit_time=555),
+          ]
+        self._repo = MemoryRepo.init_bare(commits, {})
+        backend = DictBackend({'/': self._repo})
         self._walker = ProtocolGraphWalker(
         self._walker = ProtocolGraphWalker(
-            TestUploadPackHandler(self._objects, TestProto()),
-            self._objects, None)
+            TestUploadPackHandler(backend, ['/', 'host=lolcats'], TestProto()),
+            self._repo.object_store, self._repo.get_peeled)
 
 
     def test_is_satisfied_no_haves(self):
     def test_is_satisfied_no_haves(self):
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))
@@ -257,22 +243,21 @@ class ProtocolGraphWalkerTestCase(TestCase):
         self.assertFalse(self._walker.all_wants_satisfied([THREE]))
         self.assertFalse(self._walker.all_wants_satisfied([THREE]))
         self.assertTrue(self._walker.all_wants_satisfied([TWO, THREE]))
         self.assertTrue(self._walker.all_wants_satisfied([TWO, THREE]))
 
 
-    def test_read_proto_line(self):
-        self._walker.proto.set_output([
-          'want %s' % ONE,
-          'want %s' % TWO,
-          'have %s' % THREE,
-          'foo %s' % FOUR,
-          'bar',
-          'done',
-          ])
-        self.assertEquals(('want', ONE), self._walker.read_proto_line())
-        self.assertEquals(('want', TWO), self._walker.read_proto_line())
-        self.assertEquals(('have', THREE), self._walker.read_proto_line())
-        self.assertRaises(GitProtocolError, self._walker.read_proto_line)
-        self.assertRaises(GitProtocolError, self._walker.read_proto_line)
-        self.assertEquals(('done', None), self._walker.read_proto_line())
-        self.assertEquals((None, None), self._walker.read_proto_line())
+    def test_split_proto_line(self):
+        allowed = ('want', 'done', None)
+        self.assertEquals(('want', ONE),
+                          _split_proto_line('want %s\n' % ONE, allowed))
+        self.assertEquals(('want', TWO),
+                          _split_proto_line('want %s\n' % TWO, allowed))
+        self.assertRaises(GitProtocolError, _split_proto_line,
+                          'want xxxx\n', allowed)
+        self.assertRaises(UnexpectedCommandError, _split_proto_line,
+                          'have %s\n' % THREE, allowed)
+        self.assertRaises(GitProtocolError, _split_proto_line,
+                          'foo %s\n' % FOUR, allowed)
+        self.assertRaises(GitProtocolError, _split_proto_line, 'bar', allowed)
+        self.assertEquals(('done', None), _split_proto_line('done\n', allowed))
+        self.assertEquals((None, None), _split_proto_line('', allowed))
 
 
     def test_determine_wants(self):
     def test_determine_wants(self):
         self.assertRaises(GitProtocolError, self._walker.determine_wants, {})
         self.assertRaises(GitProtocolError, self._walker.determine_wants, {})
@@ -281,8 +266,12 @@ class ProtocolGraphWalkerTestCase(TestCase):
           'want %s multi_ack' % ONE,
           'want %s multi_ack' % ONE,
           'want %s' % TWO,
           'want %s' % TWO,
           ])
           ])
-        heads = {'ref1': ONE, 'ref2': TWO, 'ref3': THREE}
-        self._walker.get_peeled = heads.get
+        heads = {
+          'refs/heads/ref1': ONE,
+          'refs/heads/ref2': TWO,
+          'refs/heads/ref3': THREE,
+          }
+        self._repo.refs._update(heads)
         self.assertEquals([ONE, TWO], self._walker.determine_wants(heads))
         self.assertEquals([ONE, TWO], self._walker.determine_wants(heads))
 
 
         self._walker.proto.set_output(['want %s multi_ack' % FOUR])
         self._walker.proto.set_output(['want %s multi_ack' % FOUR])
@@ -300,9 +289,14 @@ class ProtocolGraphWalkerTestCase(TestCase):
     def test_determine_wants_advertisement(self):
     def test_determine_wants_advertisement(self):
         self._walker.proto.set_output([])
         self._walker.proto.set_output([])
         # advertise branch tips plus tag
         # advertise branch tips plus tag
-        heads = {'ref4': FOUR, 'ref5': FIVE, 'tag6': SIX}
-        peeled = {'ref4': FOUR, 'ref5': FIVE, 'tag6': FIVE}
-        self._walker.get_peeled = peeled.get
+        heads = {
+          'refs/heads/ref4': FOUR,
+          'refs/heads/ref5': FIVE,
+          'refs/heads/tag6': SIX,
+          }
+        self._repo.refs._update(heads)
+        self._repo.refs._update_peeled(heads)
+        self._repo.refs._update_peeled({'refs/heads/tag6': FIVE})
         self._walker.determine_wants(heads)
         self._walker.determine_wants(heads)
         lines = []
         lines = []
         while True:
         while True:
@@ -315,16 +309,16 @@ class ProtocolGraphWalkerTestCase(TestCase):
             lines.append(line.rstrip())
             lines.append(line.rstrip())
 
 
         self.assertEquals([
         self.assertEquals([
-          '%s ref4' % FOUR,
-          '%s ref5' % FIVE,
-          '%s tag6^{}' % FIVE,
-          '%s tag6' % SIX,
+          '%s refs/heads/ref4' % FOUR,
+          '%s refs/heads/ref5' % FIVE,
+          '%s refs/heads/tag6^{}' % FIVE,
+          '%s refs/heads/tag6' % SIX,
           ], sorted(lines))
           ], sorted(lines))
 
 
         # ensure peeled tag was advertised immediately following tag
         # ensure peeled tag was advertised immediately following tag
         for i, line in enumerate(lines):
         for i, line in enumerate(lines):
-            if line.endswith(' tag6'):
-                self.assertEquals('%s tag6^{}' % FIVE, lines[i+1])
+            if line.endswith(' refs/heads/tag6'):
+                self.assertEquals('%s refs/heads/tag6^{}' % FIVE, lines[i+1])
 
 
     # TODO: test commit time cutoff
     # TODO: test commit time cutoff
 
 
@@ -338,8 +332,11 @@ class TestProtocolGraphWalker(object):
         self.stateless_rpc = False
         self.stateless_rpc = False
         self.advertise_refs = False
         self.advertise_refs = False
 
 
-    def read_proto_line(self):
-        return self.lines.pop(0)
+    def read_proto_line(self, allowed):
+        command, sha = self.lines.pop(0)
+        if allowed is not None:
+            assert command in allowed
+        return command, sha
 
 
     def send_ack(self, sha, ack_type=''):
     def send_ack(self, sha, ack_type=''):
         self.acks.append((sha, ack_type))
         self.acks.append((sha, ack_type))

+ 214 - 87
dulwich/tests/test_web.py

@@ -21,8 +21,20 @@
 from cStringIO import StringIO
 from cStringIO import StringIO
 import re
 import re
 
 
+from dulwich.object_store import (
+    MemoryObjectStore,
+    )
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
+    Tag,
+    )
+from dulwich.repo import (
+    BaseRepo,
+    DictRefsContainer,
+    MemoryRepo,
+    )
+from dulwich.server import (
+    DictBackend,
     )
     )
 from dulwich.tests import (
 from dulwich.tests import (
     TestCase,
     TestCase,
@@ -31,33 +43,73 @@ from dulwich.web import (
     HTTP_OK,
     HTTP_OK,
     HTTP_NOT_FOUND,
     HTTP_NOT_FOUND,
     HTTP_FORBIDDEN,
     HTTP_FORBIDDEN,
+    HTTP_ERROR,
     send_file,
     send_file,
+    get_text_file,
+    get_loose_object,
+    get_pack_file,
+    get_idx_file,
     get_info_refs,
     get_info_refs,
+    get_info_packs,
     handle_service_request,
     handle_service_request,
     _LengthLimitedFile,
     _LengthLimitedFile,
     HTTPGitRequest,
     HTTPGitRequest,
     HTTPGitApplication,
     HTTPGitApplication,
     )
     )
 
 
+from utils import make_object
+
+
+class TestHTTPGitRequest(HTTPGitRequest):
+    """HTTPGitRequest with overridden methods to help test caching."""
+
+    def __init__(self, *args, **kwargs):
+        HTTPGitRequest.__init__(self, *args, **kwargs)
+        self.cached = None
+
+    def nocache(self):
+        self.cached = False
+
+    def cache_forever(self):
+        self.cached = True
+
 
 
 class WebTestCase(TestCase):
 class WebTestCase(TestCase):
-    """Base TestCase that sets up some useful instance vars."""
+    """Base TestCase with useful instance vars and utility functions."""
+
+    _req_class = TestHTTPGitRequest
 
 
     def setUp(self):
     def setUp(self):
         super(WebTestCase, self).setUp()
         super(WebTestCase, self).setUp()
         self._environ = {}
         self._environ = {}
-        self._req = HTTPGitRequest(self._environ, self._start_response,
-                                   handlers=self._handlers())
+        self._req = self._req_class(self._environ, self._start_response,
+                                    handlers=self._handlers())
         self._status = None
         self._status = None
         self._headers = []
         self._headers = []
+        self._output = StringIO()
 
 
     def _start_response(self, status, headers):
     def _start_response(self, status, headers):
         self._status = status
         self._status = status
         self._headers = list(headers)
         self._headers = list(headers)
+        return self._output.write
 
 
     def _handlers(self):
     def _handlers(self):
         return None
         return None
 
 
+    def assertContentTypeEquals(self, expected):
+        self.assertTrue(('Content-Type', expected) in self._headers)
+
+
+def _test_backend(objects, refs=None, named_files=None):
+    if not refs:
+        refs = {}
+    if not named_files:
+        named_files = {}
+    repo = MemoryRepo.init_bare(objects, refs)
+    for path, contents in named_files.iteritems():
+        repo._put_named_file(path, contents)
+    return DictBackend({'/': repo})
+
 
 
 class DumbHandlersTestCase(WebTestCase):
 class DumbHandlersTestCase(WebTestCase):
 
 
@@ -67,10 +119,10 @@ class DumbHandlersTestCase(WebTestCase):
 
 
     def test_send_file(self):
     def test_send_file(self):
         f = StringIO('foobar')
         f = StringIO('foobar')
-        output = ''.join(send_file(self._req, f, 'text/plain'))
+        output = ''.join(send_file(self._req, f, 'some/thing'))
         self.assertEquals('foobar', output)
         self.assertEquals('foobar', output)
         self.assertEquals(HTTP_OK, self._status)
         self.assertEquals(HTTP_OK, self._status)
-        self.assertTrue(('Content-Type', 'text/plain') in self._headers)
+        self.assertContentTypeEquals('some/thing')
         self.assertTrue(f.closed)
         self.assertTrue(f.closed)
 
 
     def test_send_file_buffered(self):
     def test_send_file_buffered(self):
@@ -78,93 +130,152 @@ class DumbHandlersTestCase(WebTestCase):
         xs = 'x' * bufsize
         xs = 'x' * bufsize
         f = StringIO(2 * xs)
         f = StringIO(2 * xs)
         self.assertEquals([xs, xs],
         self.assertEquals([xs, xs],
-                          list(send_file(self._req, f, 'text/plain')))
+                          list(send_file(self._req, f, 'some/thing')))
         self.assertEquals(HTTP_OK, self._status)
         self.assertEquals(HTTP_OK, self._status)
-        self.assertTrue(('Content-Type', 'text/plain') in self._headers)
+        self.assertContentTypeEquals('some/thing')
         self.assertTrue(f.closed)
         self.assertTrue(f.closed)
 
 
     def test_send_file_error(self):
     def test_send_file_error(self):
         class TestFile(object):
         class TestFile(object):
-            def __init__(self):
+            def __init__(self, exc_class):
                 self.closed = False
                 self.closed = False
+                self._exc_class = exc_class
 
 
             def read(self, size=-1):
             def read(self, size=-1):
-                raise IOError
+                raise self._exc_class()
 
 
             def close(self):
             def close(self):
                 self.closed = True
                 self.closed = True
 
 
-        f = TestFile()
-        list(send_file(self._req, f, 'text/plain'))
-        self.assertEquals(HTTP_NOT_FOUND, self._status)
+        f = TestFile(IOError)
+        list(send_file(self._req, f, 'some/thing'))
+        self.assertEquals(HTTP_ERROR, self._status)
+        self.assertTrue(f.closed)
+        self.assertFalse(self._req.cached)
+
+        # non-IOErrors are reraised
+        f = TestFile(AttributeError)
+        self.assertRaises(AttributeError, list,
+                          send_file(self._req, f, 'some/thing'))
         self.assertTrue(f.closed)
         self.assertTrue(f.closed)
+        self.assertFalse(self._req.cached)
+
+    def test_get_text_file(self):
+        backend = _test_backend([], named_files={'description': 'foo'})
+        mat = re.search('.*', 'description')
+        output = ''.join(get_text_file(self._req, backend, mat))
+        self.assertEquals('foo', output)
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertContentTypeEquals('text/plain')
+        self.assertFalse(self._req.cached)
+
+    def test_get_loose_object(self):
+        blob = make_object(Blob, data='foo')
+        backend = _test_backend([blob])
+        mat = re.search('^(..)(.{38})$', blob.id)
+        output = ''.join(get_loose_object(self._req, backend, mat))
+        self.assertEquals(blob.as_legacy_object(), output)
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertContentTypeEquals('application/x-git-loose-object')
+        self.assertTrue(self._req.cached)
+
+    def test_get_loose_object_missing(self):
+        mat = re.search('^(..)(.{38})$', '1' * 40)
+        list(get_loose_object(self._req, _test_backend([]), mat))
+        self.assertEquals(HTTP_NOT_FOUND, self._status)
+
+    def test_get_loose_object_error(self):
+        blob = make_object(Blob, data='foo')
+        backend = _test_backend([blob])
+        mat = re.search('^(..)(.{38})$', blob.id)
+
+        def as_legacy_object_error():
+            raise IOError
+
+        blob.as_legacy_object = as_legacy_object_error
+        list(get_loose_object(self._req, backend, mat))
+        self.assertEquals(HTTP_ERROR, self._status)
+
+    def test_get_pack_file(self):
+        pack_name = 'objects/pack/pack-%s.pack' % ('1' * 40)
+        backend = _test_backend([], named_files={pack_name: 'pack contents'})
+        mat = re.search('.*', pack_name)
+        output = ''.join(get_pack_file(self._req, backend, mat))
+        self.assertEquals('pack contents', output)
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertContentTypeEquals('application/x-git-packed-objects')
+        self.assertTrue(self._req.cached)
+
+    def test_get_idx_file(self):
+        idx_name = 'objects/pack/pack-%s.idx' % ('1' * 40)
+        backend = _test_backend([], named_files={idx_name: 'idx contents'})
+        mat = re.search('.*', idx_name)
+        output = ''.join(get_idx_file(self._req, backend, mat))
+        self.assertEquals('idx contents', output)
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertContentTypeEquals('application/x-git-packed-objects-toc')
+        self.assertTrue(self._req.cached)
 
 
     def test_get_info_refs(self):
     def test_get_info_refs(self):
         self._environ['QUERY_STRING'] = ''
         self._environ['QUERY_STRING'] = ''
 
 
-        class TestTag(object):
-            def __init__(self, sha, obj_class, obj_sha):
-                self.sha = lambda: sha
-                self.object = (obj_class, obj_sha)
-
-        class TestBlob(object):
-            def __init__(self, sha):
-                self.sha = lambda: sha
-
-        blob1 = TestBlob('111')
-        blob2 = TestBlob('222')
-        blob3 = TestBlob('333')
-
-        tag1 = TestTag('aaa', Blob, '222')
-
-        class TestRepo(object):
-
-            def __init__(self, objects, peeled):
-                self._objects = dict((o.sha(), o) for o in objects)
-                self._peeled = peeled
-
-            def get_peeled(self, sha):
-                return self._peeled[sha]
-
-            def __getitem__(self, sha):
-                return self._objects[sha]
-
-            def get_refs(self):
-                return {
-                    'HEAD': '000',
-                    'refs/heads/master': blob1.sha(),
-                    'refs/tags/tag-tag': tag1.sha(),
-                    'refs/tags/blob-tag': blob3.sha(),
-                    }
-
-        class TestBackend(object):
-            def __init__(self):
-                objects = [blob1, blob2, blob3, tag1]
-                self.repo = TestRepo(objects, {
-                  'HEAD': '000',
-                  'refs/heads/master': blob1.sha(),
-                  'refs/tags/tag-tag': blob2.sha(),
-                  'refs/tags/blob-tag': blob3.sha(),
-                  })
-
-            def open_repository(self, path):
-                assert path == '/'
-                return self.repo
-
-            def get_refs(self):
-                return {
-                  'HEAD': '000',
-                  'refs/heads/master': blob1.sha(),
-                  'refs/tags/tag-tag': tag1.sha(),
-                  'refs/tags/blob-tag': blob3.sha(),
-                  }
+        blob1 = make_object(Blob, data='1')
+        blob2 = make_object(Blob, data='2')
+        blob3 = make_object(Blob, data='3')
+
+        tag1 = make_object(Tag, name='tag-tag',
+                           tagger='Test <test@example.com>',
+                           tag_time=12345,
+                           tag_timezone=0,
+                           message='message',
+                           object=(Blob, blob2.id))
+
+        objects = [blob1, blob2, blob3, tag1]
+        refs = {
+          'HEAD': '000',
+          'refs/heads/master': blob1.id,
+          'refs/tags/tag-tag': tag1.id,
+          'refs/tags/blob-tag': blob3.id,
+          }
+        backend = _test_backend(objects, refs=refs)
 
 
         mat = re.search('.*', '//info/refs')
         mat = re.search('.*', '//info/refs')
-        self.assertEquals(['111\trefs/heads/master\n',
-                           '333\trefs/tags/blob-tag\n',
-                           'aaa\trefs/tags/tag-tag\n',
-                           '222\trefs/tags/tag-tag^{}\n'],
-                          list(get_info_refs(self._req, TestBackend(), mat)))
+        self.assertEquals(['%s\trefs/heads/master\n' % blob1.id,
+                           '%s\trefs/tags/blob-tag\n' % blob3.id,
+                           '%s\trefs/tags/tag-tag\n' % tag1.id,
+                           '%s\trefs/tags/tag-tag^{}\n' % blob2.id],
+                          list(get_info_refs(self._req, backend, mat)))
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertContentTypeEquals('text/plain')
+        self.assertFalse(self._req.cached)
+
+    def test_get_info_packs(self):
+        class TestPack(object):
+            def __init__(self, sha):
+                self._sha = sha
+
+            def name(self):
+                return self._sha
+
+        packs = [TestPack(str(i) * 40) for i in xrange(1, 4)]
+
+        class TestObjectStore(MemoryObjectStore):
+            # property must be overridden, can't be assigned
+            @property
+            def packs(self):
+                return packs
+
+        store = TestObjectStore()
+        repo = BaseRepo(store, None)
+        backend = DictBackend({'/': repo})
+        mat = re.search('.*', '//info/packs')
+        output = ''.join(get_info_packs(self._req, backend, mat))
+        expected = 'P pack-%s.pack\n' * 3
+        expected %= ('1' * 40, '2' * 40, '3' * 40)
+        self.assertEquals(expected, output)
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertContentTypeEquals('text/plain')
+        self.assertFalse(self._req.cached)
 
 
 
 
 class SmartHandlersTestCase(WebTestCase):
 class SmartHandlersTestCase(WebTestCase):
@@ -191,43 +302,55 @@ class SmartHandlersTestCase(WebTestCase):
         mat = re.search('.*', '/git-evil-handler')
         mat = re.search('.*', '/git-evil-handler')
         list(handle_service_request(self._req, 'backend', mat))
         list(handle_service_request(self._req, 'backend', mat))
         self.assertEquals(HTTP_FORBIDDEN, self._status)
         self.assertEquals(HTTP_FORBIDDEN, self._status)
+        self.assertFalse(self._req.cached)
 
 
-    def test_handle_service_request(self):
+    def _run_handle_service_request(self, content_length=None):
         self._environ['wsgi.input'] = StringIO('foo')
         self._environ['wsgi.input'] = StringIO('foo')
+        if content_length is not None:
+            self._environ['CONTENT_LENGTH'] = content_length
         mat = re.search('.*', '/git-upload-pack')
         mat = re.search('.*', '/git-upload-pack')
-        output = ''.join(handle_service_request(self._req, 'backend', mat))
-        self.assertEqual('handled input: foo', output)
-        response_type = 'application/x-git-upload-pack-response'
-        self.assertTrue(('Content-Type', response_type) in self._headers)
+        handler_output = ''.join(
+          handle_service_request(self._req, 'backend', mat))
+        write_output = self._output.getvalue()
+        # Ensure all output was written via the write callback.
+        self.assertEqual('', handler_output)
+        self.assertEqual('handled input: foo', write_output)
+        self.assertContentTypeEquals('application/x-git-upload-pack-response')
         self.assertFalse(self._handler.advertise_refs)
         self.assertFalse(self._handler.advertise_refs)
         self.assertTrue(self._handler.stateless_rpc)
         self.assertTrue(self._handler.stateless_rpc)
+        self.assertFalse(self._req.cached)
+
+    def test_handle_service_request(self):
+        self._run_handle_service_request()
 
 
     def test_handle_service_request_with_length(self):
     def test_handle_service_request_with_length(self):
-        self._environ['wsgi.input'] = StringIO('foobar')
-        self._environ['CONTENT_LENGTH'] = 3
-        mat = re.search('.*', '/git-upload-pack')
-        output = ''.join(handle_service_request(self._req, 'backend', mat))
-        self.assertEqual('handled input: foo', output)
-        response_type = 'application/x-git-upload-pack-response'
-        self.assertTrue(('Content-Type', response_type) in self._headers)
+        self._run_handle_service_request(content_length='3')
+
+    def test_handle_service_request_empty_length(self):
+        self._run_handle_service_request(content_length='')
 
 
     def test_get_info_refs_unknown(self):
     def test_get_info_refs_unknown(self):
         self._environ['QUERY_STRING'] = 'service=git-evil-handler'
         self._environ['QUERY_STRING'] = 'service=git-evil-handler'
         list(get_info_refs(self._req, 'backend', None))
         list(get_info_refs(self._req, 'backend', None))
         self.assertEquals(HTTP_FORBIDDEN, self._status)
         self.assertEquals(HTTP_FORBIDDEN, self._status)
+        self.assertFalse(self._req.cached)
 
 
     def test_get_info_refs(self):
     def test_get_info_refs(self):
         self._environ['wsgi.input'] = StringIO('foo')
         self._environ['wsgi.input'] = StringIO('foo')
         self._environ['QUERY_STRING'] = 'service=git-upload-pack'
         self._environ['QUERY_STRING'] = 'service=git-upload-pack'
 
 
         mat = re.search('.*', '/git-upload-pack')
         mat = re.search('.*', '/git-upload-pack')
-        output = ''.join(get_info_refs(self._req, 'backend', mat))
+        handler_output = ''.join(get_info_refs(self._req, 'backend', mat))
+        write_output = self._output.getvalue()
         self.assertEquals(('001e# service=git-upload-pack\n'
         self.assertEquals(('001e# service=git-upload-pack\n'
                            '0000'
                            '0000'
                            # input is ignored by the handler
                            # input is ignored by the handler
-                           'handled input: '), output)
+                           'handled input: '), write_output)
+        # Ensure all output was written via the write callback.
+        self.assertEquals('', handler_output)
         self.assertTrue(self._handler.advertise_refs)
         self.assertTrue(self._handler.advertise_refs)
         self.assertTrue(self._handler.stateless_rpc)
         self.assertTrue(self._handler.stateless_rpc)
+        self.assertFalse(self._req.cached)
 
 
 
 
 class LengthLimitedFileTestCase(TestCase):
 class LengthLimitedFileTestCase(TestCase):
@@ -248,6 +371,10 @@ class LengthLimitedFileTestCase(TestCase):
 
 
 
 
 class HTTPGitRequestTestCase(WebTestCase):
 class HTTPGitRequestTestCase(WebTestCase):
+
+    # This class tests the contents of the actual cache headers
+    _req_class = HTTPGitRequest
+
     def test_not_found(self):
     def test_not_found(self):
         self._req.cache_forever()  # cache headers should be discarded
         self._req.cache_forever()  # cache headers should be discarded
         message = 'Something not found'
         message = 'Something not found'

+ 24 - 3
dulwich/tests/utils.py

@@ -26,7 +26,10 @@ import shutil
 import tempfile
 import tempfile
 import time
 import time
 
 
-from dulwich.objects import Commit
+from dulwich.objects import (
+    FixedSha,
+    Commit,
+    )
 from dulwich.repo import Repo
 from dulwich.repo import Repo
 
 
 
 
@@ -57,12 +60,30 @@ def tear_down_repo(repo):
 def make_object(cls, **attrs):
 def make_object(cls, **attrs):
     """Make an object for testing and assign some members.
     """Make an object for testing and assign some members.
 
 
+    This method creates a new subclass to allow arbitrary attribute
+    reassignment, which is not otherwise possible with objects having __slots__.
+
     :param attrs: dict of attributes to set on the new object.
     :param attrs: dict of attributes to set on the new object.
     :return: A newly initialized object of type cls.
     :return: A newly initialized object of type cls.
     """
     """
-    obj = cls()
+
+    class TestObject(cls):
+        """Class that inherits from the given class, but without __slots__.
+
+        Note that classes with __slots__ can't have arbitrary attributes monkey-
+        patched in, so this is a class that is exactly the same only with a
+        __dict__ instead of __slots__.
+        """
+        pass
+
+    obj = TestObject()
     for name, value in attrs.iteritems():
     for name, value in attrs.iteritems():
-        setattr(obj, name, value)
+        if name == 'id':
+            # id property is read-only, so we overwrite sha instead.
+            sha = FixedSha(value)
+            obj.sha = lambda: sha
+        else:
+            setattr(obj, name, value)
     return obj
     return obj
 
 
 
 

+ 29 - 16
dulwich/web.py

@@ -48,10 +48,16 @@ logger = log_utils.getLogger(__name__)
 HTTP_OK = '200 OK'
 HTTP_OK = '200 OK'
 HTTP_NOT_FOUND = '404 Not Found'
 HTTP_NOT_FOUND = '404 Not Found'
 HTTP_FORBIDDEN = '403 Forbidden'
 HTTP_FORBIDDEN = '403 Forbidden'
+HTTP_ERROR = '500 Internal Server Error'
 
 
 
 
 def date_time_string(timestamp=None):
 def date_time_string(timestamp=None):
-    # Based on BaseHTTPServer.py in python2.5
+    # From BaseHTTPRequestHandler.date_time_string in BaseHTTPServer.py in the
+    # Python 2.6.5 standard library, following modifications:
+    #  - Made a global rather than an instance method.
+    #  - weekdayname and monthname are renamed and locals rather than class
+    #    variables.
+    # Copyright (c) 2001-2010 Python Software Foundation; All Rights Reserved
     weekdays = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
     weekdays = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
     months = [None,
     months = [None,
               'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
               'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
@@ -100,7 +106,7 @@ def send_file(req, f, content_type):
         f.close()
         f.close()
     except IOError:
     except IOError:
         f.close()
         f.close()
-        yield req.not_found('Error reading file')
+        yield req.error('Error reading file')
     except:
     except:
         f.close()
         f.close()
         raise
         raise
@@ -128,7 +134,8 @@ def get_loose_object(req, backend, mat):
     try:
     try:
         data = object_store[sha].as_legacy_object()
         data = object_store[sha].as_legacy_object()
     except IOError:
     except IOError:
-        yield req.not_found('Error reading object')
+        yield req.error('Error reading object')
+        return
     req.cache_forever()
     req.cache_forever()
     req.respond(HTTP_OK, 'application/x-git-loose-object')
     req.respond(HTTP_OK, 'application/x-git-loose-object')
     yield data
     yield data
@@ -159,15 +166,13 @@ def get_info_refs(req, backend, mat):
             yield req.forbidden('Unsupported service %s' % service)
             yield req.forbidden('Unsupported service %s' % service)
             return
             return
         req.nocache()
         req.nocache()
-        req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
-        output = StringIO()
-        proto = ReceivableProtocol(StringIO().read, output.write)
+        write = req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
+        proto = ReceivableProtocol(StringIO().read, write)
         handler = handler_cls(backend, [url_prefix(mat)], proto,
         handler = handler_cls(backend, [url_prefix(mat)], proto,
                               stateless_rpc=True, advertise_refs=True)
                               stateless_rpc=True, advertise_refs=True)
         handler.proto.write_pkt_line('# service=%s\n' % service)
         handler.proto.write_pkt_line('# service=%s\n' % service)
         handler.proto.write_pkt_line(None)
         handler.proto.write_pkt_line(None)
         handler.handle()
         handler.handle()
-        yield output.getvalue()
     else:
     else:
         # non-smart fallback
         # non-smart fallback
         # TODO: select_getanyfile() (see http-backend.c)
         # TODO: select_getanyfile() (see http-backend.c)
@@ -230,20 +235,19 @@ def handle_service_request(req, backend, mat):
         yield req.forbidden('Unsupported service %s' % service)
         yield req.forbidden('Unsupported service %s' % service)
         return
         return
     req.nocache()
     req.nocache()
-    req.respond(HTTP_OK, 'application/x-%s-response' % service)
+    write = req.respond(HTTP_OK, 'application/x-%s-response' % service)
 
 
-    output = StringIO()
     input = req.environ['wsgi.input']
     input = req.environ['wsgi.input']
     # This is not necessary if this app is run from a conforming WSGI server.
     # This is not necessary if this app is run from a conforming WSGI server.
     # Unfortunately, there's no way to tell that at this point.
     # Unfortunately, there's no way to tell that at this point.
     # TODO: git may used HTTP/1.1 chunked encoding instead of specifying
     # TODO: git may used HTTP/1.1 chunked encoding instead of specifying
     # content-length
     # content-length
-    if 'CONTENT_LENGTH' in req.environ:
-        input = _LengthLimitedFile(input, int(req.environ['CONTENT_LENGTH']))
-    proto = ReceivableProtocol(input.read, output.write)
+    content_length = req.environ.get('CONTENT_LENGTH', '')
+    if content_length:
+        input = _LengthLimitedFile(input, int(content_length))
+    proto = ReceivableProtocol(input.read, write)
     handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True)
     handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True)
     handler.handle()
     handler.handle()
-    yield output.getvalue()
 
 
 
 
 class HTTPGitRequest(object):
 class HTTPGitRequest(object):
@@ -255,7 +259,7 @@ class HTTPGitRequest(object):
     def __init__(self, environ, start_response, dumb=False, handlers=None):
     def __init__(self, environ, start_response, dumb=False, handlers=None):
         self.environ = environ
         self.environ = environ
         self.dumb = dumb
         self.dumb = dumb
-        self.handlers = handlers and handlers or DEFAULT_HANDLERS
+        self.handlers = handlers
         self._start_response = start_response
         self._start_response = start_response
         self._cache_headers = []
         self._cache_headers = []
         self._headers = []
         self._headers = []
@@ -272,7 +276,7 @@ class HTTPGitRequest(object):
             self._headers.append(('Content-Type', content_type))
             self._headers.append(('Content-Type', content_type))
         self._headers.extend(self._cache_headers)
         self._headers.extend(self._cache_headers)
 
 
-        self._start_response(status, self._headers)
+        return self._start_response(status, self._headers)
 
 
     def not_found(self, message):
     def not_found(self, message):
         """Begin a HTTP 404 response and return the text of a message."""
         """Begin a HTTP 404 response and return the text of a message."""
@@ -288,6 +292,13 @@ class HTTPGitRequest(object):
         self.respond(HTTP_FORBIDDEN, 'text/plain')
         self.respond(HTTP_FORBIDDEN, 'text/plain')
         return message
         return message
 
 
+    def error(self, message):
+        """Begin a HTTP 500 response and return the text of a message."""
+        self._cache_headers = []
+        logger.error('Error: %s', message)
+        self.respond(HTTP_ERROR, 'text/plain')
+        return message
+
     def nocache(self):
     def nocache(self):
         """Set the response to never be cached by the client."""
         """Set the response to never be cached by the client."""
         self._cache_headers = [
         self._cache_headers = [
@@ -329,7 +340,9 @@ class HTTPGitApplication(object):
     def __init__(self, backend, dumb=False, handlers=None):
     def __init__(self, backend, dumb=False, handlers=None):
         self.backend = backend
         self.backend = backend
         self.dumb = dumb
         self.dumb = dumb
-        self.handlers = handlers
+        self.handlers = dict(DEFAULT_HANDLERS)
+        if handlers is not None:
+            self.handlers.update(handlers)
 
 
     def __call__(self, environ, start_response):
     def __call__(self, environ, start_response):
         path = environ['PATH_INFO']
         path = environ['PATH_INFO']

+ 1 - 1
setup.py

@@ -8,7 +8,7 @@ except ImportError:
     from distutils.core import setup, Extension
     from distutils.core import setup, Extension
 from distutils.core import Distribution
 from distutils.core import Distribution
 
 
-dulwich_version_string = '0.6.1'
+dulwich_version_string = '0.6.2'
 
 
 include_dirs = []
 include_dirs = []
 # Windows MSVC support
 # Windows MSVC support