Jelmer Vernooij 15 роки тому
батько
коміт
520c83293a
34 змінених файлів з 1329 додано та 578 видалено
  1. 4 2
      Makefile
  2. 6 0
      NEWS
  3. 4 0
      README
  4. 4 3
      bin/dul-daemon
  5. 2 2
      bin/dul-web
  6. 17 7
      dulwich/_objects.c
  7. 11 4
      dulwich/client.py
  8. 16 13
      dulwich/errors.py
  9. 8 3
      dulwich/misc.py
  10. 49 32
      dulwich/object_store.py
  11. 505 195
      dulwich/objects.py
  12. 70 55
      dulwich/pack.py
  13. 33 20
      dulwich/repo.py
  14. 106 91
      dulwich/server.py
  15. 4 0
      dulwich/tests/__init__.py
  16. 15 9
      dulwich/tests/compat/test_server.py
  17. 8 5
      dulwich/tests/compat/test_web.py
  18. 5 5
      dulwich/tests/compat/utils.py
  19. BIN
      dulwich/tests/data/blobs/11/11111111111111111111111111111111111111
  20. 0 0
      dulwich/tests/data/blobs/6f/670c0fb53f9463760b7295fbb814e965fb20c8
  21. 0 0
      dulwich/tests/data/blobs/95/4a536f7819d40e6f637f849ee187dd10066349
  22. 0 0
      dulwich/tests/data/blobs/e6/9de29bb2d1d6434b8b29ae775ad8c2e48c5391
  23. 0 0
      dulwich/tests/data/commits/0d/89f20333fbb1d2f3a94da77f4981373d8f4310
  24. 0 0
      dulwich/tests/data/commits/5d/ac377bdded4c9aeb8dff595f0faeebcc8498cc
  25. 0 0
      dulwich/tests/data/commits/60/dacdc733de308bb77bb76ce0fb0f9b44c9769e
  26. 0 0
      dulwich/tests/data/tags/71/033db03a03c6a36721efcf1968dd8f8e0cf023
  27. 0 0
      dulwich/tests/data/trees/70/c190eb48fa8bbb50ddc692a17b44cb781af7f6
  28. 16 7
      dulwich/tests/test_object_store.py
  29. 349 64
      dulwich/tests/test_objects.py
  30. 3 4
      dulwich/tests/test_pack.py
  31. 22 14
      dulwich/tests/test_repository.py
  32. 16 12
      dulwich/tests/test_server.py
  33. 21 18
      dulwich/tests/test_web.py
  34. 35 13
      dulwich/web.py

+ 4 - 2
Makefile

@@ -2,6 +2,7 @@ PYTHON = python
 SETUP = $(PYTHON) setup.py
 PYDOCTOR ?= pydoctor
 TESTRUNNER = $(shell which nosetests)
+TESTFLAGS =
 
 all: build
 
@@ -19,12 +20,13 @@ install::
 
 check:: build
 	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) dulwich
+	which git > /dev/null && PYTHONPATH=. $(PYTHON) $(TESTRUNNER) $(TESTFLAGS) -i compat
 
 check-noextensions:: clean
-	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) dulwich
+	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) $(TESTFLAGS) dulwich
 
 check-compat:: build
-	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) -i compat
+	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) $(TESTFLAGS) -i compat
 
 clean::
 	$(SETUP) clean --all

+ 6 - 0
NEWS

@@ -13,6 +13,9 @@
 
   * Implement RefsContainer.__contains__. (Jelmer Vernooij)
 
+  * Cope with \r in ref files on Windows. (
+	http://github.com/jelmer/dulwich/issues/#issue/13, Jelmer Vernooij)
+
  FEATURES
 
   * Add include-tag capability to server. (Dave Borowitz)
@@ -29,6 +32,9 @@
   * Repo.get_blob, Repo.commit, Repo.tag and Repo.tree are now deprecated.
     (Jelmer Vernooij)
 
+  * RefsContainer.set_ref() was renamed to RefsContainer.set_symbolic_ref(),
+    for clarity. (Jelmer Vernooij)
+
  API CHANGES
 
   * Blob.chunked was added. (Jelmer Vernooij)

+ 4 - 0
README

@@ -21,3 +21,7 @@ The project is named after the part of London that Mr. and Mrs. Git live in
 in the particular Monty Python sketch. It is based on the Python-Git module 
 that James Westby <jw+debian@jameswestby.net> released in 2007 and now 
 maintained by Jelmer Vernooij and John Carr.
+
+Please file bugs in the Dulwich project on Launchpad: 
+
+https://bugs.launchpad.net/dulwich/+filebug

+ 4 - 3
bin/dul-daemon

@@ -19,13 +19,14 @@
 
 import sys
 from dulwich.repo import Repo
-from dulwich.server import GitBackend, TCPGitServer
+from dulwich.server import DictBackend, TCPGitServer
 
 if __name__ == "__main__":
-    gitdir = None
     if len(sys.argv) > 1:
         gitdir = sys.argv[1]
+    else:
+        gitdir = "."
 
-    backend = GitBackend(Repo(gitdir))
+    backend = DictBackend({"/": Repo(gitdir)})
     server = TCPGitServer(backend, 'localhost')
     server.serve_forever()

+ 2 - 2
bin/dul-web

@@ -20,7 +20,7 @@
 import os
 import sys
 from dulwich.repo import Repo
-from dulwich.server import GitBackend
+from dulwich.server import DictBackend
 from dulwich.web import HTTPGitApplication
 from wsgiref.simple_server import make_server
 
@@ -30,7 +30,7 @@ if __name__ == "__main__":
     else:
         gitdir = os.getcwd()
 
-    backend = GitBackend(Repo(gitdir))
+    backend = DictBackend({"/": Repo(gitdir)})
     app = HTTPGitApplication(backend)
     # TODO: allow serving on other ports via command-line flag
     server = make_server('', 8000, app)

+ 17 - 7
dulwich/_objects.c

@@ -37,18 +37,22 @@ static PyObject *sha_to_pyhex(const unsigned char *sha)
 
 static PyObject *py_parse_tree(PyObject *self, PyObject *args)
 {
-	char *text, *end;
+	char *text, *start, *end;
 	int len, namelen;
 	PyObject *ret, *item, *name;
 
 	if (!PyArg_ParseTuple(args, "s#", &text, &len))
 		return NULL;
 
-	ret = PyDict_New();
+	/* TODO: currently this returns a list; if memory usage is a concern,
+	* consider rewriting as a custom iterator object */
+	ret = PyList_New(0);
+
 	if (ret == NULL) {
 		return NULL;
 	}
 
+	start = text;
 	end = text + len;
 
 	while (text < end) {
@@ -56,14 +60,14 @@ static PyObject *py_parse_tree(PyObject *self, PyObject *args)
 		mode = strtol(text, &text, 8);
 
 		if (*text != ' ') {
-			PyErr_SetString(PyExc_RuntimeError, "Expected space");
+			PyErr_SetString(PyExc_ValueError, "Expected space");
 			Py_DECREF(ret);
 			return NULL;
 		}
 
 		text++;
 
-		namelen = strlen(text);
+		namelen = strnlen(text, len - (text - start));
 
 		name = PyString_FromStringAndSize(text, namelen);
 		if (name == NULL) {
@@ -71,19 +75,25 @@ static PyObject *py_parse_tree(PyObject *self, PyObject *args)
 			return NULL;
 		}
 
-		item = Py_BuildValue("(lN)", mode,
+		if (text + namelen + 20 >= end) {
+			PyErr_SetString(PyExc_ValueError, "SHA truncated");
+			Py_DECREF(ret);
+			Py_DECREF(name);
+			return NULL;
+		}
+
+		item = Py_BuildValue("(NlN)", name, mode,
 							 sha_to_pyhex((unsigned char *)text+namelen+1));
 		if (item == NULL) {
 			Py_DECREF(ret);
 			Py_DECREF(name);
 			return NULL;
 		}
-		if (PyDict_SetItem(ret, name, item) == -1) {
+		if (PyList_Append(ret, item) == -1) {
 			Py_DECREF(ret);
 			Py_DECREF(item);
 			return NULL;
 		}
-		Py_DECREF(name);
 		Py_DECREF(item);
 
 		text += namelen+21;

+ 11 - 4
dulwich/client.py

@@ -28,6 +28,7 @@ import subprocess
 
 from dulwich.errors import (
     ChecksumMismatch,
+    HangupException,
     )
 from dulwich.protocol import (
     Protocol,
@@ -119,10 +120,16 @@ class GitClient(object):
                                          len(objects))
         
         # read the final confirmation sha
-        client_sha = self.proto.read(20)
-        if not client_sha in (None, "", sha):
-            raise ChecksumMismatch(sha, client_sha)
-            
+        try:
+            client_sha = self.proto.read_pkt_line()
+        except HangupException:
+            # for git-daemon versions before v1.6.6.1-26-g38a81b4, there is
+            # nothing to read; catch this and hide from the user.
+            pass
+        else:
+            if not client_sha in (None, "", sha):
+                raise ChecksumMismatch(sha, client_sha)
+
         return new_refs
 
     def fetch(self, path, target, determine_wants=None, progress=None):

+ 16 - 13
dulwich/errors.py

@@ -37,40 +37,39 @@ class ChecksumMismatch(Exception):
 
 class WrongObjectException(Exception):
     """Baseclass for all the _ is not a _ exceptions on objects.
-  
+
     Do not instantiate directly.
-  
-    Subclasses should define a _type attribute that indicates what
+
+    Subclasses should define a type_name attribute that indicates what
     was expected if they were raised.
     """
-  
+
     def __init__(self, sha, *args, **kwargs):
-        string = "%s is not a %s" % (sha, self._type)
-        Exception.__init__(self, string)
+        Exception.__init__(self, "%s is not a %s" % (sha, self.type_name))
 
 
 class NotCommitError(WrongObjectException):
     """Indicates that the sha requested does not point to a commit."""
-  
-    _type = 'commit'
+
+    type_name = 'commit'
 
 
 class NotTreeError(WrongObjectException):
     """Indicates that the sha requested does not point to a tree."""
-  
-    _type = 'tree'
+
+    type_name = 'tree'
 
 
 class NotTagError(WrongObjectException):
     """Indicates that the sha requested does not point to a tag."""
 
-    _type = 'tag'
+    type_name = 'tag'
 
 
 class NotBlobError(WrongObjectException):
     """Indicates that the sha requested does not point to a blob."""
-  
-    _type = 'blob'
+
+    type_name = 'blob'
 
 
 class MissingCommitError(Exception):
@@ -124,5 +123,9 @@ class PackedRefsException(FileFormatException):
     """Indicates an error parsing a packed-refs file."""
 
 
+class ObjectFormatException(FileFormatException):
+    """Indicates an error parsing an object."""
+
+
 class NoIndexPresent(Exception):
     """No index is present."""

+ 8 - 3
dulwich/misc.py

@@ -15,15 +15,21 @@
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # MA  02110-1301, USA.
-"""Misc utilities to work with python2.4.
+"""Misc utilities to work with python <2.6.
 
 These utilities can all be deleted when dulwich decides it wants to stop
-support for python 2.4.
+support for python <2.6.
 """
 try:
     import hashlib
 except ImportError:
     import sha
+
+try:
+    from urlparse import parse_qs
+except ImportError:
+    from cgi import parse_qs
+
 import struct
 
 
@@ -87,4 +93,3 @@ def unpack_from(fmt, buf, offset=0):
     except AttributeError:
         b = buf[offset:offset+struct.calcsize(fmt)]
         return struct.unpack(fmt, b)
-

+ 49 - 32
dulwich/object_store.py

@@ -23,6 +23,7 @@
 import errno
 import itertools
 import os
+import posixpath
 import stat
 import tempfile
 import urllib2
@@ -38,6 +39,7 @@ from dulwich.objects import (
     Tree,
     hex_to_sha,
     sha_to_hex,
+    hex_to_filename,
     S_ISGITLINK,
     )
 from dulwich.pack import (
@@ -91,14 +93,14 @@ class BaseObjectStore(object):
         """Obtain the raw text for an object.
 
         :param name: sha for the object.
-        :return: tuple with object type and object contents.
+        :return: tuple with numeric type and object contents.
         """
         raise NotImplementedError(self.get_raw)
 
     def __getitem__(self, sha):
         """Obtain an object by SHA1."""
-        type, uncomp = self.get_raw(sha)
-        return ShaFile.from_raw_string(type, uncomp)
+        type_num, uncomp = self.get_raw(sha)
+        return ShaFile.from_raw_string(type_num, uncomp)
 
     def __iter__(self):
         """Iterate over the SHAs that are present in this store."""
@@ -137,10 +139,7 @@ class BaseObjectStore(object):
             else:
                 ttree = {}
             for name, oldmode, oldhexsha in stree.iteritems():
-                if path == "":
-                    oldchildpath = name
-                else:
-                    oldchildpath = "%s/%s" % (path, name)
+                oldchildpath = posixpath.join(path, name)
                 try:
                     (newmode, newhexsha) = ttree[name]
                     newchildpath = oldchildpath
@@ -166,10 +165,7 @@ class BaseObjectStore(object):
                             yield ((oldchildpath, newchildpath), (oldmode, newmode), (oldhexsha, newhexsha))
 
             for name, newmode, newhexsha in ttree.iteritems():
-                if path == "":
-                    childpath = name
-                else:
-                    childpath = "%s/%s" % (path, name)
+                childpath = posixpath.join(path, name)
                 if not name in stree:
                     if not stat.S_ISDIR(newmode):
                         yield ((None, childpath), (None, newmode), (None, newhexsha))
@@ -186,10 +182,7 @@ class BaseObjectStore(object):
             (tid, tpath) = todo.pop()
             tree = self[tid]
             for name, mode, hexsha in tree.iteritems(): 
-                if tpath == "":
-                    path = name
-                else:
-                    path = "%s/%s" % (tpath, name)
+                path = posixpath.join(tpath, name)
                 if stat.S_ISDIR(mode):
                     todo.add((hexsha, path))
                 else:
@@ -233,13 +226,14 @@ class BaseObjectStore(object):
         """
         return ObjectStoreGraphWalker(heads, lambda sha: self[sha].parents)
 
-    def generate_pack_contents(self, have, want):
+    def generate_pack_contents(self, have, want, progress=None):
         """Iterate over the contents of a pack file.
 
         :param have: List of SHA1s of objects that should not be sent
         :param want: List of SHA1s of objects that should be sent
+        :param progress: Optional progress reporting method
         """
-        return self.iter_shas(self.find_missing_objects(have, want))
+        return self.iter_shas(self.find_missing_objects(have, want, progress))
 
 
 class PackBasedObjectStore(BaseObjectStore):
@@ -292,9 +286,9 @@ class PackBasedObjectStore(BaseObjectStore):
 
     def get_raw(self, name):
         """Obtain the raw text for an object.
-        
+
         :param name: sha for the object.
-        :return: tuple with object type and object contents.
+        :return: tuple with numeric type and object contents.
         """
         if len(name) == 40:
             sha = hex_to_sha(name)
@@ -313,7 +307,7 @@ class PackBasedObjectStore(BaseObjectStore):
             hexsha = sha_to_hex(name)
         ret = self._get_loose_object(hexsha)
         if ret is not None:
-            return ret.type, ret.as_raw_string()
+            return ret.type_num, ret.as_raw_string()
         raise KeyError(hexsha)
 
     def add_objects(self, objects):
@@ -369,10 +363,8 @@ class DiskObjectStore(PackBasedObjectStore):
             raise
 
     def _get_shafile_path(self, sha):
-        dir = sha[:2]
-        file = sha[2:]
         # Check from object dir
-        return os.path.join(self.path, dir, file)
+        return hex_to_filename(self.path, sha)
 
     def _iter_loose_objects(self):
         for base in os.listdir(self.path):
@@ -385,7 +377,7 @@ class DiskObjectStore(PackBasedObjectStore):
         path = self._get_shafile_path(sha)
         try:
             return ShaFile.from_file(path)
-        except OSError, e:
+        except (OSError, IOError), e:
             if e.errno == errno.ENOENT:
                 return None
             raise
@@ -484,6 +476,17 @@ class DiskObjectStore(PackBasedObjectStore):
         finally:
             f.close()
 
+    @classmethod
+    def init(cls, path):
+        try:
+            os.mkdir(path)
+        except OSError, e:
+            if e.errno != errno.EEXIST:
+                raise
+        os.mkdir(os.path.join(path, "info"))
+        os.mkdir(os.path.join(path, PACKDIR))
+        return cls(path)
+
 
 class MemoryObjectStore(BaseObjectStore):
     """Object store that keeps all objects in memory."""
@@ -511,9 +514,9 @@ class MemoryObjectStore(BaseObjectStore):
 
     def get_raw(self, name):
         """Obtain the raw text for an object.
-        
+
         :param name: sha for the object.
-        :return: tuple with object type and object contents.
+        :return: tuple with numeric type and object contents.
         """
         return self[name].as_raw_string()
 
@@ -629,7 +632,7 @@ def tree_lookup_path(lookup_obj, root_sha, path):
     mode = None
     for p in parts:
         obj = lookup_obj(sha)
-        if type(obj) is not Tree:
+        if not isinstance(obj, Tree):
             raise NotTreeError(sha)
         if p == '':
             continue
@@ -713,11 +716,25 @@ class ObjectStoreGraphWalker(object):
 
     def ack(self, sha):
         """Ack that a revision and its ancestors are present in the source."""
-        if sha in self.heads:
-            self.heads.remove(sha)
-        if sha in self.parents:
-            for p in self.parents[sha]:
-                self.ack(p)
+        ancestors = set([sha])
+
+        # stop if we run out of heads to remove
+        while self.heads:
+            for a in ancestors:
+                if a in self.heads:
+                    self.heads.remove(a)
+
+            # collect all ancestors
+            new_ancestors = set()
+            for a in ancestors:
+                if a in self.parents:
+                    new_ancestors.update(self.parents[a])
+
+            # no more ancestors; stop
+            if not new_ancestors:
+                break
+
+            ancestors = new_ancestors
 
     def next(self):
         """Iterate over ancestors of heads in the target."""

+ 505 - 195
dulwich/objects.py

@@ -28,36 +28,43 @@ from cStringIO import (
 import mmap
 import os
 import stat
-import time
 import zlib
 
 from dulwich.errors import (
+    ChecksumMismatch,
     NotBlobError,
     NotCommitError,
+    NotTagError,
     NotTreeError,
+    ObjectFormatException,
     )
 from dulwich.file import GitFile
 from dulwich.misc import (
     make_sha,
     )
 
-BLOB_ID = "blob"
-TAG_ID = "tag"
-TREE_ID = "tree"
-COMMIT_ID = "commit"
-PARENT_ID = "parent"
-AUTHOR_ID = "author"
-COMMITTER_ID = "committer"
-OBJECT_ID = "object"
-TYPE_ID = "type"
-TAGGER_ID = "tagger"
-ENCODING_ID = "encoding"
+
+# Header fields for commits
+_TREE_HEADER = "tree"
+_PARENT_HEADER = "parent"
+_AUTHOR_HEADER = "author"
+_COMMITTER_HEADER = "committer"
+_ENCODING_HEADER = "encoding"
+
+
+# Header fields for objects
+_OBJECT_HEADER = "object"
+_TYPE_HEADER = "type"
+_TAG_HEADER = "tag"
+_TAGGER_HEADER = "tagger"
+
 
 S_IFGITLINK = 0160000
 
 def S_ISGITLINK(m):
     return (stat.S_IFMT(m) == S_IFGITLINK)
 
+
 def _decompress(string):
     dcomp = zlib.decompressobj()
     dcomped = dcomp.decompress(string)
@@ -78,6 +85,27 @@ def hex_to_sha(hex):
     return binascii.unhexlify(hex)
 
 
+def hex_to_filename(path, hex):
+    """Takes a hex sha and returns its filename relative to the given path."""
+    dir = hex[:2]
+    file = hex[2:]
+    # Check from object dir
+    return os.path.join(path, dir, file)
+
+
+def filename_to_hex(filename):
+    """Takes an object filename and returns its corresponding hex sha."""
+    # grab the last (up to) two path components
+    names = filename.rsplit(os.path.sep, 2)[-2:]
+    errmsg = "Invalid object filename: %s" % filename
+    assert len(names) == 2, errmsg
+    base, rest = names
+    assert len(base) == 2 and len(rest) == 38, errmsg
+    hex = base + rest
+    hex_to_sha(hex)
+    return hex
+
+
 def serializable_property(name, docstring=None):
     def set(obj, value):
         obj._ensure_parsed()
@@ -89,44 +117,100 @@ def serializable_property(name, docstring=None):
     return property(get, set, doc=docstring)
 
 
+def object_class(type):
+    """Get the object class corresponding to the given type.
+
+    :param type: Either a type name string or a numeric type.
+    :return: The ShaFile subclass corresponding to the given type, or None if
+        type is not a valid type name/number.
+    """
+    return _TYPE_MAP.get(type, None)
+
+
+def check_hexsha(hex, error_msg):
+    try:
+        hex_to_sha(hex)
+    except (TypeError, AssertionError):
+        raise ObjectFormatException("%s %s" % (error_msg, hex))
+
+
+def check_identity(identity, error_msg):
+    email_start = identity.find("<")
+    email_end = identity.find(">")
+    if (email_start < 0 or email_end < 0 or email_end <= email_start
+        or identity.find("<", email_start + 1) >= 0
+        or identity.find(">", email_end + 1) >= 0
+        or not identity.endswith(">")):
+        raise ObjectFormatException(error_msg)
+
+
+class FixedSha(object):
+    """SHA object that behaves like hashlib's but is given a fixed value."""
+
+    def __init__(self, hexsha):
+        self._hexsha = hexsha
+        self._sha = hex_to_sha(hexsha)
+
+    def digest(self):
+        return self._sha
+
+    def hexdigest(self):
+        return self._hexsha
+
+
 class ShaFile(object):
     """A git SHA file."""
 
-    @classmethod
-    def _parse_legacy_object(cls, map):
-        """Parse a legacy object, creating it and setting object._text"""
-        text = _decompress(map)
-        object = None
-        for posstype in type_map.keys():
-            if text.startswith(posstype):
-                object = type_map[posstype]()
-                text = text[len(posstype):]
-                break
-        assert object is not None, "%s is not a known object type" % text[:9]
-        assert text[0] == ' ', "%s is not a space" % text[0]
-        text = text[1:]
-        size = 0
-        i = 0
-        while text[0] >= '0' and text[0] <= '9':
-            if i > 0 and size == 0:
-                raise AssertionError("Size is not in canonical format")
-            size = (size * 10) + int(text[0])
-            text = text[1:]
-            i += 1
-        object._size = size
-        assert text[0] == "\0", "Size not followed by null"
-        text = text[1:]
-        object.set_raw_string(text)
-        return object
+    @staticmethod
+    def _parse_legacy_object_header(magic, f):
+        """Parse a legacy object, creating it but not reading the file."""
+        bufsize = 1024
+        decomp = zlib.decompressobj()
+        header = decomp.decompress(magic)
+        start = 0
+        end = -1
+        while end < 0:
+            header += decomp.decompress(f.read(bufsize))
+            end = header.find("\0", start)
+            start = len(header)
+        header = header[:end]
+        type_name, size = header.split(" ", 1)
+        size = int(size)  # sanity check
+        obj_class = object_class(type_name)
+        if not obj_class:
+            raise ObjectFormatException("Not a known type: %s" % type_name)
+        obj = obj_class()
+        obj._filename = f.name
+        return obj
+
+    def _parse_legacy_object(self, f):
+        """Parse a legacy object, setting the raw string."""
+        size = os.path.getsize(f.name)
+        map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
+        try:
+            text = _decompress(map)
+        finally:
+            map.close()
+        header_end = text.find('\0')
+        if header_end < 0:
+            raise ObjectFormatException("Invalid object header")
+        self.set_raw_string(text[header_end+1:])
+
+    def as_legacy_object_chunks(self):
+        compobj = zlib.compressobj()
+        yield compobj.compress(self._header())
+        for chunk in self.as_raw_chunks():
+            yield compobj.compress(chunk)
+        yield compobj.flush()
 
     def as_legacy_object(self):
-        text = self.as_raw_string()
-        return zlib.compress("%s %d\0%s" % (self._type, len(text), text))
+        return "".join(self.as_legacy_object_chunks())
 
     def as_raw_chunks(self):
-        if self._needs_serialization:
+        if self._needs_parsing:
+            self._ensure_parsed()
+        elif self._needs_serialization:
             self._chunked_text = self._serialize()
-            self._needs_serialization = False
         return self._chunked_text
 
     def as_raw_string(self):
@@ -143,6 +227,9 @@ class ShaFile(object):
 
     def _ensure_parsed(self):
         if self._needs_parsing:
+            if not self._chunked_text:
+                assert self._filename, "ShaFile needs either text or filename"
+                self._parse_file()
             self._deserialize(self._chunked_text)
             self._needs_parsing = False
 
@@ -153,39 +240,60 @@ class ShaFile(object):
 
     def set_raw_chunks(self, chunks):
         self._chunked_text = chunks
+        self._deserialize(chunks)
         self._sha = None
-        self._needs_parsing = True
+        self._needs_parsing = False
         self._needs_serialization = False
 
-    @classmethod
-    def _parse_object(cls, map):
-        """Parse a new style object , creating it and setting object._text"""
-        used = 0
-        byte = ord(map[used])
-        used += 1
-        num_type = (byte >> 4) & 7
+    @staticmethod
+    def _parse_object_header(magic, f):
+        """Parse a new style object, creating it but not reading the file."""
+        num_type = (ord(magic[0]) >> 4) & 7
+        obj_class = object_class(num_type)
+        if not obj_class:
+            raise ObjectFormatError("Not a known type: %d" % num_type)
+        obj = obj_class()
+        obj._filename = f.name
+        return obj
+
+    def _parse_object(self, f):
+        """Parse a new style object, setting self._text."""
+        size = os.path.getsize(f.name)
+        map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
         try:
-            object = num_type_map[num_type]()
-        except KeyError:
-            raise AssertionError("Not a known type: %d" % num_type)
-        while (byte & 0x80) != 0:
-            byte = ord(map[used])
-            used += 1
-        raw = map[used:]
-        object.set_raw_string(_decompress(raw))
-        return object
+            # skip type and size; type must have already been determined, and we
+            # trust zlib to fail if it's otherwise corrupted
+            byte = ord(map[0])
+            used = 1
+            while (byte & 0x80) != 0:
+                byte = ord(map[used])
+                used += 1
+            raw = map[used:]
+            self.set_raw_string(_decompress(raw))
+        finally:
+            map.close()
+
+    @classmethod
+    def _is_legacy_object(cls, magic):
+        b0, b1 = map(ord, magic)
+        word = (b0 << 8) + b1
+        return b0 == 0x78 and (word % 31) == 0
 
     @classmethod
-    def _parse_file(cls, map):
-        word = (ord(map[0]) << 8) + ord(map[1])
-        if ord(map[0]) == 0x78 and (word % 31) == 0:
-            return cls._parse_legacy_object(map)
+    def _parse_file_header(cls, f):
+        magic = f.read(2)
+        if cls._is_legacy_object(magic):
+            return cls._parse_legacy_object_header(magic, f)
         else:
-            return cls._parse_object(map)
+            return cls._parse_object_header(magic, f)
 
     def __init__(self):
         """Don't call this directly"""
         self._sha = None
+        self._filename = None
+        self._chunked_text = []
+        self._needs_parsing = False
+        self._needs_serialization = True
 
     def _deserialize(self, chunks):
         raise NotImplementedError(self._deserialize)
@@ -193,53 +301,96 @@ class ShaFile(object):
     def _serialize(self):
         raise NotImplementedError(self._serialize)
 
+    def _parse_file(self):
+        f = GitFile(self._filename, 'rb')
+        try:
+            magic = f.read(2)
+            if self._is_legacy_object(magic):
+                self._parse_legacy_object(f)
+            else:
+                self._parse_object(f)
+        finally:
+            f.close()
+
     @classmethod
     def from_file(cls, filename):
-        """Get the contents of a SHA file on disk"""
-        size = os.path.getsize(filename)
+        """Get the contents of a SHA file on disk."""
         f = GitFile(filename, 'rb')
         try:
-            map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
-            shafile = cls._parse_file(map)
-            return shafile
+            try:
+                obj = cls._parse_file_header(f)
+                obj._sha = FixedSha(filename_to_hex(filename))
+                obj._needs_parsing = True
+                obj._needs_serialization = True
+                return obj
+            except (IndexError, ValueError), e:
+                raise ObjectFormatException("invalid object header")
         finally:
             f.close()
 
-    @classmethod
-    def from_raw_string(cls, type, string):
+    @staticmethod
+    def from_raw_string(type_num, string):
         """Creates an object of the indicated type from the raw string given.
 
-        Type is the numeric type of an object. String is the raw uncompressed
-        contents.
+        :param type_num: The numeric type of the object.
+        :param string: The raw uncompressed contents.
         """
-        real_class = num_type_map[type]
-        obj = real_class()
-        obj.type = type
+        obj = object_class(type_num)()
         obj.set_raw_string(string)
         return obj
 
-    @classmethod
-    def from_raw_chunks(cls, type, chunks):
+    @staticmethod
+    def from_raw_chunks(type_num, chunks):
         """Creates an object of the indicated type from the raw chunks given.
 
-        Type is the numeric type of an object. Chunks is a sequence of the raw 
-        uncompressed contents.
+        :param type_num: The numeric type of the object.
+        :param chunks: An iterable of the raw uncompressed contents.
         """
-        real_class = num_type_map[type]
-        obj = real_class()
-        obj.type = type
+        obj = object_class(type_num)()
         obj.set_raw_chunks(chunks)
         return obj
 
     @classmethod
     def from_string(cls, string):
-        """Create a blob from a string."""
-        shafile = cls()
-        shafile.set_raw_string(string)
-        return shafile
+        """Create a ShaFile from a string."""
+        obj = cls()
+        obj.set_raw_string(string)
+        return obj
+
+    def _check_has_member(self, member, error_msg):
+        """Check that the object has a given member variable.
+
+        :param member: the member variable to check for
+        :param error_msg: the message for an error if the member is missing
+        :raise ObjectFormatException: with the given error_msg if member is
+            missing or is None
+        """
+        if getattr(self, member, None) is None:
+            raise ObjectFormatException(error_msg)
+
+    def check(self):
+        """Check this object for internal consistency.
+
+        :raise ObjectFormatException: if the object is malformed in some way
+        :raise ChecksumMismatch: if the object was created with a SHA that does
+            not match its contents
+        """
+        # TODO: if we find that error-checking during object parsing is a
+        # performance bottleneck, those checks should be moved to the class's
+        # check() method during optimization so we can still check the object
+        # when necessary.
+        old_sha = self.id
+        try:
+            self._deserialize(self.as_raw_chunks())
+            self._sha = None
+            new_sha = self.id
+        except Exception, e:
+            raise ObjectFormatException(e)
+        if old_sha != new_sha:
+            raise ChecksumMismatch(new_sha, old_sha)
 
     def _header(self):
-        return "%s %lu\0" % (self._type, self.raw_length())
+        return "%s %lu\0" % (self.type_name, self.raw_length())
 
     def raw_length(self):
         """Returns the length of the raw string of this object."""
@@ -257,8 +408,13 @@ class ShaFile(object):
 
     def sha(self):
         """The SHA1 object that is the name of this object."""
-        if self._needs_serialization or self._sha is None:
-            self._sha = self._make_sha()
+        if self._sha is None:
+            # this is a local because as_raw_chunks() overwrites self._sha
+            new_sha = make_sha()
+            new_sha.update(self._header())
+            for chunk in self.as_raw_chunks():
+                new_sha.update(chunk)
+            self._sha = new_sha
         return self._sha
 
     @property
@@ -266,11 +422,12 @@ class ShaFile(object):
         return self.sha().hexdigest()
 
     def get_type(self):
-        return self._num_type
+        return self.type_num
 
     def set_type(self, type):
-        self._num_type = type
+        self.type_num = type
 
+    # DEPRECATED: use type_num or type_name as needed.
     type = property(get_type, set_type)
 
     def __repr__(self):
@@ -291,8 +448,8 @@ class ShaFile(object):
 class Blob(ShaFile):
     """A Git Blob object."""
 
-    _type = BLOB_ID
-    _num_type = 3
+    type_name = 'blob'
+    type_num = 3
 
     def __init__(self):
         super(Blob, self).__init__()
@@ -307,60 +464,125 @@ class Blob(ShaFile):
         self.set_raw_string(data)
 
     data = property(_get_data, _set_data,
-            "The text contained within the blob object.")
+                    "The text contained within the blob object.")
 
     def _get_chunked(self):
+        self._ensure_parsed()
         return self._chunked_text
 
     def _set_chunked(self, chunks):
         self._chunked_text = chunks
 
+    def _serialize(self):
+        if not self._chunked_text:
+            self._ensure_parsed()
+        self._needs_serialization = False
+        return self._chunked_text
+
+    def _deserialize(self, chunks):
+        self._chunked_text = chunks
+
     chunked = property(_get_chunked, _set_chunked,
         "The text within the blob object, as chunks (not necessarily lines).")
 
     @classmethod
     def from_file(cls, filename):
         blob = ShaFile.from_file(filename)
-        if blob._type != cls._type:
+        if not isinstance(blob, cls):
             raise NotBlobError(filename)
         return blob
 
+    def check(self):
+        """Check this object for internal consistency.
+
+        :raise ObjectFormatException: if the object is malformed in some way
+        """
+        super(Blob, self).check()
+
+
+def _parse_tag_or_commit(text):
+    """Parse tag or commit text.
+
+    :param text: the raw text of the tag or commit object.
+    :yield: tuples of (field, value), one per header line, in the order read
+        from the text, possibly including duplicates. Includes a field named
+        None for the freeform tag/commit text.
+    """
+    f = StringIO(text)
+    for l in f:
+        l = l.rstrip("\n")
+        if l == "":
+            # Empty line indicates end of headers
+            break
+        yield l.split(" ", 1)
+    yield (None, f.read())
+    f.close()
+
+
+def parse_tag(text):
+    return _parse_tag_or_commit(text)
+
 
 class Tag(ShaFile):
     """A Git Tag object."""
 
-    _type = TAG_ID
-    _num_type = 4
+    type_name = 'tag'
+    type_num = 4
 
     def __init__(self):
         super(Tag, self).__init__()
-        self._needs_parsing = False
-        self._needs_serialization = True
+        self._tag_timezone_neg_utc = False
 
     @classmethod
     def from_file(cls, filename):
-        blob = ShaFile.from_file(filename)
-        if blob._type != cls._type:
-            raise NotBlobError(filename)
-        return blob
+        tag = ShaFile.from_file(filename)
+        if not isinstance(tag, cls):
+            raise NotTagError(filename)
+        return tag
 
-    @classmethod
-    def from_string(cls, string):
-        """Create a blob from a string."""
-        shafile = cls()
-        shafile.set_raw_string(string)
-        return shafile
+    def check(self):
+        """Check this object for internal consistency.
+
+        :raise ObjectFormatException: if the object is malformed in some way
+        """
+        super(Tag, self).check()
+        self._check_has_member("_object_sha", "missing object sha")
+        self._check_has_member("_object_class", "missing object type")
+        self._check_has_member("_name", "missing tag name")
+
+        if not self._name:
+            raise ObjectFormatException("empty tag name")
+
+        check_hexsha(self._object_sha, "invalid object sha")
+
+        if getattr(self, "_tagger", None):
+            check_identity(self._tagger, "invalid tagger")
+
+        last = None
+        for field, _ in parse_tag("".join(self._chunked_text)):
+            if field == _OBJECT_HEADER and last is not None:
+                raise ObjectFormatException("unexpected object")
+            elif field == _TYPE_HEADER and last != _OBJECT_HEADER:
+                raise ObjectFormatException("unexpected type")
+            elif field == _TAG_HEADER and last != _TYPE_HEADER:
+                raise ObjectFormatException("unexpected tag name")
+            elif field == _TAGGER_HEADER and last != _TAG_HEADER:
+                raise ObjectFormatException("unexpected tagger")
+            last = field
 
     def _serialize(self):
         chunks = []
-        chunks.append("%s %s\n" % (OBJECT_ID, self._object_sha))
-        chunks.append("%s %s\n" % (TYPE_ID, num_type_map[self._object_type]._type))
-        chunks.append("%s %s\n" % (TAG_ID, self._name))
+        chunks.append("%s %s\n" % (_OBJECT_HEADER, self._object_sha))
+        chunks.append("%s %s\n" % (_TYPE_HEADER, self._object_class.type_name))
+        chunks.append("%s %s\n" % (_TAG_HEADER, self._name))
         if self._tagger:
             if self._tag_time is None:
-                chunks.append("%s %s\n" % (TAGGER_ID, self._tagger))
+                chunks.append("%s %s\n" % (_TAGGER_HEADER, self._tagger))
             else:
-                chunks.append("%s %s %d %s\n" % (TAGGER_ID, self._tagger, self._tag_time, format_timezone(self._tag_timezone)))
+                chunks.append("%s %s %d %s\n" % (
+                  _TAGGER_HEADER, self._tagger, self._tag_time,
+                  format_timezone(self._tag_timezone,
+                    self._tag_timezone_neg_utc)))
         chunks.append("\n") # To close headers
         chunks.append(self._message)
         return chunks
@@ -368,45 +590,49 @@ class Tag(ShaFile):
     def _deserialize(self, chunks):
         """Grab the metadata attached to the tag"""
         self._tagger = None
-        f = StringIO("".join(chunks))
-        for l in f:
-            l = l.rstrip("\n")
-            if l == "":
-                break # empty line indicates end of headers
-            (field, value) = l.split(" ", 1)
-            if field == OBJECT_ID:
+        for field, value in parse_tag("".join(chunks)):
+            if field == _OBJECT_HEADER:
                 self._object_sha = value
-            elif field == TYPE_ID:
-                self._object_type = type_map[value]
-            elif field == TAG_ID:
+            elif field == _TYPE_HEADER:
+                obj_class = object_class(value)
+                if not obj_class:
+                    raise ObjectFormatException("Not a known type: %s" % value)
+                self._object_class = obj_class
+            elif field == _TAG_HEADER:
                 self._name = value
-            elif field == TAGGER_ID:
+            elif field == _TAGGER_HEADER:
                 try:
                     sep = value.index("> ")
                 except ValueError:
                     self._tagger = value
                     self._tag_time = None
                     self._tag_timezone = None
+                    self._tag_timezone_neg_utc = False
                 else:
                     self._tagger = value[0:sep+1]
-                    (timetext, timezonetext) = value[sep+2:].rsplit(" ", 1)
                     try:
+                        (timetext, timezonetext) = value[sep+2:].rsplit(" ", 1)
                         self._tag_time = int(timetext)
-                    except ValueError: #Not a unix timestamp
-                        self._tag_time = time.strptime(timetext)
-                    self._tag_timezone = parse_timezone(timezonetext)
+                        self._tag_timezone, self._tag_timezone_neg_utc = \
+                                parse_timezone(timezonetext)
+                    except ValueError, e:
+                        raise ObjectFormatException(e)
+            elif field is None:
+                self._message = value
             else:
-                raise AssertionError("Unknown field %s" % field)
-        self._message = f.read()
+                raise ObjectFormatError("Unknown field %s" % field)
 
     def _get_object(self):
-        """Returns the object pointed by this tag, represented as a tuple(type, sha)"""
+        """Get the object pointed to by this tag.
+
+        :return: tuple of (object class, sha).
+        """
         self._ensure_parsed()
-        return (self._object_type, self._object_sha)
+        return (self._object_class, self._object_sha)
 
     def _set_object(self, value):
         self._ensure_parsed()
-        (self._object_type, self._object_sha) = value
+        (self._object_class, self._object_sha) = value
         self._needs_serialization = True
 
     object = property(_get_object, _set_object)
@@ -425,9 +651,8 @@ def parse_tree(text):
     """Parse a tree text.
 
     :param text: Serialized text to parse
-    :return: Dictionary with names as keys, (mode, sha) tuples as values
+    :yields: tuples of (name, mode, sha)
     """
-    ret = {}
     count = 0
     l = len(text)
     while count < l:
@@ -437,8 +662,7 @@ def parse_tree(text):
         name = text[mode_end+1:name_end]
         count = name_end+21
         sha = text[name_end+1:count]
-        ret[name] = (mode, sha_to_hex(sha))
-    return ret
+        yield (name, mode, sha_to_hex(sha))
 
 
 def serialize_tree(items):
@@ -458,32 +682,33 @@ def sorted_tree_items(entries):
     :param entries: Dictionary mapping names to (mode, sha) tuples
     :return: Iterator over (name, mode, sha)
     """
-    def cmp_entry((name1, value1), (name2, value2)):
-        if stat.S_ISDIR(value1[0]):
-            name1 += "/"
-        if stat.S_ISDIR(value2[0]):
-            name2 += "/"
-        return cmp(name1, name2)
     for name, entry in sorted(entries.iteritems(), cmp=cmp_entry):
         yield name, entry[0], entry[1]
 
 
+def cmp_entry((name1, value1), (name2, value2)):
+    """Compare two tree entries."""
+    if stat.S_ISDIR(value1[0]):
+        name1 += "/"
+    if stat.S_ISDIR(value2[0]):
+        name2 += "/"
+    return cmp(name1, name2)
+
+
 class Tree(ShaFile):
     """A Git tree object"""
 
-    _type = TREE_ID
-    _num_type = 2
+    type_name = 'tree'
+    type_num = 2
 
     def __init__(self):
         super(Tree, self).__init__()
         self._entries = {}
-        self._needs_parsing = False
-        self._needs_serialization = True
 
     @classmethod
     def from_file(cls, filename):
         tree = ShaFile.from_file(filename)
-        if tree._type != cls._type:
+        if not isinstance(tree, cls):
             raise NotTreeError(filename)
         return tree
 
@@ -511,6 +736,10 @@ class Tree(ShaFile):
         self._ensure_parsed()
         return len(self._entries)
 
+    def __iter__(self):
+        self._ensure_parsed()
+        return iter(self._entries)
+
     def add(self, mode, name, hexsha):
         assert type(mode) == int
         assert type(name) == str
@@ -528,8 +757,7 @@ class Tree(ShaFile):
             (mode, name, hexsha) for (name, mode, hexsha) in self.iteritems()]
 
     def iteritems(self):
-        """Iterate over all entries in the order in which they would be
-        serialized.
+        """Iterate over entries in the order in which they would be serialized.
 
         :return: Iterator over (name, mode, sha) tuples
         """
@@ -538,7 +766,40 @@ class Tree(ShaFile):
 
     def _deserialize(self, chunks):
         """Grab the entries in the tree"""
-        self._entries = parse_tree("".join(chunks))
+        try:
+            parsed_entries = parse_tree("".join(chunks))
+        except ValueError, e:
+            raise ObjectFormatException(e)
+        # TODO: list comprehension is for efficiency in the common (small) case;
+        # if memory efficiency in the large case is a concern, use a genexp.
+        self._entries = dict([(n, (m, s)) for n, m, s in parsed_entries])
+
+    def check(self):
+        """Check this object for internal consistency.
+
+        :raise ObjectFormatException: if the object is malformed in some way
+        """
+        super(Tree, self).check()
+        last = None
+        allowed_modes = (stat.S_IFREG | 0755, stat.S_IFREG | 0644,
+                         stat.S_IFLNK, stat.S_IFDIR, S_IFGITLINK,
+                         # TODO: optionally exclude as in git fsck --strict
+                         stat.S_IFREG | 0664)
+        for name, mode, sha in parse_tree("".join(self._chunked_text)):
+            check_hexsha(sha, 'invalid sha %s' % sha)
+            if '/' in name or name in ('', '.', '..'):
+                raise ObjectFormatException('invalid name %s' % name)
+
+            if mode not in allowed_modes:
+                raise ObjectFormatException('invalid mode %06o' % mode)
+
+            entry = (name, (mode, sha))
+            if last:
+                if cmp_entry(last, entry) > 0:
+                    raise ObjectFormatException('entries not sorted')
+                if name == last[0]:
+                    raise ObjectFormatException('duplicate entry %s' % name)
+            last = entry
 
     def _serialize(self):
         return list(serialize_tree(self.iteritems()))
@@ -556,39 +817,47 @@ class Tree(ShaFile):
 
 def parse_timezone(text):
     offset = int(text)
+    negative_utc = (offset == 0 and text[0] == '-')
     signum = (offset < 0) and -1 or 1
     offset = abs(offset)
     hours = int(offset / 100)
     minutes = (offset % 100)
-    return signum * (hours * 3600 + minutes * 60)
+    return signum * (hours * 3600 + minutes * 60), negative_utc
 
 
-def format_timezone(offset):
+def format_timezone(offset, negative_utc=False):
     if offset % 60 != 0:
         raise ValueError("Unable to handle non-minute offset.")
-    sign = (offset < 0) and '-' or '+'
+    if offset < 0 or (offset == 0 and negative_utc):
+        sign = '-'
+    else:
+        sign = '+'
     offset = abs(offset)
     return '%c%02d%02d' % (sign, offset / 3600, (offset / 60) % 60)
 
 
+def parse_commit(text):
+    return _parse_tag_or_commit(text)
+
+
 class Commit(ShaFile):
     """A git commit object"""
 
-    _type = COMMIT_ID
-    _num_type = 1
+    type_name = 'commit'
+    type_num = 1
 
     def __init__(self):
         super(Commit, self).__init__()
         self._parents = []
         self._encoding = None
-        self._needs_parsing = False
-        self._needs_serialization = True
         self._extra = {}
+        self._author_timezone_neg_utc = False
+        self._commit_timezone_neg_utc = False
 
     @classmethod
     def from_file(cls, filename):
         commit = ShaFile.from_file(filename)
-        if commit._type != cls._type:
+        if not isinstance(commit, cls):
             raise NotCommitError(filename)
         return commit
 
@@ -596,40 +865,79 @@ class Commit(ShaFile):
         self._parents = []
         self._extra = []
         self._author = None
-        f = StringIO("".join(chunks))
-        for l in f:
-            l = l.rstrip("\n")
-            if l == "":
-                # Empty line indicates end of headers
-                break
-            (field, value) = l.split(" ", 1)
-            if field == TREE_ID:
+        for field, value in parse_commit("".join(self._chunked_text)):
+            if field == _TREE_HEADER:
                 self._tree = value
-            elif field == PARENT_ID:
+            elif field == _PARENT_HEADER:
                 self._parents.append(value)
-            elif field == AUTHOR_ID:
+            elif field == _AUTHOR_HEADER:
                 self._author, timetext, timezonetext = value.rsplit(" ", 2)
                 self._author_time = int(timetext)
-                self._author_timezone = parse_timezone(timezonetext)
-            elif field == COMMITTER_ID:
+                self._author_timezone, self._author_timezone_neg_utc =\
+                    parse_timezone(timezonetext)
+            elif field == _COMMITTER_HEADER:
                 self._committer, timetext, timezonetext = value.rsplit(" ", 2)
                 self._commit_time = int(timetext)
-                self._commit_timezone = parse_timezone(timezonetext)
-            elif field == ENCODING_ID:
+                self._commit_timezone, self._commit_timezone_neg_utc =\
+                    parse_timezone(timezonetext)
+            elif field == _ENCODING_HEADER:
                 self._encoding = value
+            elif field is None:
+                self._message = value
             else:
                 self._extra.append((field, value))
-        self._message = f.read()
+
+    def check(self):
+        """Check this object for internal consistency.
+
+        :raise ObjectFormatException: if the object is malformed in some way
+        """
+        super(Commit, self).check()
+        self._check_has_member("_tree", "missing tree")
+        self._check_has_member("_author", "missing author")
+        self._check_has_member("_committer", "missing committer")
+        # times are currently checked when set
+
+        for parent in self._parents:
+            check_hexsha(parent, "invalid parent sha")
+        check_hexsha(self._tree, "invalid tree sha")
+
+        check_identity(self._author, "invalid author")
+        check_identity(self._committer, "invalid committer")
+
+        last = None
+        for field, _ in parse_commit("".join(self._chunked_text)):
+            if field == _TREE_HEADER and last is not None:
+                raise ObjectFormatException("unexpected tree")
+            elif field == _PARENT_HEADER and last not in (_PARENT_HEADER,
+                                                          _TREE_HEADER):
+                raise ObjectFormatException("unexpected parent")
+            elif field == _AUTHOR_HEADER and last not in (_TREE_HEADER,
+                                                          _PARENT_HEADER):
+                raise ObjectFormatException("unexpected author")
+            elif field == _COMMITTER_HEADER and last != _AUTHOR_HEADER:
+                raise ObjectFormatException("unexpected committer")
+            elif field == _ENCODING_HEADER and last != _COMMITTER_HEADER:
+                raise ObjectFormatException("unexpected encoding")
+            last = field
+
+        # TODO: optionally check for duplicate parents
 
     def _serialize(self):
         chunks = []
-        chunks.append("%s %s\n" % (TREE_ID, self._tree))
+        chunks.append("%s %s\n" % (_TREE_HEADER, self._tree))
         for p in self._parents:
-            chunks.append("%s %s\n" % (PARENT_ID, p))
-        chunks.append("%s %s %s %s\n" % (AUTHOR_ID, self._author, str(self._author_time), format_timezone(self._author_timezone)))
-        chunks.append("%s %s %s %s\n" % (COMMITTER_ID, self._committer, str(self._commit_time), format_timezone(self._commit_timezone)))
+            chunks.append("%s %s\n" % (_PARENT_HEADER, p))
+        chunks.append("%s %s %s %s\n" % (
+          _AUTHOR_HEADER, self._author, str(self._author_time),
+          format_timezone(self._author_timezone,
+                          self._author_timezone_neg_utc)))
+        chunks.append("%s %s %s %s\n" % (
+          _COMMITTER_HEADER, self._committer, str(self._commit_time),
+          format_timezone(self._commit_timezone,
+                          self._commit_timezone_neg_utc)))
         if self.encoding:
-            chunks.append("%s %s\n" % (ENCODING_ID, self.encoding))
+            chunks.append("%s %s\n" % (_ENCODING_HEADER, self.encoding))
         for k, v in self.extra:
             if "\n" in k or "\n" in v:
                 raise AssertionError("newline in extra data: %r -> %r" % (k, v))
@@ -685,22 +993,24 @@ class Commit(ShaFile):
         "Encoding of the commit message.")
 
 
-type_map = {
-    BLOB_ID : Blob,
-    TREE_ID : Tree,
-    COMMIT_ID : Commit,
-    TAG_ID: Tag,
-}
+OBJECT_CLASSES = (
+    Commit,
+    Tree,
+    Blob,
+    Tag,
+    )
+
+_TYPE_MAP = {}
+
+for cls in OBJECT_CLASSES:
+    _TYPE_MAP[cls.type_name] = cls
+    _TYPE_MAP[cls.type_num] = cls
+
 
-num_type_map = {
-    0: None,
-    1: Commit,
-    2: Tree,
-    3: Blob,
-    4: Tag,
-    # 5 Is reserved for further expansion
-}
 
+# Hold on to the pure-python implementations for testing
+_parse_tree_py = parse_tree
+_sorted_tree_items_py = sorted_tree_items
 try:
     # Try to import C versions
     from dulwich._objects import parse_tree, sorted_tree_items

+ 70 - 55
dulwich/pack.py

@@ -36,6 +36,7 @@ except ImportError:
     from misc import defaultdict
 
 import difflib
+import errno
 from itertools import (
     chain,
     imap,
@@ -124,22 +125,41 @@ def load_pack_index(path):
     return load_pack_index_file(path, f)
 
 
+def _load_file_contents(f, size=None):
+    fileno = getattr(f, 'fileno', None)
+    # Attempt to use mmap if possible
+    if fileno is not None:
+        fd = f.fileno()
+        if size is None:
+            size = os.fstat(fd).st_size
+        try:
+            contents = mmap.mmap(fd, size, access=mmap.ACCESS_READ)
+        except mmap.error:
+            # Perhaps a socket?
+            pass
+        else:
+            return contents, size
+    contents = f.read()
+    size = len(contents)
+    return contents, size
+
+
 def load_pack_index_file(path, f):
     """Load an index file from a file-like object.
 
     :param path: Path for the index file
     :param f: File-like object
     """
-    if f.read(4) == '\377tOc':
-        version = struct.unpack(">L", f.read(4))[0]
+    contents, size = _load_file_contents(f)
+    if contents[:4] == '\377tOc':
+        version = struct.unpack(">L", contents[4:8])[0]
         if version == 2:
-            f.seek(0)
-            return PackIndex2(path, file=f)
+            return PackIndex2(path, file=f, contents=contents,
+                size=size)
         else:
             raise KeyError("Unknown pack index format %d" % version)
     else:
-        f.seek(0)
-        return PackIndex1(path, file=f)
+        return PackIndex1(path, file=f, contents=contents, size=size)
 
 
 def bisect_find_sha(start, end, sha, unpack_name):
@@ -179,7 +199,7 @@ class PackIndex(object):
     the start and end offset and then bisect in to find if the value is present.
     """
   
-    def __init__(self, filename, file=None, size=None):
+    def __init__(self, filename, file=None, contents=None, size=None):
         """Create a pack index object.
     
         Provide it with the name of the index file to consider, and it will map
@@ -192,19 +212,10 @@ class PackIndex(object):
             self._file = GitFile(filename, 'rb')
         else:
             self._file = file
-        fileno = getattr(self._file, 'fileno', None)
-        if fileno is not None:
-            fd = self._file.fileno()
-            if size is None:
-                self._size = os.fstat(fd).st_size
-            else:
-                self._size = size
-            self._contents = mmap.mmap(fd, self._size,
-                access=mmap.ACCESS_READ)
+        if contents is None:
+            self._contents, self._size = _load_file_contents(file, size)
         else:
-            self._file.seek(0)
-            self._contents = self._file.read()
-            self._size = len(self._contents)
+            self._contents, self._size = (contents, size)
   
     def __eq__(self, other):
         if not isinstance(other, PackIndex):
@@ -213,7 +224,8 @@ class PackIndex(object):
         if self._fan_out_table != other._fan_out_table:
             return False
     
-        for (name1, _, _), (name2, _, _) in izip(self.iterentries(), other.iterentries()):
+        for (name1, _, _), (name2, _, _) in izip(self.iterentries(),
+                                                 other.iterentries()):
             if name1 != name2:
                 return False
         return True
@@ -265,7 +277,8 @@ class PackIndex(object):
     def iterentries(self):
         """Iterate over the entries in this pack index.
        
-        Will yield tuples with object name, offset in packfile and crc32 checksum.
+        Will yield tuples with object name, offset in packfile and crc32
+        checksum.
         """
         for i in range(len(self)):
             yield self._unpack_entry(i)
@@ -273,7 +286,8 @@ class PackIndex(object):
     def _read_fan_out_table(self, start_offset):
         ret = []
         for i in range(0x100):
-            ret.append(struct.unpack(">L", self._contents[start_offset+i*4:start_offset+(i+1)*4])[0])
+            ret.append(struct.unpack(">L",
+                self._contents[start_offset+i*4:start_offset+(i+1)*4])[0])
         return ret
   
     def check(self):
@@ -305,9 +319,9 @@ class PackIndex(object):
     def object_index(self, sha):
         """Return the index in to the corresponding packfile for the object.
     
-        Given the name of an object it will return the offset that object lives
-        at within the corresponding pack file. If the pack file doesn't have the
-        object then None will be returned.
+        Given the name of an object it will return the offset that object
+        lives at within the corresponding pack file. If the pack file doesn't
+        have the object then None will be returned.
         """
         if len(sha) == 40:
             sha = hex_to_sha(sha)
@@ -335,8 +349,8 @@ class PackIndex(object):
 class PackIndex1(PackIndex):
     """Version 1 Pack Index."""
 
-    def __init__(self, filename, file=None, size=None):
-        PackIndex.__init__(self, filename, file, size)
+    def __init__(self, filename, file=None, contents=None, size=None):
+        PackIndex.__init__(self, filename, file, contents, size)
         self.version = 1
         self._fan_out_table = self._read_fan_out_table(0)
 
@@ -361,8 +375,8 @@ class PackIndex1(PackIndex):
 class PackIndex2(PackIndex):
     """Version 2 Pack Index."""
 
-    def __init__(self, filename, file=None, size=None):
-        PackIndex.__init__(self, filename, file, size)
+    def __init__(self, filename, file=None, contents=None, size=None):
+        PackIndex.__init__(self, filename, file, contents, size)
         assert self._contents[:4] == '\377tOc', "Not a v2 pack index file"
         (self.version, ) = unpack_from(">L", self._contents, 4)
         assert self.version == 2, "Version was %d" % self.version
@@ -427,17 +441,17 @@ def unpack_object(read):
             delta_base_offset += 1
             delta_base_offset <<= 7
             delta_base_offset += (byte & 0x7f)
-        uncomp, comp_len, unused = read_zlib_chunks(read, size)
+        uncomp, comp_len, unused = read_zlib_chunks(read)
         assert size == chunks_length(uncomp)
         return type, (delta_base_offset, uncomp), comp_len+raw_base, unused
     elif type == 7: # ref delta
         basename = read(20)
         raw_base += 20
-        uncomp, comp_len, unused = read_zlib_chunks(read, size)
+        uncomp, comp_len, unused = read_zlib_chunks(read)
         assert size == chunks_length(uncomp)
         return type, (basename, uncomp), comp_len+raw_base, unused
     else:
-        uncomp, comp_len, unused = read_zlib_chunks(read, size)
+        uncomp, comp_len, unused = read_zlib_chunks(read)
         assert chunks_length(uncomp) == size
         return type, uncomp, comp_len+raw_base, unused
 
@@ -472,16 +486,17 @@ class PackData(object):
     buffer from the start of the deflated object on. This is bad, but until I
     get mmap sorted out it will have to do.
   
-    Currently there are no integrity checks done. Also no attempt is made to try
-    and detect the delta case, or a request for an object at the wrong position.
-    It will all just throw a zlib or KeyError.
+    Currently there are no integrity checks done. Also no attempt is made to
+    try and detect the delta case, or a request for an object at the wrong
+    position.  It will all just throw a zlib or KeyError.
     """
   
     def __init__(self, filename, file=None, size=None):
-        """Create a PackData object that represents the pack in the given filename.
+        """Create a PackData object that represents the pack in the given
+        filename.
     
-        The file must exist and stay readable until the object is disposed of. It
-        must also stay the same size. It will be mapped whenever needed.
+        The file must exist and stay readable until the object is disposed of.
+        It must also stay the same size. It will be mapped whenever needed.
     
         Currently there is a restriction on the size of the pack as the python
         mmap implementation is flawed.
@@ -625,9 +640,9 @@ class PackData(object):
         for (offset, type, obj, crc32) in todo:
             assert isinstance(offset, int)
             assert isinstance(type, int)
-            assert isinstance(obj, list) or isinstance(obj, str)
             try:
-                type, obj = self.resolve_object(offset, type, obj, get_ref_text)
+                type, obj = self.resolve_object(offset, type, obj,
+                    get_ref_text)
             except Postpone, (sha, ):
                 postponed[sha].append((offset, type, obj))
             else:
@@ -656,8 +671,8 @@ class PackData(object):
         """Create a version 1 file for this data file.
 
         :param filename: Index filename.
-        :param resolve_ext_ref: Function to use for resolving externally referenced
-            SHA1s (for thin packs)
+        :param resolve_ext_ref: Function to use for resolving externally
+            referenced SHA1s (for thin packs)
         :param progress: Progress report function
         """
         entries = self.sorted_entries(resolve_ext_ref, progress=progress)
@@ -667,8 +682,8 @@ class PackData(object):
         """Create a version 2 index file for this data file.
 
         :param filename: Index filename.
-        :param resolve_ext_ref: Function to use for resolving externally referenced
-            SHA1s (for thin packs)
+        :param resolve_ext_ref: Function to use for resolving externally
+            referenced SHA1s (for thin packs)
         :param progress: Progress report function
         """
         entries = self.sorted_entries(resolve_ext_ref, progress=progress)
@@ -679,8 +694,8 @@ class PackData(object):
         """Create an  index file for this data file.
 
         :param filename: Index filename.
-        :param resolve_ext_ref: Function to use for resolving externally referenced
-            SHA1s (for thin packs)
+        :param resolve_ext_ref: Function to use for resolving externally
+            referenced SHA1s (for thin packs)
         :param progress: Progress report function
         """
         if version == 1:
@@ -702,8 +717,8 @@ class PackData(object):
     def get_object_at(self, offset):
         """Given an offset in to the packfile return the object that is there.
     
-        Using the associated index the location of an object can be looked up, and
-        then the packfile can be asked directly for that object using this
+        Using the associated index the location of an object can be looked up,
+        and then the packfile can be asked directly for that object using this
         function.
         """
         if offset in self._offset_cache:
@@ -834,7 +849,7 @@ def write_pack_data(f, objects, num_objects, window=10):
     # This helps us find good objects to diff against us
     magic = []
     for obj, path in recency:
-        magic.append( (obj.type, path, 1, -obj.raw_length(), obj) )
+        magic.append( (obj.type_num, path, 1, -obj.raw_length(), obj) )
     magic.sort()
     # Build a map of objects and their index in magic - so we can find preceeding objects
     # to diff against
@@ -849,14 +864,14 @@ def write_pack_data(f, objects, num_objects, window=10):
     f.write(struct.pack(">L", num_objects)) # Number of objects in pack
     for o, path in recency:
         sha1 = o.sha().digest()
-        orig_t = o.type
+        orig_t = o.type_num
         raw = o.as_raw_string()
         winner = raw
         t = orig_t
         #for i in range(offs[o]-window, window):
         #    if i < 0 or i >= len(offs): continue
         #    b = magic[i][4]
-        #    if b.type != orig_t: continue
+        #    if b.type_num != orig_t: continue
         #    base = b.as_raw_string()
         #    delta = create_delta(base, raw)
         #    if len(delta) < len(winner):
@@ -871,8 +886,8 @@ def write_pack_index_v1(filename, entries, pack_checksum):
     """Write a new pack index file.
 
     :param filename: The filename of the new pack index file.
-    :param entries: List of tuples with object name (sha), offset_in_pack,  and
-            crc32_checksum.
+    :param entries: List of tuples with object name (sha), offset_in_pack,
+        and crc32_checksum.
     :param pack_checksum: Checksum of the pack file.
     """
     f = GitFile(filename, 'wb')
@@ -1020,8 +1035,8 @@ def write_pack_index_v2(filename, entries, pack_checksum):
     """Write a new pack index file.
 
     :param filename: The filename of the new pack index file.
-    :param entries: List of tuples with object name (sha), offset_in_pack,  and
-            crc32_checksum.
+    :param entries: List of tuples with object name (sha), offset_in_pack, and
+        crc32_checksum.
     :param pack_checksum: Checksum of the pack file.
     """
     f = GitFile(filename, 'wb')

+ 33 - 20
dulwich/repo.py

@@ -49,7 +49,7 @@ from dulwich.objects import (
     Tag,
     Tree,
     hex_to_sha,
-    num_type_map,
+    object_class,
     )
 import warnings
 
@@ -62,9 +62,6 @@ REFSDIR_HEADS = 'heads'
 INDEX_FILENAME = "index"
 
 BASE_DIRECTORIES = [
-    [OBJECTDIR], 
-    [OBJECTDIR, "info"], 
-    [OBJECTDIR, "pack"],
     ["branches"],
     [REFSDIR],
     [REFSDIR, REFSDIR_TAGS],
@@ -77,7 +74,7 @@ BASE_DIRECTORIES = [
 def read_info_refs(f):
     ret = {}
     for l in f.readlines():
-        (sha, name) = l.rstrip("\n").split("\t", 1)
+        (sha, name) = l.rstrip("\r\n").split("\t", 1)
         ret[name] = sha
     return ret
 
@@ -118,6 +115,12 @@ class RefsContainer(object):
     """A container for refs."""
 
     def set_ref(self, name, other):
+        warnings.warn("RefsContainer.set_ref() is deprecated."
+            "Use set_symblic_ref instead.",
+            category=DeprecationWarning, stacklevel=2)
+        return self.set_symbolic_ref(name, other)
+
+    def set_symbolic_ref(self, name, other):
         """Make a ref point at another ref.
 
         :param name: Name of the ref to set
@@ -200,6 +203,18 @@ class RefsContainer(object):
         if not name.startswith('refs/') or not check_ref_format(name[5:]):
             raise KeyError(name)
 
+    def read_ref(self, refname):
+        """Read a reference without following any references.
+
+        :param refname: The name of the reference
+        :return: The contents of the ref file, or None if it does 
+            not exist.
+        """
+        contents = self.read_loose_ref(refname)
+        if not contents:
+            contents = self.get_packed_refs().get(refname, None)
+        return contents
+
     def read_loose_ref(self, name):
         """Read a loose reference and return its contents.
 
@@ -220,20 +235,16 @@ class RefsContainer(object):
         depth = 0
         while contents.startswith(SYMREF):
             refname = contents[len(SYMREF):]
-            contents = self.read_loose_ref(refname)
+            contents = self.read_ref(refname)
             if not contents:
-                contents = self.get_packed_refs().get(refname, None)
-                if not contents:
-                    break
+                break
             depth += 1
             if depth > 5:
                 raise KeyError(name)
         return refname, contents
 
     def __contains__(self, refname):
-        if self.read_loose_ref(refname):
-            return True
-        if self.get_packed_refs().get(refname, None):
+        if self.read_ref(refname):
             return True
         return False
 
@@ -380,7 +391,7 @@ class DiskRefsContainer(RefsContainer):
                 header = f.read(len(SYMREF))
                 if header == SYMREF:
                     # Read only the first line
-                    return header + iter(f).next().rstrip("\n")
+                    return header + iter(f).next().rstrip("\r\n")
                 else:
                     # Read only the first 40 bytes
                     return header + f.read(40-len(SYMREF))
@@ -572,7 +583,7 @@ def read_packed_refs_with_peeled(f):
     for l in f:
         if l[0] == "#":
             continue
-        l = l.rstrip("\n")
+        l = l.rstrip("\r\n")
         if l[0] == "^":
             if not last:
                 raise PackedRefsException("unexpected peeled ref line")
@@ -698,7 +709,7 @@ class BaseRepo(object):
     def _get_object(self, sha, cls):
         assert len(sha) in (20, 40)
         ret = self.get_object(sha)
-        if ret._type != cls._type:
+        if not isinstance(ret, cls):
             if cls is Commit:
                 raise NotCommitError(ret)
             elif cls is Blob:
@@ -708,7 +719,8 @@ class BaseRepo(object):
             elif cls is Tag:
                 raise NotTagError(ret)
             else:
-                raise Exception("Type invalid: %r != %r" % (ret._type, cls._type))
+                raise Exception("Type invalid: %r != %r" % (
+                  ret.type_name, cls.type_name))
         return ret
 
     def get_object(self, sha):
@@ -784,9 +796,9 @@ class BaseRepo(object):
         if cached is not None:
             return cached
         obj = self[ref]
-        obj_type = num_type_map[obj.type]
-        while obj_type == Tag:
-            obj_type, sha = obj.object
+        obj_class = object_class(obj.type_name)
+        while obj_class is Tag:
+            obj_class, sha = obj.object
             obj = self.get_object(sha)
         return obj.id
 
@@ -1001,8 +1013,9 @@ class Repo(BaseRepo):
     def init_bare(cls, path, mkdir=True):
         for d in BASE_DIRECTORIES:
             os.mkdir(os.path.join(path, *d))
+        DiskObjectStore.init(os.path.join(path, OBJECTDIR))
         ret = cls(path)
-        ret.refs.set_ref("HEAD", "refs/heads/master")
+        ret.refs.set_symbolic_ref("HEAD", "refs/heads/master")
         ret._put_named_file('description', "Unnamed repository")
         ret._put_named_file('config', """[core]
     repositoryformatversion = 0

+ 106 - 91
dulwich/server.py

@@ -49,14 +49,27 @@ from dulwich.protocol import (
     MULTI_ACK_DETAILED,
     ack_type,
     )
-from dulwich.repo import (
-    Repo,
-    )
 from dulwich.pack import (
     write_pack_data,
     )
 
 class Backend(object):
+    """A backend for the Git smart server implementation."""
+
+    def open_repository(self, path):
+        """Open the repository at a path."""
+        raise NotImplementedError(self.open_repository)
+
+
+class BackendRepo(object):
+    """Repository abstraction used by the Git server.
+    
+    Please note that the methods required here are a 
+    subset of those provided by dulwich.repo.Repo.
+    """
+
+    object_store = None
+    refs = None
 
     def get_refs(self):
         """
@@ -66,14 +79,16 @@ class Backend(object):
         """
         raise NotImplementedError
 
-    def apply_pack(self, refs, read, delete_refs=True):
-        """ Import a set of changes into a repository and update the refs
+    def get_peeled(self, name):
+        """Return the cached peeled value of a ref, if available.
 
-        :param refs: list of tuple(name, sha)
-        :param read: callback to read from the incoming pack
-        :param delete_refs: whether to allow deleting refs
+        :param name: Name of the ref to peel
+        :return: The peeled value of the ref. If the ref is known not point to
+            a tag, this will be the SHA the ref refers to. If no cached
+            information about a tag is available, this method may return None,
+            but it should attempt to peel the tag if possible.
         """
-        raise NotImplementedError
+        return None
 
     def fetch_objects(self, determine_wants, graph_walker, progress,
                       get_tagged=None):
@@ -87,71 +102,15 @@ class Backend(object):
         raise NotImplementedError
 
 
-class GitBackend(Backend):
+class DictBackend(Backend):
+    """Trivial backend that looks up Git repositories in a dictionary."""
 
-    def __init__(self, repo=None):
-        if repo is None:
-            repo = Repo(tmpfile.mkdtemp())
-        self.repo = repo
-        self.refs = self.repo.refs
-        self.object_store = self.repo.object_store
-        self.fetch_objects = self.repo.fetch_objects
-        self.get_refs = self.repo.get_refs
-
-    def apply_pack(self, refs, read, delete_refs=True):
-        f, commit = self.repo.object_store.add_thin_pack()
-        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
-        status = []
-        unpack_error = None
-        # TODO: more informative error messages than just the exception string
-        try:
-            # TODO: decode the pack as we stream to avoid blocking reads beyond
-            # the end of data (when using HTTP/1.1 chunked encoding)
-            while True:
-                data = read(10240)
-                if not data:
-                    break
-                f.write(data)
-        except all_exceptions, e:
-            unpack_error = str(e).replace('\n', '')
-        try:
-            commit()
-        except all_exceptions, e:
-            if not unpack_error:
-                unpack_error = str(e).replace('\n', '')
+    def __init__(self, repos):
+        self.repos = repos
 
-        if unpack_error:
-            status.append(('unpack', unpack_error))
-        else:
-            status.append(('unpack', 'ok'))
-
-        for oldsha, sha, ref in refs:
-            ref_error = None
-            try:
-                if sha == ZERO_SHA:
-                    if not delete_refs:
-                        raise GitProtocolError(
-                          'Attempted to delete refs without delete-refs '
-                          'capability.')
-                    try:
-                        del self.repo.refs[ref]
-                    except all_exceptions:
-                        ref_error = 'failed to delete'
-                else:
-                    try:
-                        self.repo.refs[ref] = sha
-                    except all_exceptions:
-                        ref_error = 'failed to write'
-            except KeyError, e:
-                ref_error = 'bad ref'
-            if ref_error:
-                status.append((ref, ref_error))
-            else:
-                status.append((ref, 'ok'))
-
-
-        print "pack applied"
-        return status
+    def open_repository(self, path):
+        # FIXME: What to do in case there is no repo ?
+        return self.repos[path]
 
 
 class Handler(object):
@@ -198,9 +157,10 @@ class Handler(object):
 class UploadPackHandler(Handler):
     """Protocol handler for uploading a pack to the server."""
 
-    def __init__(self, backend, read, write,
+    def __init__(self, backend, args, read, write,
                  stateless_rpc=False, advertise_refs=False):
         Handler.__init__(self, backend, read, write)
+        self.repo = backend.open_repository(args[0])
         self._graph_walker = None
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
@@ -230,14 +190,14 @@ class UploadPackHandler(Handler):
         if not self.has_capability("include-tag"):
             return {}
         if refs is None:
-            refs = self.backend.get_refs()
+            refs = self.repo.get_refs()
         if repo is None:
-            repo = getattr(self.backend, "repo", None)
+            repo = getattr(self.repo, "repo", None)
             if repo is None:
                 # Bail if we don't have a Repo available; this is ok since
                 # clients must be able to handle if the server doesn't include
                 # all relevant tags.
-                # TODO: either guarantee a Repo, or fix behavior when missing
+                # TODO: fix behavior when missing
                 return {}
         tagged = {}
         for name, sha in refs.iteritems():
@@ -249,8 +209,9 @@ class UploadPackHandler(Handler):
     def handle(self):
         write = lambda x: self.proto.write_sideband(1, x)
 
-        graph_walker = ProtocolGraphWalker(self)
-        objects_iter = self.backend.fetch_objects(
+        graph_walker = ProtocolGraphWalker(self, self.repo.object_store,
+            self.repo.get_peeled)
+        objects_iter = self.repo.fetch_objects(
           graph_walker.determine_wants, graph_walker, self.progress,
           get_tagged=self.get_tagged)
 
@@ -270,9 +231,9 @@ class UploadPackHandler(Handler):
 class ProtocolGraphWalker(object):
     """A graph walker that knows the git protocol.
 
-    As a graph walker, this class implements ack(), next(), and reset(). It also
-    contains some base methods for interacting with the wire and walking the
-    commit tree.
+    As a graph walker, this class implements ack(), next(), and reset(). It
+    also contains some base methods for interacting with the wire and walking
+    the commit tree.
 
     The work of determining which acks to send is passed on to the
     implementation instance stored in _impl. The reason for this is that we do
@@ -280,9 +241,10 @@ class ProtocolGraphWalker(object):
     call to set_ack_level() is required to set up the implementation, before any
     calls to next() or ack() are made.
     """
-    def __init__(self, handler):
+    def __init__(self, handler, object_store, get_peeled):
         self.handler = handler
-        self.store = handler.backend.object_store
+        self.store = object_store
+        self.get_peeled = get_peeled
         self.proto = handler.proto
         self.stateless_rpc = handler.stateless_rpc
         self.advertise_refs = handler.advertise_refs
@@ -312,7 +274,7 @@ class ProtocolGraphWalker(object):
                 if not i:
                     line = "%s\x00%s" % (line, self.handler.capability_line())
                 self.proto.write_pkt_line("%s\n" % line)
-                peeled_sha = self.handler.backend.repo.get_peeled(ref)
+                peeled_sha = self.get_peeled(ref)
                 if peeled_sha != sha:
                     self.proto.write_pkt_line('%s %s^{}\n' %
                                               (peeled_sha, ref))
@@ -421,10 +383,10 @@ class ProtocolGraphWalker(object):
             commit = pending.popleft()
             if commit.id in haves:
                 return True
-            if not getattr(commit, 'get_parents', None):
+            if commit.type_name != "commit":
                 # non-commit wants are assumed to be satisfied
                 continue
-            for parent in commit.get_parents():
+            for parent in commit.parents:
                 parent_obj = self.store[parent]
                 # TODO: handle parents with later commit times than children
                 if parent_obj.commit_time >= earliest:
@@ -559,17 +521,71 @@ class MultiAckDetailedGraphWalkerImpl(object):
 class ReceivePackHandler(Handler):
     """Protocol handler for downloading a pack from the client."""
 
-    def __init__(self, backend, read, write,
+    def __init__(self, backend, args, read, write,
                  stateless_rpc=False, advertise_refs=False):
         Handler.__init__(self, backend, read, write)
+        self.repo = backend.open_repository(args[0])
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
 
     def capabilities(self):
         return ("report-status", "delete-refs")
 
+    def _apply_pack(self, refs, read):
+        f, commit = self.repo.object_store.add_thin_pack()
+        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
+        status = []
+        unpack_error = None
+        # TODO: more informative error messages than just the exception string
+        try:
+            # TODO: decode the pack as we stream to avoid blocking reads beyond
+            # the end of data (when using HTTP/1.1 chunked encoding)
+            while True:
+                data = read(10240)
+                if not data:
+                    break
+                f.write(data)
+        except all_exceptions, e:
+            unpack_error = str(e).replace('\n', '')
+        try:
+            commit()
+        except all_exceptions, e:
+            if not unpack_error:
+                unpack_error = str(e).replace('\n', '')
+
+        if unpack_error:
+            status.append(('unpack', unpack_error))
+        else:
+            status.append(('unpack', 'ok'))
+
+        for oldsha, sha, ref in refs:
+            ref_error = None
+            try:
+                if sha == ZERO_SHA:
+                    if not self.has_capability('delete-refs'):
+                        raise GitProtocolError(
+                          'Attempted to delete refs without delete-refs '
+                          'capability.')
+                    try:
+                        del self.repo.refs[ref]
+                    except all_exceptions:
+                        ref_error = 'failed to delete'
+                else:
+                    try:
+                        self.repo.refs[ref] = sha
+                    except all_exceptions:
+                        ref_error = 'failed to write'
+            except KeyError, e:
+                ref_error = 'bad ref'
+            if ref_error:
+                status.append((ref, ref_error))
+            else:
+                status.append((ref, 'ok'))
+
+        return status
+
     def handle(self):
-        refs = self.backend.get_refs().items()
+        refs = self.repo.get_refs().items()
 
         if self.advertise_refs or not self.stateless_rpc:
             if refs:
@@ -603,8 +619,7 @@ class ReceivePackHandler(Handler):
             ref = self.proto.read_pkt_line()
 
         # backend can now deal with this refs and read a pack using self.read
-        status = self.backend.apply_pack(client_refs, self.proto.read,
-                                         self.has_capability('delete-refs'))
+        status = self.repo._apply_pack(client_refs, self.proto.read)
 
         # when we have read all the pack from the client, send a status report
         # if the client asked for it
@@ -633,7 +648,7 @@ class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
         else:
             return
 
-        h = cls(self.server.backend, self.rfile.read, self.wfile.write)
+        h = cls(self.server.backend, args, self.rfile.read, self.wfile.write)
         h.handle()
 
 

+ 4 - 0
dulwich/tests/__init__.py

@@ -21,6 +21,10 @@
 
 import unittest
 
+# XXX: Ideally we should allow other test runners as well, 
+# but unfortunately unittest doesn't have a SkipTest/TestSkipped
+# exception.
+from nose import SkipTest as TestSkipped
 
 def test_suite():
     names = [

+ 15 - 9
dulwich/tests/compat/test_server.py

@@ -26,27 +26,32 @@ On *nix, you can kill the tests with Ctrl-Z, "kill %".
 
 import threading
 
-from dulwich import server
+from dulwich.server import (
+    DictBackend,
+    TCPGitServer,
+    )
+from dulwich.tests import (
+    TestSkipped,
+    )
 from server_utils import (
     ServerTests,
     ShutdownServerMixIn,
     )
 from utils import (
     CompatTestCase,
-    SkipTest,
     )
 
 
-if getattr(server.TCPGitServer, 'shutdown', None):
-    TCPGitServer = server.TCPGitServer
-else:
-    class TCPGitServer(ShutdownServerMixIn, server.TCPGitServer):
+if not getattr(TCPGitServer, 'shutdown', None):
+    _TCPGitServer = TCPGitServer
+
+    class TCPGitServer(ShutdownServerMixIn, TCPGitServer):
         """Subclass of TCPGitServer that can be shut down."""
 
         def __init__(self, *args, **kwargs):
             # BaseServer is old-style so we have to call both __init__s
             ShutdownServerMixIn.__init__(self)
-            server.TCPGitServer.__init__(self, *args, **kwargs)
+            _TCPGitServer.__init__(self, *args, **kwargs)
 
         serve = ShutdownServerMixIn.serve_forever
 
@@ -65,11 +70,12 @@ class GitServerTestCase(ServerTests, CompatTestCase):
         CompatTestCase.tearDown(self)
 
     def _start_server(self, repo):
-        dul_server = TCPGitServer(server.GitBackend(repo), 'localhost', 0)
+        backend = DictBackend({'/': repo})
+        dul_server = TCPGitServer(backend, 'localhost', 0)
         threading.Thread(target=dul_server.serve).start()
         self._server = dul_server
         _, port = self._server.socket.getsockname()
         return port
 
     def test_push_to_dulwich(self):
-        raise SkipTest('Skipping push test due to known deadlock bug.')
+        raise TestSkipped('Skipping push test due to known deadlock bug.')

+ 8 - 5
dulwich/tests/compat/test_web.py

@@ -28,7 +28,10 @@ import threading
 from wsgiref import simple_server
 
 from dulwich.server import (
-    GitBackend,
+    DictBackend,
+    )
+from dulwich.tests import (
+    TestSkipped,
     )
 from dulwich.web import (
     HTTPGitApplication,
@@ -40,7 +43,6 @@ from server_utils import (
     )
 from utils import (
     CompatTestCase,
-    SkipTest,
     )
 
 
@@ -68,7 +70,8 @@ class WebTests(ServerTests):
     protocol = 'http'
 
     def _start_server(self, repo):
-        app = self._make_app(GitBackend(repo))
+        backend = DictBackend({'/': repo})
+        app = self._make_app(backend)
         dul_server = simple_server.make_server('localhost', 0, app,
                                                server_class=WSGIServer)
         threading.Thread(target=dul_server.serve_forever).start()
@@ -95,7 +98,7 @@ class SmartWebTestCase(WebTests, CompatTestCase):
 
     def test_push_to_dulwich(self):
         # TODO(dborowitz): enable after merging thin pack fixes.
-        raise SkipTest('Skipping push test due to known pack bug.')
+        raise TestSkipped('Skipping push test due to known pack bug.')
 
 
 class DumbWebTestCase(WebTests, CompatTestCase):
@@ -114,4 +117,4 @@ class DumbWebTestCase(WebTests, CompatTestCase):
 
     def test_push_to_dulwich(self):
         # Note: remove this if dumb pushing is supported
-        raise SkipTest('Dumb web pushing not supported.')
+        raise TestSkipped('Dumb web pushing not supported.')

+ 5 - 5
dulwich/tests/compat/utils.py

@@ -24,11 +24,11 @@ import subprocess
 import tempfile
 import unittest
 
-# XXX: Ideally we shouldn't depend on nose but allow other testrunners as well.
-from nose import SkipTest
-
 from dulwich.repo import Repo
 
+from dulwich.tests import (
+    TestSkipped,
+    )
 
 _DEFAULT_GIT = 'git'
 
@@ -67,8 +67,8 @@ def require_git_version(required_version, git_path=_DEFAULT_GIT):
     if found_version < required_version:
         required_version = '.'.join(map(str, required_version))
         found_version = '.'.join(map(str, found_version))
-        raise SkipTest('Test requires git >= %s, found %s' %
-                            (required_version, found_version))
+        raise TestSkipped('Test requires git >= %s, found %s' %
+                         (required_version, found_version))
 
 
 def run_git(args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False,

BIN
dulwich/tests/data/blobs/11/11111111111111111111111111111111111111


+ 0 - 0
dulwich/tests/data/blobs/6f670c0fb53f9463760b7295fbb814e965fb20c8 → dulwich/tests/data/blobs/6f/670c0fb53f9463760b7295fbb814e965fb20c8


+ 0 - 0
dulwich/tests/data/blobs/954a536f7819d40e6f637f849ee187dd10066349 → dulwich/tests/data/blobs/95/4a536f7819d40e6f637f849ee187dd10066349


+ 0 - 0
dulwich/tests/data/blobs/e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 → dulwich/tests/data/blobs/e6/9de29bb2d1d6434b8b29ae775ad8c2e48c5391


+ 0 - 0
dulwich/tests/data/commits/0d89f20333fbb1d2f3a94da77f4981373d8f4310 → dulwich/tests/data/commits/0d/89f20333fbb1d2f3a94da77f4981373d8f4310


+ 0 - 0
dulwich/tests/data/commits/5dac377bdded4c9aeb8dff595f0faeebcc8498cc → dulwich/tests/data/commits/5d/ac377bdded4c9aeb8dff595f0faeebcc8498cc


+ 0 - 0
dulwich/tests/data/commits/60dacdc733de308bb77bb76ce0fb0f9b44c9769e → dulwich/tests/data/commits/60/dacdc733de308bb77bb76ce0fb0f9b44c9769e


+ 0 - 0
dulwich/tests/data/tags/71033db03a03c6a36721efcf1968dd8f8e0cf023 → dulwich/tests/data/tags/71/033db03a03c6a36721efcf1968dd8f8e0cf023


+ 0 - 0
dulwich/tests/data/trees/70c190eb48fa8bbb50ddc692a17b44cb781af7f6 → dulwich/tests/data/trees/70/c190eb48fa8bbb50ddc692a17b44cb781af7f6


+ 16 - 7
dulwich/tests/test_object_store.py

@@ -31,6 +31,7 @@ from dulwich.object_store import (
     )
 import os
 import shutil
+import tempfile
 
 
 testobject = Blob()
@@ -39,12 +40,18 @@ testobject.data = "yummy data"
 
 class SpecificDiskObjectStoreTests(TestCase):
 
+    def setUp(self):
+        self.store_dir = tempfile.mkdtemp()
+
+    def tearDown(self):
+        shutil.rmtree(self.store_dir)
+
     def test_pack_dir(self):
-        o = DiskObjectStore("foo")
-        self.assertEquals(os.path.join("foo", "pack"), o.pack_dir)
+        o = DiskObjectStore(self.store_dir)
+        self.assertEquals(os.path.join(self.store_dir, "pack"), o.pack_dir)
 
     def test_empty_packs(self):
-        o = DiskObjectStore("foo")
+        o = DiskObjectStore(self.store_dir)
         self.assertEquals([], o.packs)
 
 
@@ -95,10 +102,12 @@ class DiskObjectStoreTests(ObjectStoreTests,TestCase):
 
     def setUp(self):
         TestCase.setUp(self)
-        if os.path.exists("foo"):
-            shutil.rmtree("foo")
-        os.makedirs(os.path.join("foo", "pack"))
-        self.store = DiskObjectStore("foo")
+        self.store_dir = tempfile.mkdtemp()
+        self.store = DiskObjectStore.init(self.store_dir)
+
+    def tearDown(self):
+        TestCase.tearDown(self)
+        shutil.rmtree(self.store_dir)
 
 
 # TODO: MissingObjectFinderTests

+ 349 - 64
dulwich/tests/test_objects.py

@@ -20,11 +20,18 @@
 
 """Tests for git base objects."""
 
+# TODO: Round-trip parse-serialize-parse and serialize-parse-serialize tests.
 
+
+import datetime
 import os
 import stat
 import unittest
 
+from dulwich.errors import (
+    ChecksumMismatch,
+    ObjectFormatException,
+    )
 from dulwich.objects import (
     Blob,
     Tree,
@@ -32,7 +39,15 @@ from dulwich.objects import (
     Tag,
     format_timezone,
     hex_to_sha,
+    hex_to_filename,
+    check_hexsha,
+    check_identity,
     parse_timezone,
+    parse_tree,
+    _parse_tree_py,
+    )
+from dulwich.tests import (
+    TestSkipped,
     )
 
 a_sha = '6f670c0fb53f9463760b7295fbb814e965fb20c8'
@@ -41,13 +56,46 @@ c_sha = '954a536f7819d40e6f637f849ee187dd10066349'
 tree_sha = '70c190eb48fa8bbb50ddc692a17b44cb781af7f6'
 tag_sha = '71033db03a03c6a36721efcf1968dd8f8e0cf023'
 
+
+try:
+    from itertools import permutations
+except ImportError:
+    # Implementation of permutations from Python 2.6 documentation:
+    # http://docs.python.org/2.6/library/itertools.html#itertools.permutations
+    # Copyright (c) 2001-2010 Python Software Foundation; All Rights Reserved
+    def permutations(iterable, r=None):
+        # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
+        # permutations(range(3)) --> 012 021 102 120 201 210
+        pool = tuple(iterable)
+        n = len(pool)
+        r = n if r is None else r
+        if r > n:
+            return
+        indices = range(n)
+        cycles = range(n, n-r, -1)
+        yield tuple(pool[i] for i in indices[:r])
+        while n:
+            for i in reversed(range(r)):
+                cycles[i] -= 1
+                if cycles[i] == 0:
+                    indices[i:] = indices[i+1:] + indices[i:i+1]
+                    cycles[i] = n - i
+                else:
+                    j = cycles[i]
+                    indices[i], indices[-j] = indices[-j], indices[i]
+                    yield tuple(pool[i] for i in indices[:r])
+                    break
+            else:
+                return
+
+
 class BlobReadTests(unittest.TestCase):
     """Test decompression of blobs"""
-  
-    def get_sha_file(self, obj, base, sha):
-        return obj.from_file(os.path.join(os.path.dirname(__file__),
-                                          'data', base, sha))
-  
+
+    def get_sha_file(self, cls, base, sha):
+        dir = os.path.join(os.path.dirname(__file__), 'data', base)
+        return cls.from_file(hex_to_filename(dir, sha))
+
     def get_blob(self, sha):
         """Return the blob named sha from the test data dir"""
         return self.get_sha_file(Blob, 'blobs', sha)
@@ -162,7 +210,28 @@ class BlobReadTests(unittest.TestCase):
         self.assertEqual(c.commit_timezone, 0)
         self.assertEqual(c.author_timezone, 0)
         self.assertEqual(c.message, 'Merge ../b\n')
-  
+
+    def test_check_id(self):
+        wrong_sha = '1' * 40
+        b = self.get_blob(wrong_sha)
+        self.assertEqual(wrong_sha, b.id)
+        self.assertRaises(ChecksumMismatch, b.check)
+        self.assertEqual('742b386350576589175e374a5706505cbd17680c', b.id)
+
+
+class ShaFileCheckTests(unittest.TestCase):
+
+    def assertCheckFails(self, cls, data):
+        obj = cls()
+        def do_check():
+            obj.set_raw_string(data)
+            obj.check()
+        self.assertRaises(ObjectFormatException, do_check)
+
+    def assertCheckSucceeds(self, cls, data):
+        obj = cls()
+        obj.set_raw_string(data)
+        self.assertEqual(None, obj.check())
 
 
 class CommitSerializationTests(unittest.TestCase):
@@ -219,42 +288,115 @@ class CommitSerializationTests(unittest.TestCase):
         self.assertTrue(" -0100\n" in c.as_raw_string())
 
 
-class CommitDeserializationTests(unittest.TestCase):
+default_committer = 'James Westby <jw+debian@jameswestby.net> 1174773719 +0000'
+
+class CommitParseTests(ShaFileCheckTests):
+
+    def make_commit_lines(self,
+                          tree='d80c186a03f423a81b39df39dc87fd269736ca86',
+                          parents=['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
+                                   '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
+                          author=default_committer,
+                          committer=default_committer,
+                          encoding=None,
+                          message='Merge ../b\n',
+                          extra=None):
+        lines = []
+        if tree is not None:
+            lines.append('tree %s' % tree)
+        if parents is not None:
+            lines.extend('parent %s' % p for p in parents)
+        if author is not None:
+            lines.append('author %s' % author)
+        if committer is not None:
+            lines.append('committer %s' % committer)
+        if encoding is not None:
+            lines.append('encoding %s' % encoding)
+        if extra is not None:
+            for name, value in sorted(extra.iteritems()):
+                lines.append('%s %s' % (name, value))
+        lines.append('')
+        if message is not None:
+            lines.append(message)
+        return lines
+
+    def make_commit_text(self, **kwargs):
+        return '\n'.join(self.make_commit_lines(**kwargs))
 
     def test_simple(self):
-        c = Commit.from_string(
-                'tree d80c186a03f423a81b39df39dc87fd269736ca86\n'
-                'parent ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd\n'
-                'parent 4cffe90e0a41ad3f5190079d7c8f036bde29cbe6\n'
-                'author James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                'committer James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                '\n'
-                'Merge ../b\n')
+        c = Commit.from_string(self.make_commit_text())
         self.assertEquals('Merge ../b\n', c.message)
+        self.assertEquals('James Westby <jw+debian@jameswestby.net>', c.author)
         self.assertEquals('James Westby <jw+debian@jameswestby.net>',
-            c.author)
-        self.assertEquals('James Westby <jw+debian@jameswestby.net>',
-            c.committer)
-        self.assertEquals('d80c186a03f423a81b39df39dc87fd269736ca86',
-            c.tree)
+                          c.committer)
+        self.assertEquals('d80c186a03f423a81b39df39dc87fd269736ca86', c.tree)
         self.assertEquals(['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
-                          '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
-            c.parents)
+                           '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
+                          c.parents)
+        expected_time = datetime.datetime(2007, 3, 24, 22, 1, 59)
+        self.assertEquals(expected_time,
+                          datetime.datetime.utcfromtimestamp(c.commit_time))
+        self.assertEquals(0, c.commit_timezone)
+        self.assertEquals(expected_time,
+                          datetime.datetime.utcfromtimestamp(c.author_time))
+        self.assertEquals(0, c.author_timezone)
+        self.assertEquals(None, c.encoding)
 
     def test_custom(self):
-        c = Commit.from_string(
-                'tree d80c186a03f423a81b39df39dc87fd269736ca86\n'
-                'parent ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd\n'
-                'parent 4cffe90e0a41ad3f5190079d7c8f036bde29cbe6\n'
-                'author James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                'committer James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                'extra-field data\n'
-                '\n'
-                'Merge ../b\n')
+        c = Commit.from_string(self.make_commit_text(
+          extra={'extra-field': 'data'}))
         self.assertEquals([('extra-field', 'data')], c.extra)
 
-
-class TreeSerializationTests(unittest.TestCase):
+    def test_encoding(self):
+        c = Commit.from_string(self.make_commit_text(encoding='UTF-8'))
+        self.assertEquals('UTF-8', c.encoding)
+
+    def test_check(self):
+        self.assertCheckSucceeds(Commit, self.make_commit_text())
+        self.assertCheckSucceeds(Commit, self.make_commit_text(parents=None))
+        self.assertCheckSucceeds(Commit,
+                                 self.make_commit_text(encoding='UTF-8'))
+
+        self.assertCheckFails(Commit, self.make_commit_text(tree='xxx'))
+        self.assertCheckFails(Commit, self.make_commit_text(
+          parents=[a_sha, 'xxx']))
+        bad_committer = "some guy without an email address 1174773719 +0000"
+        self.assertCheckFails(Commit,
+                              self.make_commit_text(committer=bad_committer))
+        self.assertCheckFails(Commit,
+                              self.make_commit_text(author=bad_committer))
+        self.assertCheckFails(Commit, self.make_commit_text(author=None))
+        self.assertCheckFails(Commit, self.make_commit_text(committer=None))
+        self.assertCheckFails(Commit, self.make_commit_text(
+          author=None, committer=None))
+
+    def test_check_duplicates(self):
+        # duplicate each of the header fields
+        for i in xrange(5):
+            lines = self.make_commit_lines(parents=[a_sha], encoding='UTF-8')
+            lines.insert(i, lines[i])
+            text = '\n'.join(lines)
+            if lines[i].startswith('parent'):
+                # duplicate parents are ok for now
+                self.assertCheckSucceeds(Commit, text)
+            else:
+                self.assertCheckFails(Commit, text)
+
+    def test_check_order(self):
+        lines = self.make_commit_lines(parents=[a_sha], encoding='UTF-8')
+        headers = lines[:5]
+        rest = lines[5:]
+        # of all possible permutations, ensure only the original succeeds
+        for perm in permutations(headers):
+            perm = list(perm)
+            text = '\n'.join(perm + rest)
+            if perm == headers:
+                self.assertCheckSucceeds(Commit, text)
+            else:
+                self.assertCheckFails(Commit, text)
+
+
+class TreeTests(ShaFileCheckTests):
 
     def test_simple(self):
         myhexsha = "d80c186a03f423a81b39df39dc87fd269736ca86"
@@ -270,6 +412,57 @@ class TreeSerializationTests(unittest.TestCase):
         x["a/c"] = (stat.S_IFDIR, "d80c186a03f423a81b39df39dc87fd269736ca86")
         self.assertEquals(["a.c", "a", "a/c"], [p[0] for p in x.iteritems()])
 
+    def _do_test_parse_tree(self, parse_tree):
+        dir = os.path.join(os.path.dirname(__file__), 'data', 'trees')
+        o = Tree.from_file(hex_to_filename(dir, tree_sha))
+        o._parse_file()
+        self.assertEquals([('a', 0100644, a_sha), ('b', 0100644, b_sha)],
+                          list(parse_tree(o.as_raw_string())))
+
+    def test_parse_tree(self):
+        self._do_test_parse_tree(_parse_tree_py)
+
+    def test_parse_tree_extension(self):
+        if parse_tree is _parse_tree_py:
+            raise TestSkipped('parse_tree extension not found')
+        self._do_test_parse_tree(parse_tree)
+
+    def test_check(self):
+        t = Tree
+        sha = hex_to_sha(a_sha)
+
+        # filenames
+        self.assertCheckSucceeds(t, '100644 .a\0%s' % sha)
+        self.assertCheckFails(t, '100644 \0%s' % sha)
+        self.assertCheckFails(t, '100644 .\0%s' % sha)
+        self.assertCheckFails(t, '100644 a/a\0%s' % sha)
+        self.assertCheckFails(t, '100644 ..\0%s' % sha)
+
+        # modes
+        self.assertCheckSucceeds(t, '100644 a\0%s' % sha)
+        self.assertCheckSucceeds(t, '100755 a\0%s' % sha)
+        self.assertCheckSucceeds(t, '160000 a\0%s' % sha)
+        # TODO more whitelisted modes
+        self.assertCheckFails(t, '123456 a\0%s' % sha)
+        self.assertCheckFails(t, '123abc a\0%s' % sha)
+
+        # shas
+        self.assertCheckFails(t, '100644 a\0%s' % ('x' * 5))
+        self.assertCheckFails(t, '100644 a\0%s' % ('x' * 18 + '\0'))
+        self.assertCheckFails(t, '100644 a\0%s\n100644 b\0%s' % ('x' * 21, sha))
+
+        # ordering
+        sha2 = hex_to_sha(b_sha)
+        self.assertCheckSucceeds(t, '100644 a\0%s\n100644 b\0%s' % (sha, sha))
+        self.assertCheckSucceeds(t, '100644 a\0%s\n100644 b\0%s' % (sha, sha2))
+        self.assertCheckFails(t, '100644 a\0%s\n100755 a\0%s' % (sha, sha2))
+        self.assertCheckFails(t, '100644 b\0%s\n100644 a\0%s' % (sha2, sha))
+
+    def test_iter(self):
+        t = Tree()
+        t["foo"] = (0100644, a_sha)
+        self.assertEquals(set(["foo"]), set(t))
+
 
 class TagSerializeTests(unittest.TestCase):
 
@@ -278,7 +471,7 @@ class TagSerializeTests(unittest.TestCase):
         x.tagger = "Jelmer Vernooij <jelmer@samba.org>"
         x.name = "0.1"
         x.message = "Tag 0.1"
-        x.object = (3, "d80c186a03f423a81b39df39dc87fd269736ca86")
+        x.object = (Blob, "d80c186a03f423a81b39df39dc87fd269736ca86")
         x.tag_time = 423423423
         x.tag_timezone = 0
         self.assertEquals("""object d80c186a03f423a81b39df39dc87fd269736ca86
@@ -289,16 +482,9 @@ tagger Jelmer Vernooij <jelmer@samba.org> 423423423 +0000
 Tag 0.1""", x.as_raw_string())
 
 
-class TagParseTests(unittest.TestCase):
-
-    def test_parse_ctime(self):
-        x = Tag()
-        x.set_raw_string("""object a38d6181ff27824c79fc7df825164a212eff6a3f
-type commit
-tag v2.6.22-rc7
-tagger Linus Torvalds <torvalds@woody.linux-foundation.org> Sun Jul 1 12:54:34 2007 -0700
-
-Linux 2.6.22-rc7
+default_tagger = ('Linus Torvalds <torvalds@woody.linux-foundation.org> '
+                  '1183319674 -0700')
+default_message = """Linux 2.6.22-rc7
 -----BEGIN PGP SIGNATURE-----
 Version: GnuPG v1.4.7 (GNU/Linux)
 
@@ -306,39 +492,136 @@ iD8DBQBGiAaAF3YsRnbiHLsRAitMAKCiLboJkQECM/jpYsY3WPfvUgLXkACgg3ql
 OK2XeQOiEeXtT76rV4t2WR4=
 =ivrA
 -----END PGP SIGNATURE-----
-""")
-        self.assertEquals("Linus Torvalds <torvalds@woody.linux-foundation.org>", x.tagger)
+"""
+
+
+class TagParseTests(ShaFileCheckTests):
+    def make_tag_lines(self,
+                       object_sha="a38d6181ff27824c79fc7df825164a212eff6a3f",
+                       object_type_name="commit",
+                       name="v2.6.22-rc7",
+                       tagger=default_tagger,
+                       message=default_message):
+        lines = []
+        if object_sha is not None:
+            lines.append("object %s" % object_sha)
+        if object_type_name is not None:
+            lines.append("type %s" % object_type_name)
+        if name is not None:
+            lines.append("tag %s" % name)
+        if tagger is not None:
+            lines.append("tagger %s" % tagger)
+        lines.append("")
+        if message is not None:
+            lines.append(message)
+        return lines
+
+    def make_tag_text(self, **kwargs):
+        return "\n".join(self.make_tag_lines(**kwargs))
+
+    def test_parse(self):
+        x = Tag()
+        x.set_raw_string(self.make_tag_text())
+        self.assertEquals(
+            "Linus Torvalds <torvalds@woody.linux-foundation.org>", x.tagger)
         self.assertEquals("v2.6.22-rc7", x.name)
+        object_type, object_sha = x.object
+        self.assertEquals("a38d6181ff27824c79fc7df825164a212eff6a3f",
+                          object_sha)
+        self.assertEquals(Commit, object_type)
+        self.assertEquals(datetime.datetime.utcfromtimestamp(x.tag_time),
+                          datetime.datetime(2007, 7, 1, 19, 54, 34))
+        self.assertEquals(-25200, x.tag_timezone)
 
     def test_parse_no_tagger(self):
         x = Tag()
-        x.set_raw_string("""object a38d6181ff27824c79fc7df825164a212eff6a3f
-type commit
-tag v2.6.22-rc7
-
-Linux 2.6.22-rc7
------BEGIN PGP SIGNATURE-----
-Version: GnuPG v1.4.7 (GNU/Linux)
-
-iD8DBQBGiAaAF3YsRnbiHLsRAitMAKCiLboJkQECM/jpYsY3WPfvUgLXkACgg3ql
-OK2XeQOiEeXtT76rV4t2WR4=
-=ivrA
------END PGP SIGNATURE-----
-""")
+        x.set_raw_string(self.make_tag_text(tagger=None))
         self.assertEquals(None, x.tagger)
         self.assertEquals("v2.6.22-rc7", x.name)
 
+    def test_check(self):
+        self.assertCheckSucceeds(Tag, self.make_tag_text())
+        self.assertCheckFails(Tag, self.make_tag_text(object_sha=None))
+        self.assertCheckFails(Tag, self.make_tag_text(object_type_name=None))
+        self.assertCheckFails(Tag, self.make_tag_text(name=None))
+        self.assertCheckFails(Tag, self.make_tag_text(name=''))
+        self.assertCheckFails(Tag, self.make_tag_text(
+          object_type_name="foobar"))
+        self.assertCheckFails(Tag, self.make_tag_text(
+          tagger="some guy without an email address 1183319674 -0700"))
+        self.assertCheckFails(Tag, self.make_tag_text(
+          tagger=("Linus Torvalds <torvalds@woody.linux-foundation.org> "
+                  "Sun 7 Jul 2007 12:54:34 +0700")))
+        self.assertCheckFails(Tag, self.make_tag_text(object_sha="xxx"))
+
+    def test_check_duplicates(self):
+        # duplicate each of the header fields
+        for i in xrange(4):
+            lines = self.make_tag_lines()
+            lines.insert(i, lines[i])
+            self.assertCheckFails(Tag, '\n'.join(lines))
+
+    def test_check_order(self):
+        lines = self.make_tag_lines()
+        headers = lines[:4]
+        rest = lines[4:]
+        # of all possible permutations, ensure only the original succeeds
+        for perm in permutations(headers):
+            perm = list(perm)
+            text = '\n'.join(perm + rest)
+            if perm == headers:
+                self.assertCheckSucceeds(Tag, text)
+            else:
+                self.assertCheckFails(Tag, text)
+
+
+class CheckTests(unittest.TestCase):
+
+    def test_check_hexsha(self):
+        check_hexsha(a_sha, "failed to check good sha")
+        self.assertRaises(ObjectFormatException, check_hexsha, '1' * 39,
+                          'sha too short')
+        self.assertRaises(ObjectFormatException, check_hexsha, '1' * 41,
+                          'sha too long')
+        self.assertRaises(ObjectFormatException, check_hexsha, 'x' * 40,
+                          'invalid characters')
+
+    def test_check_identity(self):
+        check_identity("Dave Borowitz <dborowitz@google.com>",
+                       "failed to check good identity")
+        check_identity("<dborowitz@google.com>",
+                       "failed to check good identity")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz", "no email")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <dborowitz", "incomplete email")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "dborowitz@google.com>", "incomplete email")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <<dborowitz@google.com>", "typo")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <dborowitz@google.com>>", "typo")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <dborowitz@google.com>xxx",
+                          "trailing characters")
+
 
 class TimezoneTests(unittest.TestCase):
 
     def test_parse_timezone_utc(self):
-        self.assertEquals(0, parse_timezone("+0000"))
+        self.assertEquals((0, False), parse_timezone("+0000"))
+
+    def test_parse_timezone_utc_negative(self):
+        self.assertEquals((0, True), parse_timezone("-0000"))
 
     def test_generate_timezone_utc(self):
         self.assertEquals("+0000", format_timezone(0))
 
+    def test_generate_timezone_utc_negative(self):
+        self.assertEquals("-0000", format_timezone(0, True))
+
     def test_parse_timezone_cet(self):
-        self.assertEquals(60 * 60, parse_timezone("+0100"))
+        self.assertEquals((60 * 60, False), parse_timezone("+0100"))
 
     def test_format_timezone_cet(self):
         self.assertEquals("+0100", format_timezone(60 * 60))
@@ -347,10 +630,12 @@ class TimezoneTests(unittest.TestCase):
         self.assertEquals("-0400", format_timezone(-4 * 60 * 60))
 
     def test_parse_timezone_pdt(self):
-        self.assertEquals(-4 * 60 * 60, parse_timezone("-0400"))
+        self.assertEquals((-4 * 60 * 60, False), parse_timezone("-0400"))
 
     def test_format_timezone_pdt_half(self):
-        self.assertEquals("-0440", format_timezone(int(((-4 * 60) - 40) * 60)))
+        self.assertEquals("-0440",
+            format_timezone(int(((-4 * 60) - 40) * 60)))
 
     def test_parse_timezone_pdt_half(self):
-        self.assertEquals(((-4 * 60) - 40) * 60, parse_timezone("-0440"))
+        self.assertEquals((((-4 * 60) - 40) * 60, False),
+            parse_timezone("-0440"))

+ 3 - 4
dulwich/tests/test_pack.py

@@ -183,13 +183,13 @@ class TestPack(PackTests):
         """Tests random access for non-delta objects"""
         p = self.get_pack(pack1_sha)
         obj = p[a_sha]
-        self.assertEqual(obj._type, 'blob')
+        self.assertEqual(obj.type_name, 'blob')
         self.assertEqual(obj.sha().hexdigest(), a_sha)
         obj = p[tree_sha]
-        self.assertEqual(obj._type, 'tree')
+        self.assertEqual(obj.type_name, 'tree')
         self.assertEqual(obj.sha().hexdigest(), tree_sha)
         obj = p[commit_sha]
-        self.assertEqual(obj._type, 'commit')
+        self.assertEqual(obj.type_name, 'commit')
         self.assertEqual(obj.sha().hexdigest(), commit_sha)
 
     def test_copy(self):
@@ -285,4 +285,3 @@ class ZlibTests(unittest.TestCase):
     def test_simple_decompress(self):
         self.assertEquals((["tree 4ada885c9196b6b6fa08744b5862bf92896fc002\nparent None\nauthor Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\ncommitter Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\n\nProvide replacement for mmap()'s offset argument."], 158, 'Z'), 
         read_zlib_chunks(StringIO(TEST_COMP1).read, 229))
-

+ 22 - 14
dulwich/tests/test_repository.py

@@ -92,16 +92,16 @@ class RepositoryTests(unittest.TestCase):
     def test_head(self):
         r = self._repo = open_repo('a.git')
         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
-  
+
     def test_get_object(self):
         r = self._repo = open_repo('a.git')
         obj = r.get_object(r.head())
-        self.assertEqual(obj._type, 'commit')
-  
+        self.assertEqual(obj.type_name, 'commit')
+
     def test_get_object_non_existant(self):
         r = self._repo = open_repo('a.git')
         self.assertRaises(KeyError, r.get_object, missing_sha)
-  
+
     def test_commit(self):
         r = self._repo = open_repo('a.git')
         warnings.simplefilter("ignore", DeprecationWarning)
@@ -109,8 +109,8 @@ class RepositoryTests(unittest.TestCase):
             obj = r.commit(r.head())
         finally:
             warnings.resetwarnings()
-        self.assertEqual(obj._type, 'commit')
-  
+        self.assertEqual(obj.type_name, 'commit')
+
     def test_commit_not_commit(self):
         r = self._repo = open_repo('a.git')
         warnings.simplefilter("ignore", DeprecationWarning)
@@ -119,7 +119,7 @@ class RepositoryTests(unittest.TestCase):
                 r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
         finally:
             warnings.resetwarnings()
-  
+
     def test_tree(self):
         r = self._repo = open_repo('a.git')
         commit = r[r.head()]
@@ -128,9 +128,9 @@ class RepositoryTests(unittest.TestCase):
             tree = r.tree(commit.tree)
         finally:
             warnings.resetwarnings()
-        self.assertEqual(tree._type, 'tree')
+        self.assertEqual(tree.type_name, 'tree')
         self.assertEqual(tree.sha().hexdigest(), commit.tree)
-  
+
     def test_tree_not_tree(self):
         r = self._repo = open_repo('a.git')
         warnings.simplefilter("ignore", DeprecationWarning)
@@ -147,10 +147,10 @@ class RepositoryTests(unittest.TestCase):
             tag = r.tag(tag_sha)
         finally:
             warnings.resetwarnings()
-        self.assertEqual(tag._type, 'tag')
+        self.assertEqual(tag.type_name, 'tag')
         self.assertEqual(tag.sha().hexdigest(), tag_sha)
-        obj_type, obj_sha = tag.object
-        self.assertEqual(obj_type, objects.Commit)
+        obj_class, obj_sha = tag.object
+        self.assertEqual(obj_class, objects.Commit)
         self.assertEqual(obj_sha, r.head())
 
     def test_tag_not_tag(self):
@@ -190,9 +190,9 @@ class RepositoryTests(unittest.TestCase):
             blob = r.get_blob(blob_sha)
         finally:
             warnings.resetwarnings()
-        self.assertEqual(blob._type, 'blob')
+        self.assertEqual(blob.type_name, 'blob')
         self.assertEqual(blob.sha().hexdigest(), blob_sha)
-  
+
     def test_get_blob_notblob(self):
         r = self._repo = open_repo('a.git')
         warnings.simplefilter("ignore", DeprecationWarning)
@@ -556,3 +556,11 @@ class RefsContainerTests(unittest.TestCase):
             self._refs.remove_if_equals('refs/tags/refs-0.1',
             'df6800012397fb85c56e7418dd4eb9405dee075c'))
         self.assertRaises(KeyError, lambda: self._refs['refs/tags/refs-0.1'])
+
+    def test_read_ref(self):
+        self.assertEqual('ref: refs/heads/master', self._refs.read_ref("HEAD"))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec', 
+            self._refs.read_ref("refs/heads/packed"))
+        self.assertEqual(None,
+            self._refs.read_ref("nonexistant"))
+

+ 16 - 12
dulwich/tests/test_server.py

@@ -26,12 +26,15 @@ from dulwich.errors import (
     GitProtocolError,
     )
 from dulwich.server import (
-    UploadPackHandler,
+    Backend,
+    DictBackend,
+    BackendRepo,
     Handler,
-    ProtocolGraphWalker,
-    SingleAckGraphWalkerImpl,
     MultiAckGraphWalkerImpl,
     MultiAckDetailedGraphWalkerImpl,
+    ProtocolGraphWalker,
+    SingleAckGraphWalkerImpl,
+    UploadPackHandler,
     )
 
 
@@ -76,7 +79,7 @@ class TestProto(object):
 class HandlerTestCase(TestCase):
 
     def setUp(self):
-        self._handler = Handler(None, None, None)
+        self._handler = Handler(Backend(), None, None)
         self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
         self._handler.required_capabilities = lambda: ('cap2',)
 
@@ -119,7 +122,9 @@ class HandlerTestCase(TestCase):
 class UploadPackHandlerTestCase(TestCase):
 
     def setUp(self):
-        self._handler = UploadPackHandler(None, None, None)
+        self._backend = DictBackend({"/": BackendRepo()})
+        self._handler = UploadPackHandler(self._backend,
+                ["/", "host=lolcathost"], None, None)
         self._handler.proto = TestProto()
 
     def test_progress(self):
@@ -170,11 +175,9 @@ class TestCommit(object):
 
     def __init__(self, sha, parents, commit_time):
         self.id = sha
-        self._parents = parents
+        self.parents = parents
         self.commit_time = commit_time
-
-    def get_parents(self):
-        return self._parents
+        self.type_name = "commit"
 
     def __repr__(self):
         return '%s(%s)' % (self.__class__.__name__, self._sha)
@@ -223,7 +226,8 @@ class ProtocolGraphWalkerTestCase(TestCase):
             }
 
         self._walker = ProtocolGraphWalker(
-            TestUploadPackHandler(self._objects, TestProto()))
+            TestUploadPackHandler(self._objects, TestProto()),
+            self._objects, None)
 
     def test_is_satisfied_no_haves(self):
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))
@@ -275,7 +279,7 @@ class ProtocolGraphWalkerTestCase(TestCase):
             'want %s' % TWO,
             ])
         heads = {'ref1': ONE, 'ref2': TWO, 'ref3': THREE}
-        self._walker.handler.backend.repo.peeled = heads
+        self._walker.get_peeled = heads.get
         self.assertEquals([ONE, TWO], self._walker.determine_wants(heads))
 
         self._walker.proto.set_output(['want %s multi_ack' % FOUR])
@@ -295,7 +299,7 @@ class ProtocolGraphWalkerTestCase(TestCase):
         # advertise branch tips plus tag
         heads = {'ref4': FOUR, 'ref5': FIVE, 'tag6': SIX}
         peeled = {'ref4': FOUR, 'ref5': FIVE, 'tag6': FIVE}
-        self._walker.handler.backend.repo.peeled = peeled
+        self._walker.get_peeled = peeled.get
         self._walker.determine_wants(heads)
         lines = []
         while True:

+ 21 - 18
dulwich/tests/test_web.py

@@ -23,7 +23,6 @@ import re
 from unittest import TestCase
 
 from dulwich.objects import (
-    Tag,
     Blob,
     )
 from dulwich.web import (
@@ -96,15 +95,11 @@ class DumbHandlersTestCase(WebTestCase):
         self._environ['QUERY_STRING'] = ''
 
         class TestTag(object):
-            type = Tag().type
-
-            def __init__(self, sha, obj_type, obj_sha):
+            def __init__(self, sha, obj_class, obj_sha):
                 self.sha = lambda: sha
-                self.object = (obj_type, obj_sha)
+                self.object = (obj_class, obj_sha)
 
         class TestBlob(object):
-            type = Blob().type
-
             def __init__(self, sha):
                 self.sha = lambda: sha
 
@@ -112,9 +107,10 @@ class DumbHandlersTestCase(WebTestCase):
         blob2 = TestBlob('222')
         blob3 = TestBlob('333')
 
-        tag1 = TestTag('aaa', TestBlob.type, '222')
+        tag1 = TestTag('aaa', Blob, '222')
 
         class TestRepo(object):
+
             def __init__(self, objects, peeled):
                 self._objects = dict((o.sha(), o) for o in objects)
                 self._peeled = peeled
@@ -125,6 +121,14 @@ class DumbHandlersTestCase(WebTestCase):
             def __getitem__(self, sha):
                 return self._objects[sha]
 
+            def get_refs(self):
+                return {
+                    'HEAD': '000',
+                    'refs/heads/master': blob1.sha(),
+                    'refs/tags/tag-tag': tag1.sha(),
+                    'refs/tags/blob-tag': blob3.sha(),
+                    }
+
         class TestBackend(object):
             def __init__(self):
                 objects = [blob1, blob2, blob3, tag1]
@@ -135,19 +139,16 @@ class DumbHandlersTestCase(WebTestCase):
                     'refs/tags/blob-tag': blob3.sha(),
                     })
 
-            def get_refs(self):
-                return {
-                    'HEAD': '000',
-                    'refs/heads/master': blob1.sha(),
-                    'refs/tags/tag-tag': tag1.sha(),
-                    'refs/tags/blob-tag': blob3.sha(),
-                    }
+            def open_repository(self, path):
+                assert path == '/'
+                return self.repo
 
+        mat = re.search('.*', '//info/refs')
         self.assertEquals(['111\trefs/heads/master\n',
                            '333\trefs/tags/blob-tag\n',
                            'aaa\trefs/tags/tag-tag\n',
                            '222\trefs/tags/tag-tag^{}\n'],
-                          list(get_info_refs(self._req, TestBackend(), None)))
+                          list(get_info_refs(self._req, TestBackend(), mat)))
 
 
 class SmartHandlersTestCase(WebTestCase):
@@ -163,8 +164,9 @@ class SmartHandlersTestCase(WebTestCase):
                 self._handler.write('pkt-line: %s' % line)
 
     class _TestUploadPackHandler(object):
-        def __init__(self, backend, read, write, stateless_rpc=False,
+        def __init__(self, backend, args, read, write, stateless_rpc=False,
                      advertise_refs=False):
+            self.args = args
             self.read = read
             self.write = write
             self.proto = SmartHandlersTestCase.TestProtocol(self)
@@ -217,7 +219,8 @@ class SmartHandlersTestCase(WebTestCase):
         self._environ['wsgi.input'] = StringIO('foo')
         self._environ['QUERY_STRING'] = 'service=git-upload-pack'
 
-        output = ''.join(get_info_refs(self._req, 'backend', None,
+        mat = re.search('.*', '/git-upload-pack')
+        output = ''.join(get_info_refs(self._req, 'backend', mat,
                                        services=self.services()))
         self.assertEquals(('pkt-line: # service=git-upload-pack\n'
                            'flush-pkt\n'

+ 35 - 13
dulwich/web.py

@@ -21,8 +21,11 @@
 from cStringIO import StringIO
 import re
 import time
-import urlparse
 
+try:
+    from urlparse import parse_qs
+except ImportError:
+    from dulwich.misc import parse_qs
 from dulwich.server import (
     ReceivePackHandler,
     UploadPackHandler,
@@ -33,7 +36,7 @@ HTTP_NOT_FOUND = '404 Not Found'
 HTTP_FORBIDDEN = '403 Forbidden'
 
 
-def date_time_string(self, timestamp=None):
+def date_time_string(timestamp=None):
     # Based on BaseHTTPServer.py in python2.5
     weekdays = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
     months = [None,
@@ -46,6 +49,22 @@ def date_time_string(self, timestamp=None):
             weekdays[wd], day, months[month], year, hh, mm, ss)
 
 
+def url_prefix(mat):
+    """Extract the URL prefix from a regex match.
+
+    :param mat: A regex match object.
+    :returns: The URL prefix, defined as the text before the match in the
+        original string. Normalized to start with one leading slash and end with
+        zero.
+    """
+    return '/' + mat.string[:mat.start()].strip('/')
+
+
+def get_repo(backend, mat):
+    """Get a Repo instance for the given backend and URL regex match."""
+    return backend.open_repository(url_prefix(mat))
+
+
 def send_file(req, f, content_type):
     """Send a file-like object to the request output.
 
@@ -73,13 +92,13 @@ def send_file(req, f, content_type):
 
 def get_text_file(req, backend, mat):
     req.nocache()
-    return send_file(req, backend.repo.get_named_file(mat.group()),
+    return send_file(req, get_repo(backend, mat).get_named_file(mat.group()),
                      'text/plain')
 
 
 def get_loose_object(req, backend, mat):
     sha = mat.group(1) + mat.group(2)
-    object_store = backend.object_store
+    object_store = get_repo(backend, mat).object_store
     if not object_store.contains_loose(sha):
         yield req.not_found('Object not found')
         return
@@ -94,13 +113,13 @@ def get_loose_object(req, backend, mat):
 
 def get_pack_file(req, backend, mat):
     req.cache_forever()
-    return send_file(req, backend.repo.get_named_file(mat.group()),
+    return send_file(req, get_repo(backend, mat).get_named_file(mat.group()),
                      'application/x-git-packed-objects')
 
 
 def get_idx_file(req, backend, mat):
     req.cache_forever()
-    return send_file(req, backend.repo.get_named_file(mat.group()),
+    return send_file(req, get_repo(backend, mat).get_named_file(mat.group()),
                      'application/x-git-packed-objects-toc')
 
 
@@ -109,7 +128,7 @@ default_services = {'git-upload-pack': UploadPackHandler,
 def get_info_refs(req, backend, mat, services=None):
     if services is None:
         services = default_services
-    params = urlparse.parse_qs(req.environ['QUERY_STRING'])
+    params = parse_qs(req.environ['QUERY_STRING'])
     service = params.get('service', [None])[0]
     if service and not req.dumb:
         handler_cls = services.get(service, None)
@@ -120,7 +139,8 @@ def get_info_refs(req, backend, mat, services=None):
         req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
         output = StringIO()
         dummy_input = StringIO()  # GET request, handler doesn't need to read
-        handler = handler_cls(backend, dummy_input.read, output.write,
+        handler = handler_cls(backend, [url_prefix(mat)],
+                              dummy_input.read, output.write,
                               stateless_rpc=True, advertise_refs=True)
         handler.proto.write_pkt_line('# service=%s\n' % service)
         handler.proto.write_pkt_line(None)
@@ -131,18 +151,19 @@ def get_info_refs(req, backend, mat, services=None):
         # TODO: select_getanyfile() (see http-backend.c)
         req.nocache()
         req.respond(HTTP_OK, 'text/plain')
-        refs = backend.get_refs()
+        repo = get_repo(backend, mat)
+        refs = repo.get_refs()
         for name in sorted(refs.iterkeys()):
             # get_refs() includes HEAD as a special case, but we don't want to
             # advertise it
             if name == 'HEAD':
                 continue
             sha = refs[name]
-            o = backend.repo[sha]
+            o = repo[sha]
             if not o:
                 continue
             yield '%s\t%s\n' % (sha, name)
-            peeled_sha = backend.repo.get_peeled(name)
+            peeled_sha = repo.get_peeled(name)
             if peeled_sha != sha:
                 yield '%s\t%s^{}\n' % (peeled_sha, name)
 
@@ -150,7 +171,7 @@ def get_info_refs(req, backend, mat, services=None):
 def get_info_packs(req, backend, mat):
     req.nocache()
     req.respond(HTTP_OK, 'text/plain')
-    for pack in backend.object_store.packs:
+    for pack in get_repo(backend, mat).object_store.packs:
         yield 'P pack-%s.pack\n' % pack.name()
 
 
@@ -195,7 +216,8 @@ def handle_service_request(req, backend, mat, services=None):
     # content-length
     if 'CONTENT_LENGTH' in req.environ:
         input = _LengthLimitedFile(input, int(req.environ['CONTENT_LENGTH']))
-    handler = handler_cls(backend, input.read, output.write, stateless_rpc=True)
+    handler = handler_cls(backend, [url_prefix(mat)], input.read, output.write,
+                          stateless_rpc=True)
     handler.handle()
     yield output.getvalue()