Jelmer Vernooij před 15 roky
rodič
revize
520c83293a
34 změnil soubory, kde provedl 1329 přidání a 578 odebrání
  1. 4 2
      Makefile
  2. 6 0
      NEWS
  3. 4 0
      README
  4. 4 3
      bin/dul-daemon
  5. 2 2
      bin/dul-web
  6. 17 7
      dulwich/_objects.c
  7. 11 4
      dulwich/client.py
  8. 16 13
      dulwich/errors.py
  9. 8 3
      dulwich/misc.py
  10. 49 32
      dulwich/object_store.py
  11. 505 195
      dulwich/objects.py
  12. 70 55
      dulwich/pack.py
  13. 33 20
      dulwich/repo.py
  14. 106 91
      dulwich/server.py
  15. 4 0
      dulwich/tests/__init__.py
  16. 15 9
      dulwich/tests/compat/test_server.py
  17. 8 5
      dulwich/tests/compat/test_web.py
  18. 5 5
      dulwich/tests/compat/utils.py
  19. binární
      dulwich/tests/data/blobs/11/11111111111111111111111111111111111111
  20. 0 0
      dulwich/tests/data/blobs/6f/670c0fb53f9463760b7295fbb814e965fb20c8
  21. 0 0
      dulwich/tests/data/blobs/95/4a536f7819d40e6f637f849ee187dd10066349
  22. 0 0
      dulwich/tests/data/blobs/e6/9de29bb2d1d6434b8b29ae775ad8c2e48c5391
  23. 0 0
      dulwich/tests/data/commits/0d/89f20333fbb1d2f3a94da77f4981373d8f4310
  24. 0 0
      dulwich/tests/data/commits/5d/ac377bdded4c9aeb8dff595f0faeebcc8498cc
  25. 0 0
      dulwich/tests/data/commits/60/dacdc733de308bb77bb76ce0fb0f9b44c9769e
  26. 0 0
      dulwich/tests/data/tags/71/033db03a03c6a36721efcf1968dd8f8e0cf023
  27. 0 0
      dulwich/tests/data/trees/70/c190eb48fa8bbb50ddc692a17b44cb781af7f6
  28. 16 7
      dulwich/tests/test_object_store.py
  29. 349 64
      dulwich/tests/test_objects.py
  30. 3 4
      dulwich/tests/test_pack.py
  31. 22 14
      dulwich/tests/test_repository.py
  32. 16 12
      dulwich/tests/test_server.py
  33. 21 18
      dulwich/tests/test_web.py
  34. 35 13
      dulwich/web.py

+ 4 - 2
Makefile

@@ -2,6 +2,7 @@ PYTHON = python
 SETUP = $(PYTHON) setup.py
 SETUP = $(PYTHON) setup.py
 PYDOCTOR ?= pydoctor
 PYDOCTOR ?= pydoctor
 TESTRUNNER = $(shell which nosetests)
 TESTRUNNER = $(shell which nosetests)
+TESTFLAGS =
 
 
 all: build
 all: build
 
 
@@ -19,12 +20,13 @@ install::
 
 
 check:: build
 check:: build
 	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) dulwich
 	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) dulwich
+	which git > /dev/null && PYTHONPATH=. $(PYTHON) $(TESTRUNNER) $(TESTFLAGS) -i compat
 
 
 check-noextensions:: clean
 check-noextensions:: clean
-	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) dulwich
+	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) $(TESTFLAGS) dulwich
 
 
 check-compat:: build
 check-compat:: build
-	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) -i compat
+	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) $(TESTFLAGS) -i compat
 
 
 clean::
 clean::
 	$(SETUP) clean --all
 	$(SETUP) clean --all

+ 6 - 0
NEWS

@@ -13,6 +13,9 @@
 
 
   * Implement RefsContainer.__contains__. (Jelmer Vernooij)
   * Implement RefsContainer.__contains__. (Jelmer Vernooij)
 
 
+  * Cope with \r in ref files on Windows. (
+	http://github.com/jelmer/dulwich/issues/#issue/13, Jelmer Vernooij)
+
  FEATURES
  FEATURES
 
 
   * Add include-tag capability to server. (Dave Borowitz)
   * Add include-tag capability to server. (Dave Borowitz)
@@ -29,6 +32,9 @@
   * Repo.get_blob, Repo.commit, Repo.tag and Repo.tree are now deprecated.
   * Repo.get_blob, Repo.commit, Repo.tag and Repo.tree are now deprecated.
     (Jelmer Vernooij)
     (Jelmer Vernooij)
 
 
+  * RefsContainer.set_ref() was renamed to RefsContainer.set_symbolic_ref(),
+    for clarity. (Jelmer Vernooij)
+
  API CHANGES
  API CHANGES
 
 
   * Blob.chunked was added. (Jelmer Vernooij)
   * Blob.chunked was added. (Jelmer Vernooij)

+ 4 - 0
README

@@ -21,3 +21,7 @@ The project is named after the part of London that Mr. and Mrs. Git live in
 in the particular Monty Python sketch. It is based on the Python-Git module 
 in the particular Monty Python sketch. It is based on the Python-Git module 
 that James Westby <jw+debian@jameswestby.net> released in 2007 and now 
 that James Westby <jw+debian@jameswestby.net> released in 2007 and now 
 maintained by Jelmer Vernooij and John Carr.
 maintained by Jelmer Vernooij and John Carr.
+
+Please file bugs in the Dulwich project on Launchpad: 
+
+https://bugs.launchpad.net/dulwich/+filebug

+ 4 - 3
bin/dul-daemon

@@ -19,13 +19,14 @@
 
 
 import sys
 import sys
 from dulwich.repo import Repo
 from dulwich.repo import Repo
-from dulwich.server import GitBackend, TCPGitServer
+from dulwich.server import DictBackend, TCPGitServer
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    gitdir = None
     if len(sys.argv) > 1:
     if len(sys.argv) > 1:
         gitdir = sys.argv[1]
         gitdir = sys.argv[1]
+    else:
+        gitdir = "."
 
 
-    backend = GitBackend(Repo(gitdir))
+    backend = DictBackend({"/": Repo(gitdir)})
     server = TCPGitServer(backend, 'localhost')
     server = TCPGitServer(backend, 'localhost')
     server.serve_forever()
     server.serve_forever()

+ 2 - 2
bin/dul-web

@@ -20,7 +20,7 @@
 import os
 import os
 import sys
 import sys
 from dulwich.repo import Repo
 from dulwich.repo import Repo
-from dulwich.server import GitBackend
+from dulwich.server import DictBackend
 from dulwich.web import HTTPGitApplication
 from dulwich.web import HTTPGitApplication
 from wsgiref.simple_server import make_server
 from wsgiref.simple_server import make_server
 
 
@@ -30,7 +30,7 @@ if __name__ == "__main__":
     else:
     else:
         gitdir = os.getcwd()
         gitdir = os.getcwd()
 
 
-    backend = GitBackend(Repo(gitdir))
+    backend = DictBackend({"/": Repo(gitdir)})
     app = HTTPGitApplication(backend)
     app = HTTPGitApplication(backend)
     # TODO: allow serving on other ports via command-line flag
     # TODO: allow serving on other ports via command-line flag
     server = make_server('', 8000, app)
     server = make_server('', 8000, app)

+ 17 - 7
dulwich/_objects.c

@@ -37,18 +37,22 @@ static PyObject *sha_to_pyhex(const unsigned char *sha)
 
 
 static PyObject *py_parse_tree(PyObject *self, PyObject *args)
 static PyObject *py_parse_tree(PyObject *self, PyObject *args)
 {
 {
-	char *text, *end;
+	char *text, *start, *end;
 	int len, namelen;
 	int len, namelen;
 	PyObject *ret, *item, *name;
 	PyObject *ret, *item, *name;
 
 
 	if (!PyArg_ParseTuple(args, "s#", &text, &len))
 	if (!PyArg_ParseTuple(args, "s#", &text, &len))
 		return NULL;
 		return NULL;
 
 
-	ret = PyDict_New();
+	/* TODO: currently this returns a list; if memory usage is a concern,
+	* consider rewriting as a custom iterator object */
+	ret = PyList_New(0);
+
 	if (ret == NULL) {
 	if (ret == NULL) {
 		return NULL;
 		return NULL;
 	}
 	}
 
 
+	start = text;
 	end = text + len;
 	end = text + len;
 
 
 	while (text < end) {
 	while (text < end) {
@@ -56,14 +60,14 @@ static PyObject *py_parse_tree(PyObject *self, PyObject *args)
 		mode = strtol(text, &text, 8);
 		mode = strtol(text, &text, 8);
 
 
 		if (*text != ' ') {
 		if (*text != ' ') {
-			PyErr_SetString(PyExc_RuntimeError, "Expected space");
+			PyErr_SetString(PyExc_ValueError, "Expected space");
 			Py_DECREF(ret);
 			Py_DECREF(ret);
 			return NULL;
 			return NULL;
 		}
 		}
 
 
 		text++;
 		text++;
 
 
-		namelen = strlen(text);
+		namelen = strnlen(text, len - (text - start));
 
 
 		name = PyString_FromStringAndSize(text, namelen);
 		name = PyString_FromStringAndSize(text, namelen);
 		if (name == NULL) {
 		if (name == NULL) {
@@ -71,19 +75,25 @@ static PyObject *py_parse_tree(PyObject *self, PyObject *args)
 			return NULL;
 			return NULL;
 		}
 		}
 
 
-		item = Py_BuildValue("(lN)", mode,
+		if (text + namelen + 20 >= end) {
+			PyErr_SetString(PyExc_ValueError, "SHA truncated");
+			Py_DECREF(ret);
+			Py_DECREF(name);
+			return NULL;
+		}
+
+		item = Py_BuildValue("(NlN)", name, mode,
 							 sha_to_pyhex((unsigned char *)text+namelen+1));
 							 sha_to_pyhex((unsigned char *)text+namelen+1));
 		if (item == NULL) {
 		if (item == NULL) {
 			Py_DECREF(ret);
 			Py_DECREF(ret);
 			Py_DECREF(name);
 			Py_DECREF(name);
 			return NULL;
 			return NULL;
 		}
 		}
-		if (PyDict_SetItem(ret, name, item) == -1) {
+		if (PyList_Append(ret, item) == -1) {
 			Py_DECREF(ret);
 			Py_DECREF(ret);
 			Py_DECREF(item);
 			Py_DECREF(item);
 			return NULL;
 			return NULL;
 		}
 		}
-		Py_DECREF(name);
 		Py_DECREF(item);
 		Py_DECREF(item);
 
 
 		text += namelen+21;
 		text += namelen+21;

+ 11 - 4
dulwich/client.py

@@ -28,6 +28,7 @@ import subprocess
 
 
 from dulwich.errors import (
 from dulwich.errors import (
     ChecksumMismatch,
     ChecksumMismatch,
+    HangupException,
     )
     )
 from dulwich.protocol import (
 from dulwich.protocol import (
     Protocol,
     Protocol,
@@ -119,10 +120,16 @@ class GitClient(object):
                                          len(objects))
                                          len(objects))
         
         
         # read the final confirmation sha
         # read the final confirmation sha
-        client_sha = self.proto.read(20)
-        if not client_sha in (None, "", sha):
-            raise ChecksumMismatch(sha, client_sha)
-            
+        try:
+            client_sha = self.proto.read_pkt_line()
+        except HangupException:
+            # for git-daemon versions before v1.6.6.1-26-g38a81b4, there is
+            # nothing to read; catch this and hide from the user.
+            pass
+        else:
+            if not client_sha in (None, "", sha):
+                raise ChecksumMismatch(sha, client_sha)
+
         return new_refs
         return new_refs
 
 
     def fetch(self, path, target, determine_wants=None, progress=None):
     def fetch(self, path, target, determine_wants=None, progress=None):

+ 16 - 13
dulwich/errors.py

@@ -37,40 +37,39 @@ class ChecksumMismatch(Exception):
 
 
 class WrongObjectException(Exception):
 class WrongObjectException(Exception):
     """Baseclass for all the _ is not a _ exceptions on objects.
     """Baseclass for all the _ is not a _ exceptions on objects.
-  
+
     Do not instantiate directly.
     Do not instantiate directly.
-  
-    Subclasses should define a _type attribute that indicates what
+
+    Subclasses should define a type_name attribute that indicates what
     was expected if they were raised.
     was expected if they were raised.
     """
     """
-  
+
     def __init__(self, sha, *args, **kwargs):
     def __init__(self, sha, *args, **kwargs):
-        string = "%s is not a %s" % (sha, self._type)
-        Exception.__init__(self, string)
+        Exception.__init__(self, "%s is not a %s" % (sha, self.type_name))
 
 
 
 
 class NotCommitError(WrongObjectException):
 class NotCommitError(WrongObjectException):
     """Indicates that the sha requested does not point to a commit."""
     """Indicates that the sha requested does not point to a commit."""
-  
-    _type = 'commit'
+
+    type_name = 'commit'
 
 
 
 
 class NotTreeError(WrongObjectException):
 class NotTreeError(WrongObjectException):
     """Indicates that the sha requested does not point to a tree."""
     """Indicates that the sha requested does not point to a tree."""
-  
-    _type = 'tree'
+
+    type_name = 'tree'
 
 
 
 
 class NotTagError(WrongObjectException):
 class NotTagError(WrongObjectException):
     """Indicates that the sha requested does not point to a tag."""
     """Indicates that the sha requested does not point to a tag."""
 
 
-    _type = 'tag'
+    type_name = 'tag'
 
 
 
 
 class NotBlobError(WrongObjectException):
 class NotBlobError(WrongObjectException):
     """Indicates that the sha requested does not point to a blob."""
     """Indicates that the sha requested does not point to a blob."""
-  
-    _type = 'blob'
+
+    type_name = 'blob'
 
 
 
 
 class MissingCommitError(Exception):
 class MissingCommitError(Exception):
@@ -124,5 +123,9 @@ class PackedRefsException(FileFormatException):
     """Indicates an error parsing a packed-refs file."""
     """Indicates an error parsing a packed-refs file."""
 
 
 
 
+class ObjectFormatException(FileFormatException):
+    """Indicates an error parsing an object."""
+
+
 class NoIndexPresent(Exception):
 class NoIndexPresent(Exception):
     """No index is present."""
     """No index is present."""

+ 8 - 3
dulwich/misc.py

@@ -15,15 +15,21 @@
 # along with this program; if not, write to the Free Software
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # MA  02110-1301, USA.
 # MA  02110-1301, USA.
-"""Misc utilities to work with python2.4.
+"""Misc utilities to work with python <2.6.
 
 
 These utilities can all be deleted when dulwich decides it wants to stop
 These utilities can all be deleted when dulwich decides it wants to stop
-support for python 2.4.
+support for python <2.6.
 """
 """
 try:
 try:
     import hashlib
     import hashlib
 except ImportError:
 except ImportError:
     import sha
     import sha
+
+try:
+    from urlparse import parse_qs
+except ImportError:
+    from cgi import parse_qs
+
 import struct
 import struct
 
 
 
 
@@ -87,4 +93,3 @@ def unpack_from(fmt, buf, offset=0):
     except AttributeError:
     except AttributeError:
         b = buf[offset:offset+struct.calcsize(fmt)]
         b = buf[offset:offset+struct.calcsize(fmt)]
         return struct.unpack(fmt, b)
         return struct.unpack(fmt, b)
-

+ 49 - 32
dulwich/object_store.py

@@ -23,6 +23,7 @@
 import errno
 import errno
 import itertools
 import itertools
 import os
 import os
+import posixpath
 import stat
 import stat
 import tempfile
 import tempfile
 import urllib2
 import urllib2
@@ -38,6 +39,7 @@ from dulwich.objects import (
     Tree,
     Tree,
     hex_to_sha,
     hex_to_sha,
     sha_to_hex,
     sha_to_hex,
+    hex_to_filename,
     S_ISGITLINK,
     S_ISGITLINK,
     )
     )
 from dulwich.pack import (
 from dulwich.pack import (
@@ -91,14 +93,14 @@ class BaseObjectStore(object):
         """Obtain the raw text for an object.
         """Obtain the raw text for an object.
 
 
         :param name: sha for the object.
         :param name: sha for the object.
-        :return: tuple with object type and object contents.
+        :return: tuple with numeric type and object contents.
         """
         """
         raise NotImplementedError(self.get_raw)
         raise NotImplementedError(self.get_raw)
 
 
     def __getitem__(self, sha):
     def __getitem__(self, sha):
         """Obtain an object by SHA1."""
         """Obtain an object by SHA1."""
-        type, uncomp = self.get_raw(sha)
-        return ShaFile.from_raw_string(type, uncomp)
+        type_num, uncomp = self.get_raw(sha)
+        return ShaFile.from_raw_string(type_num, uncomp)
 
 
     def __iter__(self):
     def __iter__(self):
         """Iterate over the SHAs that are present in this store."""
         """Iterate over the SHAs that are present in this store."""
@@ -137,10 +139,7 @@ class BaseObjectStore(object):
             else:
             else:
                 ttree = {}
                 ttree = {}
             for name, oldmode, oldhexsha in stree.iteritems():
             for name, oldmode, oldhexsha in stree.iteritems():
-                if path == "":
-                    oldchildpath = name
-                else:
-                    oldchildpath = "%s/%s" % (path, name)
+                oldchildpath = posixpath.join(path, name)
                 try:
                 try:
                     (newmode, newhexsha) = ttree[name]
                     (newmode, newhexsha) = ttree[name]
                     newchildpath = oldchildpath
                     newchildpath = oldchildpath
@@ -166,10 +165,7 @@ class BaseObjectStore(object):
                             yield ((oldchildpath, newchildpath), (oldmode, newmode), (oldhexsha, newhexsha))
                             yield ((oldchildpath, newchildpath), (oldmode, newmode), (oldhexsha, newhexsha))
 
 
             for name, newmode, newhexsha in ttree.iteritems():
             for name, newmode, newhexsha in ttree.iteritems():
-                if path == "":
-                    childpath = name
-                else:
-                    childpath = "%s/%s" % (path, name)
+                childpath = posixpath.join(path, name)
                 if not name in stree:
                 if not name in stree:
                     if not stat.S_ISDIR(newmode):
                     if not stat.S_ISDIR(newmode):
                         yield ((None, childpath), (None, newmode), (None, newhexsha))
                         yield ((None, childpath), (None, newmode), (None, newhexsha))
@@ -186,10 +182,7 @@ class BaseObjectStore(object):
             (tid, tpath) = todo.pop()
             (tid, tpath) = todo.pop()
             tree = self[tid]
             tree = self[tid]
             for name, mode, hexsha in tree.iteritems(): 
             for name, mode, hexsha in tree.iteritems(): 
-                if tpath == "":
-                    path = name
-                else:
-                    path = "%s/%s" % (tpath, name)
+                path = posixpath.join(tpath, name)
                 if stat.S_ISDIR(mode):
                 if stat.S_ISDIR(mode):
                     todo.add((hexsha, path))
                     todo.add((hexsha, path))
                 else:
                 else:
@@ -233,13 +226,14 @@ class BaseObjectStore(object):
         """
         """
         return ObjectStoreGraphWalker(heads, lambda sha: self[sha].parents)
         return ObjectStoreGraphWalker(heads, lambda sha: self[sha].parents)
 
 
-    def generate_pack_contents(self, have, want):
+    def generate_pack_contents(self, have, want, progress=None):
         """Iterate over the contents of a pack file.
         """Iterate over the contents of a pack file.
 
 
         :param have: List of SHA1s of objects that should not be sent
         :param have: List of SHA1s of objects that should not be sent
         :param want: List of SHA1s of objects that should be sent
         :param want: List of SHA1s of objects that should be sent
+        :param progress: Optional progress reporting method
         """
         """
-        return self.iter_shas(self.find_missing_objects(have, want))
+        return self.iter_shas(self.find_missing_objects(have, want, progress))
 
 
 
 
 class PackBasedObjectStore(BaseObjectStore):
 class PackBasedObjectStore(BaseObjectStore):
@@ -292,9 +286,9 @@ class PackBasedObjectStore(BaseObjectStore):
 
 
     def get_raw(self, name):
     def get_raw(self, name):
         """Obtain the raw text for an object.
         """Obtain the raw text for an object.
-        
+
         :param name: sha for the object.
         :param name: sha for the object.
-        :return: tuple with object type and object contents.
+        :return: tuple with numeric type and object contents.
         """
         """
         if len(name) == 40:
         if len(name) == 40:
             sha = hex_to_sha(name)
             sha = hex_to_sha(name)
@@ -313,7 +307,7 @@ class PackBasedObjectStore(BaseObjectStore):
             hexsha = sha_to_hex(name)
             hexsha = sha_to_hex(name)
         ret = self._get_loose_object(hexsha)
         ret = self._get_loose_object(hexsha)
         if ret is not None:
         if ret is not None:
-            return ret.type, ret.as_raw_string()
+            return ret.type_num, ret.as_raw_string()
         raise KeyError(hexsha)
         raise KeyError(hexsha)
 
 
     def add_objects(self, objects):
     def add_objects(self, objects):
@@ -369,10 +363,8 @@ class DiskObjectStore(PackBasedObjectStore):
             raise
             raise
 
 
     def _get_shafile_path(self, sha):
     def _get_shafile_path(self, sha):
-        dir = sha[:2]
-        file = sha[2:]
         # Check from object dir
         # Check from object dir
-        return os.path.join(self.path, dir, file)
+        return hex_to_filename(self.path, sha)
 
 
     def _iter_loose_objects(self):
     def _iter_loose_objects(self):
         for base in os.listdir(self.path):
         for base in os.listdir(self.path):
@@ -385,7 +377,7 @@ class DiskObjectStore(PackBasedObjectStore):
         path = self._get_shafile_path(sha)
         path = self._get_shafile_path(sha)
         try:
         try:
             return ShaFile.from_file(path)
             return ShaFile.from_file(path)
-        except OSError, e:
+        except (OSError, IOError), e:
             if e.errno == errno.ENOENT:
             if e.errno == errno.ENOENT:
                 return None
                 return None
             raise
             raise
@@ -484,6 +476,17 @@ class DiskObjectStore(PackBasedObjectStore):
         finally:
         finally:
             f.close()
             f.close()
 
 
+    @classmethod
+    def init(cls, path):
+        try:
+            os.mkdir(path)
+        except OSError, e:
+            if e.errno != errno.EEXIST:
+                raise
+        os.mkdir(os.path.join(path, "info"))
+        os.mkdir(os.path.join(path, PACKDIR))
+        return cls(path)
+
 
 
 class MemoryObjectStore(BaseObjectStore):
 class MemoryObjectStore(BaseObjectStore):
     """Object store that keeps all objects in memory."""
     """Object store that keeps all objects in memory."""
@@ -511,9 +514,9 @@ class MemoryObjectStore(BaseObjectStore):
 
 
     def get_raw(self, name):
     def get_raw(self, name):
         """Obtain the raw text for an object.
         """Obtain the raw text for an object.
-        
+
         :param name: sha for the object.
         :param name: sha for the object.
-        :return: tuple with object type and object contents.
+        :return: tuple with numeric type and object contents.
         """
         """
         return self[name].as_raw_string()
         return self[name].as_raw_string()
 
 
@@ -629,7 +632,7 @@ def tree_lookup_path(lookup_obj, root_sha, path):
     mode = None
     mode = None
     for p in parts:
     for p in parts:
         obj = lookup_obj(sha)
         obj = lookup_obj(sha)
-        if type(obj) is not Tree:
+        if not isinstance(obj, Tree):
             raise NotTreeError(sha)
             raise NotTreeError(sha)
         if p == '':
         if p == '':
             continue
             continue
@@ -713,11 +716,25 @@ class ObjectStoreGraphWalker(object):
 
 
     def ack(self, sha):
     def ack(self, sha):
         """Ack that a revision and its ancestors are present in the source."""
         """Ack that a revision and its ancestors are present in the source."""
-        if sha in self.heads:
-            self.heads.remove(sha)
-        if sha in self.parents:
-            for p in self.parents[sha]:
-                self.ack(p)
+        ancestors = set([sha])
+
+        # stop if we run out of heads to remove
+        while self.heads:
+            for a in ancestors:
+                if a in self.heads:
+                    self.heads.remove(a)
+
+            # collect all ancestors
+            new_ancestors = set()
+            for a in ancestors:
+                if a in self.parents:
+                    new_ancestors.update(self.parents[a])
+
+            # no more ancestors; stop
+            if not new_ancestors:
+                break
+
+            ancestors = new_ancestors
 
 
     def next(self):
     def next(self):
         """Iterate over ancestors of heads in the target."""
         """Iterate over ancestors of heads in the target."""

+ 505 - 195
dulwich/objects.py

@@ -28,36 +28,43 @@ from cStringIO import (
 import mmap
 import mmap
 import os
 import os
 import stat
 import stat
-import time
 import zlib
 import zlib
 
 
 from dulwich.errors import (
 from dulwich.errors import (
+    ChecksumMismatch,
     NotBlobError,
     NotBlobError,
     NotCommitError,
     NotCommitError,
+    NotTagError,
     NotTreeError,
     NotTreeError,
+    ObjectFormatException,
     )
     )
 from dulwich.file import GitFile
 from dulwich.file import GitFile
 from dulwich.misc import (
 from dulwich.misc import (
     make_sha,
     make_sha,
     )
     )
 
 
-BLOB_ID = "blob"
-TAG_ID = "tag"
-TREE_ID = "tree"
-COMMIT_ID = "commit"
-PARENT_ID = "parent"
-AUTHOR_ID = "author"
-COMMITTER_ID = "committer"
-OBJECT_ID = "object"
-TYPE_ID = "type"
-TAGGER_ID = "tagger"
-ENCODING_ID = "encoding"
+
+# Header fields for commits
+_TREE_HEADER = "tree"
+_PARENT_HEADER = "parent"
+_AUTHOR_HEADER = "author"
+_COMMITTER_HEADER = "committer"
+_ENCODING_HEADER = "encoding"
+
+
+# Header fields for objects
+_OBJECT_HEADER = "object"
+_TYPE_HEADER = "type"
+_TAG_HEADER = "tag"
+_TAGGER_HEADER = "tagger"
+
 
 
 S_IFGITLINK = 0160000
 S_IFGITLINK = 0160000
 
 
 def S_ISGITLINK(m):
 def S_ISGITLINK(m):
     return (stat.S_IFMT(m) == S_IFGITLINK)
     return (stat.S_IFMT(m) == S_IFGITLINK)
 
 
+
 def _decompress(string):
 def _decompress(string):
     dcomp = zlib.decompressobj()
     dcomp = zlib.decompressobj()
     dcomped = dcomp.decompress(string)
     dcomped = dcomp.decompress(string)
@@ -78,6 +85,27 @@ def hex_to_sha(hex):
     return binascii.unhexlify(hex)
     return binascii.unhexlify(hex)
 
 
 
 
+def hex_to_filename(path, hex):
+    """Takes a hex sha and returns its filename relative to the given path."""
+    dir = hex[:2]
+    file = hex[2:]
+    # Check from object dir
+    return os.path.join(path, dir, file)
+
+
+def filename_to_hex(filename):
+    """Takes an object filename and returns its corresponding hex sha."""
+    # grab the last (up to) two path components
+    names = filename.rsplit(os.path.sep, 2)[-2:]
+    errmsg = "Invalid object filename: %s" % filename
+    assert len(names) == 2, errmsg
+    base, rest = names
+    assert len(base) == 2 and len(rest) == 38, errmsg
+    hex = base + rest
+    hex_to_sha(hex)
+    return hex
+
+
 def serializable_property(name, docstring=None):
 def serializable_property(name, docstring=None):
     def set(obj, value):
     def set(obj, value):
         obj._ensure_parsed()
         obj._ensure_parsed()
@@ -89,44 +117,100 @@ def serializable_property(name, docstring=None):
     return property(get, set, doc=docstring)
     return property(get, set, doc=docstring)
 
 
 
 
+def object_class(type):
+    """Get the object class corresponding to the given type.
+
+    :param type: Either a type name string or a numeric type.
+    :return: The ShaFile subclass corresponding to the given type, or None if
+        type is not a valid type name/number.
+    """
+    return _TYPE_MAP.get(type, None)
+
+
+def check_hexsha(hex, error_msg):
+    try:
+        hex_to_sha(hex)
+    except (TypeError, AssertionError):
+        raise ObjectFormatException("%s %s" % (error_msg, hex))
+
+
+def check_identity(identity, error_msg):
+    email_start = identity.find("<")
+    email_end = identity.find(">")
+    if (email_start < 0 or email_end < 0 or email_end <= email_start
+        or identity.find("<", email_start + 1) >= 0
+        or identity.find(">", email_end + 1) >= 0
+        or not identity.endswith(">")):
+        raise ObjectFormatException(error_msg)
+
+
+class FixedSha(object):
+    """SHA object that behaves like hashlib's but is given a fixed value."""
+
+    def __init__(self, hexsha):
+        self._hexsha = hexsha
+        self._sha = hex_to_sha(hexsha)
+
+    def digest(self):
+        return self._sha
+
+    def hexdigest(self):
+        return self._hexsha
+
+
 class ShaFile(object):
 class ShaFile(object):
     """A git SHA file."""
     """A git SHA file."""
 
 
-    @classmethod
-    def _parse_legacy_object(cls, map):
-        """Parse a legacy object, creating it and setting object._text"""
-        text = _decompress(map)
-        object = None
-        for posstype in type_map.keys():
-            if text.startswith(posstype):
-                object = type_map[posstype]()
-                text = text[len(posstype):]
-                break
-        assert object is not None, "%s is not a known object type" % text[:9]
-        assert text[0] == ' ', "%s is not a space" % text[0]
-        text = text[1:]
-        size = 0
-        i = 0
-        while text[0] >= '0' and text[0] <= '9':
-            if i > 0 and size == 0:
-                raise AssertionError("Size is not in canonical format")
-            size = (size * 10) + int(text[0])
-            text = text[1:]
-            i += 1
-        object._size = size
-        assert text[0] == "\0", "Size not followed by null"
-        text = text[1:]
-        object.set_raw_string(text)
-        return object
+    @staticmethod
+    def _parse_legacy_object_header(magic, f):
+        """Parse a legacy object, creating it but not reading the file."""
+        bufsize = 1024
+        decomp = zlib.decompressobj()
+        header = decomp.decompress(magic)
+        start = 0
+        end = -1
+        while end < 0:
+            header += decomp.decompress(f.read(bufsize))
+            end = header.find("\0", start)
+            start = len(header)
+        header = header[:end]
+        type_name, size = header.split(" ", 1)
+        size = int(size)  # sanity check
+        obj_class = object_class(type_name)
+        if not obj_class:
+            raise ObjectFormatException("Not a known type: %s" % type_name)
+        obj = obj_class()
+        obj._filename = f.name
+        return obj
+
+    def _parse_legacy_object(self, f):
+        """Parse a legacy object, setting the raw string."""
+        size = os.path.getsize(f.name)
+        map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
+        try:
+            text = _decompress(map)
+        finally:
+            map.close()
+        header_end = text.find('\0')
+        if header_end < 0:
+            raise ObjectFormatException("Invalid object header")
+        self.set_raw_string(text[header_end+1:])
+
+    def as_legacy_object_chunks(self):
+        compobj = zlib.compressobj()
+        yield compobj.compress(self._header())
+        for chunk in self.as_raw_chunks():
+            yield compobj.compress(chunk)
+        yield compobj.flush()
 
 
     def as_legacy_object(self):
     def as_legacy_object(self):
-        text = self.as_raw_string()
-        return zlib.compress("%s %d\0%s" % (self._type, len(text), text))
+        return "".join(self.as_legacy_object_chunks())
 
 
     def as_raw_chunks(self):
     def as_raw_chunks(self):
-        if self._needs_serialization:
+        if self._needs_parsing:
+            self._ensure_parsed()
+        elif self._needs_serialization:
             self._chunked_text = self._serialize()
             self._chunked_text = self._serialize()
-            self._needs_serialization = False
         return self._chunked_text
         return self._chunked_text
 
 
     def as_raw_string(self):
     def as_raw_string(self):
@@ -143,6 +227,9 @@ class ShaFile(object):
 
 
     def _ensure_parsed(self):
     def _ensure_parsed(self):
         if self._needs_parsing:
         if self._needs_parsing:
+            if not self._chunked_text:
+                assert self._filename, "ShaFile needs either text or filename"
+                self._parse_file()
             self._deserialize(self._chunked_text)
             self._deserialize(self._chunked_text)
             self._needs_parsing = False
             self._needs_parsing = False
 
 
@@ -153,39 +240,60 @@ class ShaFile(object):
 
 
     def set_raw_chunks(self, chunks):
     def set_raw_chunks(self, chunks):
         self._chunked_text = chunks
         self._chunked_text = chunks
+        self._deserialize(chunks)
         self._sha = None
         self._sha = None
-        self._needs_parsing = True
+        self._needs_parsing = False
         self._needs_serialization = False
         self._needs_serialization = False
 
 
-    @classmethod
-    def _parse_object(cls, map):
-        """Parse a new style object , creating it and setting object._text"""
-        used = 0
-        byte = ord(map[used])
-        used += 1
-        num_type = (byte >> 4) & 7
+    @staticmethod
+    def _parse_object_header(magic, f):
+        """Parse a new style object, creating it but not reading the file."""
+        num_type = (ord(magic[0]) >> 4) & 7
+        obj_class = object_class(num_type)
+        if not obj_class:
+            raise ObjectFormatError("Not a known type: %d" % num_type)
+        obj = obj_class()
+        obj._filename = f.name
+        return obj
+
+    def _parse_object(self, f):
+        """Parse a new style object, setting self._text."""
+        size = os.path.getsize(f.name)
+        map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
         try:
         try:
-            object = num_type_map[num_type]()
-        except KeyError:
-            raise AssertionError("Not a known type: %d" % num_type)
-        while (byte & 0x80) != 0:
-            byte = ord(map[used])
-            used += 1
-        raw = map[used:]
-        object.set_raw_string(_decompress(raw))
-        return object
+            # skip type and size; type must have already been determined, and we
+            # trust zlib to fail if it's otherwise corrupted
+            byte = ord(map[0])
+            used = 1
+            while (byte & 0x80) != 0:
+                byte = ord(map[used])
+                used += 1
+            raw = map[used:]
+            self.set_raw_string(_decompress(raw))
+        finally:
+            map.close()
+
+    @classmethod
+    def _is_legacy_object(cls, magic):
+        b0, b1 = map(ord, magic)
+        word = (b0 << 8) + b1
+        return b0 == 0x78 and (word % 31) == 0
 
 
     @classmethod
     @classmethod
-    def _parse_file(cls, map):
-        word = (ord(map[0]) << 8) + ord(map[1])
-        if ord(map[0]) == 0x78 and (word % 31) == 0:
-            return cls._parse_legacy_object(map)
+    def _parse_file_header(cls, f):
+        magic = f.read(2)
+        if cls._is_legacy_object(magic):
+            return cls._parse_legacy_object_header(magic, f)
         else:
         else:
-            return cls._parse_object(map)
+            return cls._parse_object_header(magic, f)
 
 
     def __init__(self):
     def __init__(self):
         """Don't call this directly"""
         """Don't call this directly"""
         self._sha = None
         self._sha = None
+        self._filename = None
+        self._chunked_text = []
+        self._needs_parsing = False
+        self._needs_serialization = True
 
 
     def _deserialize(self, chunks):
     def _deserialize(self, chunks):
         raise NotImplementedError(self._deserialize)
         raise NotImplementedError(self._deserialize)
@@ -193,53 +301,96 @@ class ShaFile(object):
     def _serialize(self):
     def _serialize(self):
         raise NotImplementedError(self._serialize)
         raise NotImplementedError(self._serialize)
 
 
+    def _parse_file(self):
+        f = GitFile(self._filename, 'rb')
+        try:
+            magic = f.read(2)
+            if self._is_legacy_object(magic):
+                self._parse_legacy_object(f)
+            else:
+                self._parse_object(f)
+        finally:
+            f.close()
+
     @classmethod
     @classmethod
     def from_file(cls, filename):
     def from_file(cls, filename):
-        """Get the contents of a SHA file on disk"""
-        size = os.path.getsize(filename)
+        """Get the contents of a SHA file on disk."""
         f = GitFile(filename, 'rb')
         f = GitFile(filename, 'rb')
         try:
         try:
-            map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
-            shafile = cls._parse_file(map)
-            return shafile
+            try:
+                obj = cls._parse_file_header(f)
+                obj._sha = FixedSha(filename_to_hex(filename))
+                obj._needs_parsing = True
+                obj._needs_serialization = True
+                return obj
+            except (IndexError, ValueError), e:
+                raise ObjectFormatException("invalid object header")
         finally:
         finally:
             f.close()
             f.close()
 
 
-    @classmethod
-    def from_raw_string(cls, type, string):
+    @staticmethod
+    def from_raw_string(type_num, string):
         """Creates an object of the indicated type from the raw string given.
         """Creates an object of the indicated type from the raw string given.
 
 
-        Type is the numeric type of an object. String is the raw uncompressed
-        contents.
+        :param type_num: The numeric type of the object.
+        :param string: The raw uncompressed contents.
         """
         """
-        real_class = num_type_map[type]
-        obj = real_class()
-        obj.type = type
+        obj = object_class(type_num)()
         obj.set_raw_string(string)
         obj.set_raw_string(string)
         return obj
         return obj
 
 
-    @classmethod
-    def from_raw_chunks(cls, type, chunks):
+    @staticmethod
+    def from_raw_chunks(type_num, chunks):
         """Creates an object of the indicated type from the raw chunks given.
         """Creates an object of the indicated type from the raw chunks given.
 
 
-        Type is the numeric type of an object. Chunks is a sequence of the raw 
-        uncompressed contents.
+        :param type_num: The numeric type of the object.
+        :param chunks: An iterable of the raw uncompressed contents.
         """
         """
-        real_class = num_type_map[type]
-        obj = real_class()
-        obj.type = type
+        obj = object_class(type_num)()
         obj.set_raw_chunks(chunks)
         obj.set_raw_chunks(chunks)
         return obj
         return obj
 
 
     @classmethod
     @classmethod
     def from_string(cls, string):
     def from_string(cls, string):
-        """Create a blob from a string."""
-        shafile = cls()
-        shafile.set_raw_string(string)
-        return shafile
+        """Create a ShaFile from a string."""
+        obj = cls()
+        obj.set_raw_string(string)
+        return obj
+
+    def _check_has_member(self, member, error_msg):
+        """Check that the object has a given member variable.
+
+        :param member: the member variable to check for
+        :param error_msg: the message for an error if the member is missing
+        :raise ObjectFormatException: with the given error_msg if member is
+            missing or is None
+        """
+        if getattr(self, member, None) is None:
+            raise ObjectFormatException(error_msg)
+
+    def check(self):
+        """Check this object for internal consistency.
+
+        :raise ObjectFormatException: if the object is malformed in some way
+        :raise ChecksumMismatch: if the object was created with a SHA that does
+            not match its contents
+        """
+        # TODO: if we find that error-checking during object parsing is a
+        # performance bottleneck, those checks should be moved to the class's
+        # check() method during optimization so we can still check the object
+        # when necessary.
+        old_sha = self.id
+        try:
+            self._deserialize(self.as_raw_chunks())
+            self._sha = None
+            new_sha = self.id
+        except Exception, e:
+            raise ObjectFormatException(e)
+        if old_sha != new_sha:
+            raise ChecksumMismatch(new_sha, old_sha)
 
 
     def _header(self):
     def _header(self):
-        return "%s %lu\0" % (self._type, self.raw_length())
+        return "%s %lu\0" % (self.type_name, self.raw_length())
 
 
     def raw_length(self):
     def raw_length(self):
         """Returns the length of the raw string of this object."""
         """Returns the length of the raw string of this object."""
@@ -257,8 +408,13 @@ class ShaFile(object):
 
 
     def sha(self):
     def sha(self):
         """The SHA1 object that is the name of this object."""
         """The SHA1 object that is the name of this object."""
-        if self._needs_serialization or self._sha is None:
-            self._sha = self._make_sha()
+        if self._sha is None:
+            # this is a local because as_raw_chunks() overwrites self._sha
+            new_sha = make_sha()
+            new_sha.update(self._header())
+            for chunk in self.as_raw_chunks():
+                new_sha.update(chunk)
+            self._sha = new_sha
         return self._sha
         return self._sha
 
 
     @property
     @property
@@ -266,11 +422,12 @@ class ShaFile(object):
         return self.sha().hexdigest()
         return self.sha().hexdigest()
 
 
     def get_type(self):
     def get_type(self):
-        return self._num_type
+        return self.type_num
 
 
     def set_type(self, type):
     def set_type(self, type):
-        self._num_type = type
+        self.type_num = type
 
 
+    # DEPRECATED: use type_num or type_name as needed.
     type = property(get_type, set_type)
     type = property(get_type, set_type)
 
 
     def __repr__(self):
     def __repr__(self):
@@ -291,8 +448,8 @@ class ShaFile(object):
 class Blob(ShaFile):
 class Blob(ShaFile):
     """A Git Blob object."""
     """A Git Blob object."""
 
 
-    _type = BLOB_ID
-    _num_type = 3
+    type_name = 'blob'
+    type_num = 3
 
 
     def __init__(self):
     def __init__(self):
         super(Blob, self).__init__()
         super(Blob, self).__init__()
@@ -307,60 +464,125 @@ class Blob(ShaFile):
         self.set_raw_string(data)
         self.set_raw_string(data)
 
 
     data = property(_get_data, _set_data,
     data = property(_get_data, _set_data,
-            "The text contained within the blob object.")
+                    "The text contained within the blob object.")
 
 
     def _get_chunked(self):
     def _get_chunked(self):
+        self._ensure_parsed()
         return self._chunked_text
         return self._chunked_text
 
 
     def _set_chunked(self, chunks):
     def _set_chunked(self, chunks):
         self._chunked_text = chunks
         self._chunked_text = chunks
 
 
+    def _serialize(self):
+        if not self._chunked_text:
+            self._ensure_parsed()
+        self._needs_serialization = False
+        return self._chunked_text
+
+    def _deserialize(self, chunks):
+        self._chunked_text = chunks
+
     chunked = property(_get_chunked, _set_chunked,
     chunked = property(_get_chunked, _set_chunked,
         "The text within the blob object, as chunks (not necessarily lines).")
         "The text within the blob object, as chunks (not necessarily lines).")
 
 
     @classmethod
     @classmethod
     def from_file(cls, filename):
     def from_file(cls, filename):
         blob = ShaFile.from_file(filename)
         blob = ShaFile.from_file(filename)
-        if blob._type != cls._type:
+        if not isinstance(blob, cls):
             raise NotBlobError(filename)
             raise NotBlobError(filename)
         return blob
         return blob
 
 
+    def check(self):
+        """Check this object for internal consistency.
+
+        :raise ObjectFormatException: if the object is malformed in some way
+        """
+        super(Blob, self).check()
+
+
+def _parse_tag_or_commit(text):
+    """Parse tag or commit text.
+
+    :param text: the raw text of the tag or commit object.
+    :yield: tuples of (field, value), one per header line, in the order read
+        from the text, possibly including duplicates. Includes a field named
+        None for the freeform tag/commit text.
+    """
+    f = StringIO(text)
+    for l in f:
+        l = l.rstrip("\n")
+        if l == "":
+            # Empty line indicates end of headers
+            break
+        yield l.split(" ", 1)
+    yield (None, f.read())
+    f.close()
+
+
+def parse_tag(text):
+    return _parse_tag_or_commit(text)
+
 
 
 class Tag(ShaFile):
 class Tag(ShaFile):
     """A Git Tag object."""
     """A Git Tag object."""
 
 
-    _type = TAG_ID
-    _num_type = 4
+    type_name = 'tag'
+    type_num = 4
 
 
     def __init__(self):
     def __init__(self):
         super(Tag, self).__init__()
         super(Tag, self).__init__()
-        self._needs_parsing = False
-        self._needs_serialization = True
+        self._tag_timezone_neg_utc = False
 
 
     @classmethod
     @classmethod
     def from_file(cls, filename):
     def from_file(cls, filename):
-        blob = ShaFile.from_file(filename)
-        if blob._type != cls._type:
-            raise NotBlobError(filename)
-        return blob
+        tag = ShaFile.from_file(filename)
+        if not isinstance(tag, cls):
+            raise NotTagError(filename)
+        return tag
 
 
-    @classmethod
-    def from_string(cls, string):
-        """Create a blob from a string."""
-        shafile = cls()
-        shafile.set_raw_string(string)
-        return shafile
+    def check(self):
+        """Check this object for internal consistency.
+
+        :raise ObjectFormatException: if the object is malformed in some way
+        """
+        super(Tag, self).check()
+        self._check_has_member("_object_sha", "missing object sha")
+        self._check_has_member("_object_class", "missing object type")
+        self._check_has_member("_name", "missing tag name")
+
+        if not self._name:
+            raise ObjectFormatException("empty tag name")
+
+        check_hexsha(self._object_sha, "invalid object sha")
+
+        if getattr(self, "_tagger", None):
+            check_identity(self._tagger, "invalid tagger")
+
+        last = None
+        for field, _ in parse_tag("".join(self._chunked_text)):
+            if field == _OBJECT_HEADER and last is not None:
+                raise ObjectFormatException("unexpected object")
+            elif field == _TYPE_HEADER and last != _OBJECT_HEADER:
+                raise ObjectFormatException("unexpected type")
+            elif field == _TAG_HEADER and last != _TYPE_HEADER:
+                raise ObjectFormatException("unexpected tag name")
+            elif field == _TAGGER_HEADER and last != _TAG_HEADER:
+                raise ObjectFormatException("unexpected tagger")
+            last = field
 
 
     def _serialize(self):
     def _serialize(self):
         chunks = []
         chunks = []
-        chunks.append("%s %s\n" % (OBJECT_ID, self._object_sha))
-        chunks.append("%s %s\n" % (TYPE_ID, num_type_map[self._object_type]._type))
-        chunks.append("%s %s\n" % (TAG_ID, self._name))
+        chunks.append("%s %s\n" % (_OBJECT_HEADER, self._object_sha))
+        chunks.append("%s %s\n" % (_TYPE_HEADER, self._object_class.type_name))
+        chunks.append("%s %s\n" % (_TAG_HEADER, self._name))
         if self._tagger:
         if self._tagger:
             if self._tag_time is None:
             if self._tag_time is None:
-                chunks.append("%s %s\n" % (TAGGER_ID, self._tagger))
+                chunks.append("%s %s\n" % (_TAGGER_HEADER, self._tagger))
             else:
             else:
-                chunks.append("%s %s %d %s\n" % (TAGGER_ID, self._tagger, self._tag_time, format_timezone(self._tag_timezone)))
+                chunks.append("%s %s %d %s\n" % (
+                  _TAGGER_HEADER, self._tagger, self._tag_time,
+                  format_timezone(self._tag_timezone,
+                    self._tag_timezone_neg_utc)))
         chunks.append("\n") # To close headers
         chunks.append("\n") # To close headers
         chunks.append(self._message)
         chunks.append(self._message)
         return chunks
         return chunks
@@ -368,45 +590,49 @@ class Tag(ShaFile):
     def _deserialize(self, chunks):
     def _deserialize(self, chunks):
         """Grab the metadata attached to the tag"""
         """Grab the metadata attached to the tag"""
         self._tagger = None
         self._tagger = None
-        f = StringIO("".join(chunks))
-        for l in f:
-            l = l.rstrip("\n")
-            if l == "":
-                break # empty line indicates end of headers
-            (field, value) = l.split(" ", 1)
-            if field == OBJECT_ID:
+        for field, value in parse_tag("".join(chunks)):
+            if field == _OBJECT_HEADER:
                 self._object_sha = value
                 self._object_sha = value
-            elif field == TYPE_ID:
-                self._object_type = type_map[value]
-            elif field == TAG_ID:
+            elif field == _TYPE_HEADER:
+                obj_class = object_class(value)
+                if not obj_class:
+                    raise ObjectFormatException("Not a known type: %s" % value)
+                self._object_class = obj_class
+            elif field == _TAG_HEADER:
                 self._name = value
                 self._name = value
-            elif field == TAGGER_ID:
+            elif field == _TAGGER_HEADER:
                 try:
                 try:
                     sep = value.index("> ")
                     sep = value.index("> ")
                 except ValueError:
                 except ValueError:
                     self._tagger = value
                     self._tagger = value
                     self._tag_time = None
                     self._tag_time = None
                     self._tag_timezone = None
                     self._tag_timezone = None
+                    self._tag_timezone_neg_utc = False
                 else:
                 else:
                     self._tagger = value[0:sep+1]
                     self._tagger = value[0:sep+1]
-                    (timetext, timezonetext) = value[sep+2:].rsplit(" ", 1)
                     try:
                     try:
+                        (timetext, timezonetext) = value[sep+2:].rsplit(" ", 1)
                         self._tag_time = int(timetext)
                         self._tag_time = int(timetext)
-                    except ValueError: #Not a unix timestamp
-                        self._tag_time = time.strptime(timetext)
-                    self._tag_timezone = parse_timezone(timezonetext)
+                        self._tag_timezone, self._tag_timezone_neg_utc = \
+                                parse_timezone(timezonetext)
+                    except ValueError, e:
+                        raise ObjectFormatException(e)
+            elif field is None:
+                self._message = value
             else:
             else:
-                raise AssertionError("Unknown field %s" % field)
-        self._message = f.read()
+                raise ObjectFormatError("Unknown field %s" % field)
 
 
     def _get_object(self):
     def _get_object(self):
-        """Returns the object pointed by this tag, represented as a tuple(type, sha)"""
+        """Get the object pointed to by this tag.
+
+        :return: tuple of (object class, sha).
+        """
         self._ensure_parsed()
         self._ensure_parsed()
-        return (self._object_type, self._object_sha)
+        return (self._object_class, self._object_sha)
 
 
     def _set_object(self, value):
     def _set_object(self, value):
         self._ensure_parsed()
         self._ensure_parsed()
-        (self._object_type, self._object_sha) = value
+        (self._object_class, self._object_sha) = value
         self._needs_serialization = True
         self._needs_serialization = True
 
 
     object = property(_get_object, _set_object)
     object = property(_get_object, _set_object)
@@ -425,9 +651,8 @@ def parse_tree(text):
     """Parse a tree text.
     """Parse a tree text.
 
 
     :param text: Serialized text to parse
     :param text: Serialized text to parse
-    :return: Dictionary with names as keys, (mode, sha) tuples as values
+    :yields: tuples of (name, mode, sha)
     """
     """
-    ret = {}
     count = 0
     count = 0
     l = len(text)
     l = len(text)
     while count < l:
     while count < l:
@@ -437,8 +662,7 @@ def parse_tree(text):
         name = text[mode_end+1:name_end]
         name = text[mode_end+1:name_end]
         count = name_end+21
         count = name_end+21
         sha = text[name_end+1:count]
         sha = text[name_end+1:count]
-        ret[name] = (mode, sha_to_hex(sha))
-    return ret
+        yield (name, mode, sha_to_hex(sha))
 
 
 
 
 def serialize_tree(items):
 def serialize_tree(items):
@@ -458,32 +682,33 @@ def sorted_tree_items(entries):
     :param entries: Dictionary mapping names to (mode, sha) tuples
     :param entries: Dictionary mapping names to (mode, sha) tuples
     :return: Iterator over (name, mode, sha)
     :return: Iterator over (name, mode, sha)
     """
     """
-    def cmp_entry((name1, value1), (name2, value2)):
-        if stat.S_ISDIR(value1[0]):
-            name1 += "/"
-        if stat.S_ISDIR(value2[0]):
-            name2 += "/"
-        return cmp(name1, name2)
     for name, entry in sorted(entries.iteritems(), cmp=cmp_entry):
     for name, entry in sorted(entries.iteritems(), cmp=cmp_entry):
         yield name, entry[0], entry[1]
         yield name, entry[0], entry[1]
 
 
 
 
+def cmp_entry((name1, value1), (name2, value2)):
+    """Compare two tree entries."""
+    if stat.S_ISDIR(value1[0]):
+        name1 += "/"
+    if stat.S_ISDIR(value2[0]):
+        name2 += "/"
+    return cmp(name1, name2)
+
+
 class Tree(ShaFile):
 class Tree(ShaFile):
     """A Git tree object"""
     """A Git tree object"""
 
 
-    _type = TREE_ID
-    _num_type = 2
+    type_name = 'tree'
+    type_num = 2
 
 
     def __init__(self):
     def __init__(self):
         super(Tree, self).__init__()
         super(Tree, self).__init__()
         self._entries = {}
         self._entries = {}
-        self._needs_parsing = False
-        self._needs_serialization = True
 
 
     @classmethod
     @classmethod
     def from_file(cls, filename):
     def from_file(cls, filename):
         tree = ShaFile.from_file(filename)
         tree = ShaFile.from_file(filename)
-        if tree._type != cls._type:
+        if not isinstance(tree, cls):
             raise NotTreeError(filename)
             raise NotTreeError(filename)
         return tree
         return tree
 
 
@@ -511,6 +736,10 @@ class Tree(ShaFile):
         self._ensure_parsed()
         self._ensure_parsed()
         return len(self._entries)
         return len(self._entries)
 
 
+    def __iter__(self):
+        self._ensure_parsed()
+        return iter(self._entries)
+
     def add(self, mode, name, hexsha):
     def add(self, mode, name, hexsha):
         assert type(mode) == int
         assert type(mode) == int
         assert type(name) == str
         assert type(name) == str
@@ -528,8 +757,7 @@ class Tree(ShaFile):
             (mode, name, hexsha) for (name, mode, hexsha) in self.iteritems()]
             (mode, name, hexsha) for (name, mode, hexsha) in self.iteritems()]
 
 
     def iteritems(self):
     def iteritems(self):
-        """Iterate over all entries in the order in which they would be
-        serialized.
+        """Iterate over entries in the order in which they would be serialized.
 
 
         :return: Iterator over (name, mode, sha) tuples
         :return: Iterator over (name, mode, sha) tuples
         """
         """
@@ -538,7 +766,40 @@ class Tree(ShaFile):
 
 
     def _deserialize(self, chunks):
     def _deserialize(self, chunks):
         """Grab the entries in the tree"""
         """Grab the entries in the tree"""
-        self._entries = parse_tree("".join(chunks))
+        try:
+            parsed_entries = parse_tree("".join(chunks))
+        except ValueError, e:
+            raise ObjectFormatException(e)
+        # TODO: list comprehension is for efficiency in the common (small) case;
+        # if memory efficiency in the large case is a concern, use a genexp.
+        self._entries = dict([(n, (m, s)) for n, m, s in parsed_entries])
+
+    def check(self):
+        """Check this object for internal consistency.
+
+        :raise ObjectFormatException: if the object is malformed in some way
+        """
+        super(Tree, self).check()
+        last = None
+        allowed_modes = (stat.S_IFREG | 0755, stat.S_IFREG | 0644,
+                         stat.S_IFLNK, stat.S_IFDIR, S_IFGITLINK,
+                         # TODO: optionally exclude as in git fsck --strict
+                         stat.S_IFREG | 0664)
+        for name, mode, sha in parse_tree("".join(self._chunked_text)):
+            check_hexsha(sha, 'invalid sha %s' % sha)
+            if '/' in name or name in ('', '.', '..'):
+                raise ObjectFormatException('invalid name %s' % name)
+
+            if mode not in allowed_modes:
+                raise ObjectFormatException('invalid mode %06o' % mode)
+
+            entry = (name, (mode, sha))
+            if last:
+                if cmp_entry(last, entry) > 0:
+                    raise ObjectFormatException('entries not sorted')
+                if name == last[0]:
+                    raise ObjectFormatException('duplicate entry %s' % name)
+            last = entry
 
 
     def _serialize(self):
     def _serialize(self):
         return list(serialize_tree(self.iteritems()))
         return list(serialize_tree(self.iteritems()))
@@ -556,39 +817,47 @@ class Tree(ShaFile):
 
 
 def parse_timezone(text):
 def parse_timezone(text):
     offset = int(text)
     offset = int(text)
+    negative_utc = (offset == 0 and text[0] == '-')
     signum = (offset < 0) and -1 or 1
     signum = (offset < 0) and -1 or 1
     offset = abs(offset)
     offset = abs(offset)
     hours = int(offset / 100)
     hours = int(offset / 100)
     minutes = (offset % 100)
     minutes = (offset % 100)
-    return signum * (hours * 3600 + minutes * 60)
+    return signum * (hours * 3600 + minutes * 60), negative_utc
 
 
 
 
-def format_timezone(offset):
+def format_timezone(offset, negative_utc=False):
     if offset % 60 != 0:
     if offset % 60 != 0:
         raise ValueError("Unable to handle non-minute offset.")
         raise ValueError("Unable to handle non-minute offset.")
-    sign = (offset < 0) and '-' or '+'
+    if offset < 0 or (offset == 0 and negative_utc):
+        sign = '-'
+    else:
+        sign = '+'
     offset = abs(offset)
     offset = abs(offset)
     return '%c%02d%02d' % (sign, offset / 3600, (offset / 60) % 60)
     return '%c%02d%02d' % (sign, offset / 3600, (offset / 60) % 60)
 
 
 
 
+def parse_commit(text):
+    return _parse_tag_or_commit(text)
+
+
 class Commit(ShaFile):
 class Commit(ShaFile):
     """A git commit object"""
     """A git commit object"""
 
 
-    _type = COMMIT_ID
-    _num_type = 1
+    type_name = 'commit'
+    type_num = 1
 
 
     def __init__(self):
     def __init__(self):
         super(Commit, self).__init__()
         super(Commit, self).__init__()
         self._parents = []
         self._parents = []
         self._encoding = None
         self._encoding = None
-        self._needs_parsing = False
-        self._needs_serialization = True
         self._extra = {}
         self._extra = {}
+        self._author_timezone_neg_utc = False
+        self._commit_timezone_neg_utc = False
 
 
     @classmethod
     @classmethod
     def from_file(cls, filename):
     def from_file(cls, filename):
         commit = ShaFile.from_file(filename)
         commit = ShaFile.from_file(filename)
-        if commit._type != cls._type:
+        if not isinstance(commit, cls):
             raise NotCommitError(filename)
             raise NotCommitError(filename)
         return commit
         return commit
 
 
@@ -596,40 +865,79 @@ class Commit(ShaFile):
         self._parents = []
         self._parents = []
         self._extra = []
         self._extra = []
         self._author = None
         self._author = None
-        f = StringIO("".join(chunks))
-        for l in f:
-            l = l.rstrip("\n")
-            if l == "":
-                # Empty line indicates end of headers
-                break
-            (field, value) = l.split(" ", 1)
-            if field == TREE_ID:
+        for field, value in parse_commit("".join(self._chunked_text)):
+            if field == _TREE_HEADER:
                 self._tree = value
                 self._tree = value
-            elif field == PARENT_ID:
+            elif field == _PARENT_HEADER:
                 self._parents.append(value)
                 self._parents.append(value)
-            elif field == AUTHOR_ID:
+            elif field == _AUTHOR_HEADER:
                 self._author, timetext, timezonetext = value.rsplit(" ", 2)
                 self._author, timetext, timezonetext = value.rsplit(" ", 2)
                 self._author_time = int(timetext)
                 self._author_time = int(timetext)
-                self._author_timezone = parse_timezone(timezonetext)
-            elif field == COMMITTER_ID:
+                self._author_timezone, self._author_timezone_neg_utc =\
+                    parse_timezone(timezonetext)
+            elif field == _COMMITTER_HEADER:
                 self._committer, timetext, timezonetext = value.rsplit(" ", 2)
                 self._committer, timetext, timezonetext = value.rsplit(" ", 2)
                 self._commit_time = int(timetext)
                 self._commit_time = int(timetext)
-                self._commit_timezone = parse_timezone(timezonetext)
-            elif field == ENCODING_ID:
+                self._commit_timezone, self._commit_timezone_neg_utc =\
+                    parse_timezone(timezonetext)
+            elif field == _ENCODING_HEADER:
                 self._encoding = value
                 self._encoding = value
+            elif field is None:
+                self._message = value
             else:
             else:
                 self._extra.append((field, value))
                 self._extra.append((field, value))
-        self._message = f.read()
+
+    def check(self):
+        """Check this object for internal consistency.
+
+        :raise ObjectFormatException: if the object is malformed in some way
+        """
+        super(Commit, self).check()
+        self._check_has_member("_tree", "missing tree")
+        self._check_has_member("_author", "missing author")
+        self._check_has_member("_committer", "missing committer")
+        # times are currently checked when set
+
+        for parent in self._parents:
+            check_hexsha(parent, "invalid parent sha")
+        check_hexsha(self._tree, "invalid tree sha")
+
+        check_identity(self._author, "invalid author")
+        check_identity(self._committer, "invalid committer")
+
+        last = None
+        for field, _ in parse_commit("".join(self._chunked_text)):
+            if field == _TREE_HEADER and last is not None:
+                raise ObjectFormatException("unexpected tree")
+            elif field == _PARENT_HEADER and last not in (_PARENT_HEADER,
+                                                          _TREE_HEADER):
+                raise ObjectFormatException("unexpected parent")
+            elif field == _AUTHOR_HEADER and last not in (_TREE_HEADER,
+                                                          _PARENT_HEADER):
+                raise ObjectFormatException("unexpected author")
+            elif field == _COMMITTER_HEADER and last != _AUTHOR_HEADER:
+                raise ObjectFormatException("unexpected committer")
+            elif field == _ENCODING_HEADER and last != _COMMITTER_HEADER:
+                raise ObjectFormatException("unexpected encoding")
+            last = field
+
+        # TODO: optionally check for duplicate parents
 
 
     def _serialize(self):
     def _serialize(self):
         chunks = []
         chunks = []
-        chunks.append("%s %s\n" % (TREE_ID, self._tree))
+        chunks.append("%s %s\n" % (_TREE_HEADER, self._tree))
         for p in self._parents:
         for p in self._parents:
-            chunks.append("%s %s\n" % (PARENT_ID, p))
-        chunks.append("%s %s %s %s\n" % (AUTHOR_ID, self._author, str(self._author_time), format_timezone(self._author_timezone)))
-        chunks.append("%s %s %s %s\n" % (COMMITTER_ID, self._committer, str(self._commit_time), format_timezone(self._commit_timezone)))
+            chunks.append("%s %s\n" % (_PARENT_HEADER, p))
+        chunks.append("%s %s %s %s\n" % (
+          _AUTHOR_HEADER, self._author, str(self._author_time),
+          format_timezone(self._author_timezone,
+                          self._author_timezone_neg_utc)))
+        chunks.append("%s %s %s %s\n" % (
+          _COMMITTER_HEADER, self._committer, str(self._commit_time),
+          format_timezone(self._commit_timezone,
+                          self._commit_timezone_neg_utc)))
         if self.encoding:
         if self.encoding:
-            chunks.append("%s %s\n" % (ENCODING_ID, self.encoding))
+            chunks.append("%s %s\n" % (_ENCODING_HEADER, self.encoding))
         for k, v in self.extra:
         for k, v in self.extra:
             if "\n" in k or "\n" in v:
             if "\n" in k or "\n" in v:
                 raise AssertionError("newline in extra data: %r -> %r" % (k, v))
                 raise AssertionError("newline in extra data: %r -> %r" % (k, v))
@@ -685,22 +993,24 @@ class Commit(ShaFile):
         "Encoding of the commit message.")
         "Encoding of the commit message.")
 
 
 
 
-type_map = {
-    BLOB_ID : Blob,
-    TREE_ID : Tree,
-    COMMIT_ID : Commit,
-    TAG_ID: Tag,
-}
+OBJECT_CLASSES = (
+    Commit,
+    Tree,
+    Blob,
+    Tag,
+    )
+
+_TYPE_MAP = {}
+
+for cls in OBJECT_CLASSES:
+    _TYPE_MAP[cls.type_name] = cls
+    _TYPE_MAP[cls.type_num] = cls
+
 
 
-num_type_map = {
-    0: None,
-    1: Commit,
-    2: Tree,
-    3: Blob,
-    4: Tag,
-    # 5 Is reserved for further expansion
-}
 
 
+# Hold on to the pure-python implementations for testing
+_parse_tree_py = parse_tree
+_sorted_tree_items_py = sorted_tree_items
 try:
 try:
     # Try to import C versions
     # Try to import C versions
     from dulwich._objects import parse_tree, sorted_tree_items
     from dulwich._objects import parse_tree, sorted_tree_items

+ 70 - 55
dulwich/pack.py

@@ -36,6 +36,7 @@ except ImportError:
     from misc import defaultdict
     from misc import defaultdict
 
 
 import difflib
 import difflib
+import errno
 from itertools import (
 from itertools import (
     chain,
     chain,
     imap,
     imap,
@@ -124,22 +125,41 @@ def load_pack_index(path):
     return load_pack_index_file(path, f)
     return load_pack_index_file(path, f)
 
 
 
 
+def _load_file_contents(f, size=None):
+    fileno = getattr(f, 'fileno', None)
+    # Attempt to use mmap if possible
+    if fileno is not None:
+        fd = f.fileno()
+        if size is None:
+            size = os.fstat(fd).st_size
+        try:
+            contents = mmap.mmap(fd, size, access=mmap.ACCESS_READ)
+        except mmap.error:
+            # Perhaps a socket?
+            pass
+        else:
+            return contents, size
+    contents = f.read()
+    size = len(contents)
+    return contents, size
+
+
 def load_pack_index_file(path, f):
 def load_pack_index_file(path, f):
     """Load an index file from a file-like object.
     """Load an index file from a file-like object.
 
 
     :param path: Path for the index file
     :param path: Path for the index file
     :param f: File-like object
     :param f: File-like object
     """
     """
-    if f.read(4) == '\377tOc':
-        version = struct.unpack(">L", f.read(4))[0]
+    contents, size = _load_file_contents(f)
+    if contents[:4] == '\377tOc':
+        version = struct.unpack(">L", contents[4:8])[0]
         if version == 2:
         if version == 2:
-            f.seek(0)
-            return PackIndex2(path, file=f)
+            return PackIndex2(path, file=f, contents=contents,
+                size=size)
         else:
         else:
             raise KeyError("Unknown pack index format %d" % version)
             raise KeyError("Unknown pack index format %d" % version)
     else:
     else:
-        f.seek(0)
-        return PackIndex1(path, file=f)
+        return PackIndex1(path, file=f, contents=contents, size=size)
 
 
 
 
 def bisect_find_sha(start, end, sha, unpack_name):
 def bisect_find_sha(start, end, sha, unpack_name):
@@ -179,7 +199,7 @@ class PackIndex(object):
     the start and end offset and then bisect in to find if the value is present.
     the start and end offset and then bisect in to find if the value is present.
     """
     """
   
   
-    def __init__(self, filename, file=None, size=None):
+    def __init__(self, filename, file=None, contents=None, size=None):
         """Create a pack index object.
         """Create a pack index object.
     
     
         Provide it with the name of the index file to consider, and it will map
         Provide it with the name of the index file to consider, and it will map
@@ -192,19 +212,10 @@ class PackIndex(object):
             self._file = GitFile(filename, 'rb')
             self._file = GitFile(filename, 'rb')
         else:
         else:
             self._file = file
             self._file = file
-        fileno = getattr(self._file, 'fileno', None)
-        if fileno is not None:
-            fd = self._file.fileno()
-            if size is None:
-                self._size = os.fstat(fd).st_size
-            else:
-                self._size = size
-            self._contents = mmap.mmap(fd, self._size,
-                access=mmap.ACCESS_READ)
+        if contents is None:
+            self._contents, self._size = _load_file_contents(file, size)
         else:
         else:
-            self._file.seek(0)
-            self._contents = self._file.read()
-            self._size = len(self._contents)
+            self._contents, self._size = (contents, size)
   
   
     def __eq__(self, other):
     def __eq__(self, other):
         if not isinstance(other, PackIndex):
         if not isinstance(other, PackIndex):
@@ -213,7 +224,8 @@ class PackIndex(object):
         if self._fan_out_table != other._fan_out_table:
         if self._fan_out_table != other._fan_out_table:
             return False
             return False
     
     
-        for (name1, _, _), (name2, _, _) in izip(self.iterentries(), other.iterentries()):
+        for (name1, _, _), (name2, _, _) in izip(self.iterentries(),
+                                                 other.iterentries()):
             if name1 != name2:
             if name1 != name2:
                 return False
                 return False
         return True
         return True
@@ -265,7 +277,8 @@ class PackIndex(object):
     def iterentries(self):
     def iterentries(self):
         """Iterate over the entries in this pack index.
         """Iterate over the entries in this pack index.
        
        
-        Will yield tuples with object name, offset in packfile and crc32 checksum.
+        Will yield tuples with object name, offset in packfile and crc32
+        checksum.
         """
         """
         for i in range(len(self)):
         for i in range(len(self)):
             yield self._unpack_entry(i)
             yield self._unpack_entry(i)
@@ -273,7 +286,8 @@ class PackIndex(object):
     def _read_fan_out_table(self, start_offset):
     def _read_fan_out_table(self, start_offset):
         ret = []
         ret = []
         for i in range(0x100):
         for i in range(0x100):
-            ret.append(struct.unpack(">L", self._contents[start_offset+i*4:start_offset+(i+1)*4])[0])
+            ret.append(struct.unpack(">L",
+                self._contents[start_offset+i*4:start_offset+(i+1)*4])[0])
         return ret
         return ret
   
   
     def check(self):
     def check(self):
@@ -305,9 +319,9 @@ class PackIndex(object):
     def object_index(self, sha):
     def object_index(self, sha):
         """Return the index in to the corresponding packfile for the object.
         """Return the index in to the corresponding packfile for the object.
     
     
-        Given the name of an object it will return the offset that object lives
-        at within the corresponding pack file. If the pack file doesn't have the
-        object then None will be returned.
+        Given the name of an object it will return the offset that object
+        lives at within the corresponding pack file. If the pack file doesn't
+        have the object then None will be returned.
         """
         """
         if len(sha) == 40:
         if len(sha) == 40:
             sha = hex_to_sha(sha)
             sha = hex_to_sha(sha)
@@ -335,8 +349,8 @@ class PackIndex(object):
 class PackIndex1(PackIndex):
 class PackIndex1(PackIndex):
     """Version 1 Pack Index."""
     """Version 1 Pack Index."""
 
 
-    def __init__(self, filename, file=None, size=None):
-        PackIndex.__init__(self, filename, file, size)
+    def __init__(self, filename, file=None, contents=None, size=None):
+        PackIndex.__init__(self, filename, file, contents, size)
         self.version = 1
         self.version = 1
         self._fan_out_table = self._read_fan_out_table(0)
         self._fan_out_table = self._read_fan_out_table(0)
 
 
@@ -361,8 +375,8 @@ class PackIndex1(PackIndex):
 class PackIndex2(PackIndex):
 class PackIndex2(PackIndex):
     """Version 2 Pack Index."""
     """Version 2 Pack Index."""
 
 
-    def __init__(self, filename, file=None, size=None):
-        PackIndex.__init__(self, filename, file, size)
+    def __init__(self, filename, file=None, contents=None, size=None):
+        PackIndex.__init__(self, filename, file, contents, size)
         assert self._contents[:4] == '\377tOc', "Not a v2 pack index file"
         assert self._contents[:4] == '\377tOc', "Not a v2 pack index file"
         (self.version, ) = unpack_from(">L", self._contents, 4)
         (self.version, ) = unpack_from(">L", self._contents, 4)
         assert self.version == 2, "Version was %d" % self.version
         assert self.version == 2, "Version was %d" % self.version
@@ -427,17 +441,17 @@ def unpack_object(read):
             delta_base_offset += 1
             delta_base_offset += 1
             delta_base_offset <<= 7
             delta_base_offset <<= 7
             delta_base_offset += (byte & 0x7f)
             delta_base_offset += (byte & 0x7f)
-        uncomp, comp_len, unused = read_zlib_chunks(read, size)
+        uncomp, comp_len, unused = read_zlib_chunks(read)
         assert size == chunks_length(uncomp)
         assert size == chunks_length(uncomp)
         return type, (delta_base_offset, uncomp), comp_len+raw_base, unused
         return type, (delta_base_offset, uncomp), comp_len+raw_base, unused
     elif type == 7: # ref delta
     elif type == 7: # ref delta
         basename = read(20)
         basename = read(20)
         raw_base += 20
         raw_base += 20
-        uncomp, comp_len, unused = read_zlib_chunks(read, size)
+        uncomp, comp_len, unused = read_zlib_chunks(read)
         assert size == chunks_length(uncomp)
         assert size == chunks_length(uncomp)
         return type, (basename, uncomp), comp_len+raw_base, unused
         return type, (basename, uncomp), comp_len+raw_base, unused
     else:
     else:
-        uncomp, comp_len, unused = read_zlib_chunks(read, size)
+        uncomp, comp_len, unused = read_zlib_chunks(read)
         assert chunks_length(uncomp) == size
         assert chunks_length(uncomp) == size
         return type, uncomp, comp_len+raw_base, unused
         return type, uncomp, comp_len+raw_base, unused
 
 
@@ -472,16 +486,17 @@ class PackData(object):
     buffer from the start of the deflated object on. This is bad, but until I
     buffer from the start of the deflated object on. This is bad, but until I
     get mmap sorted out it will have to do.
     get mmap sorted out it will have to do.
   
   
-    Currently there are no integrity checks done. Also no attempt is made to try
-    and detect the delta case, or a request for an object at the wrong position.
-    It will all just throw a zlib or KeyError.
+    Currently there are no integrity checks done. Also no attempt is made to
+    try and detect the delta case, or a request for an object at the wrong
+    position.  It will all just throw a zlib or KeyError.
     """
     """
   
   
     def __init__(self, filename, file=None, size=None):
     def __init__(self, filename, file=None, size=None):
-        """Create a PackData object that represents the pack in the given filename.
+        """Create a PackData object that represents the pack in the given
+        filename.
     
     
-        The file must exist and stay readable until the object is disposed of. It
-        must also stay the same size. It will be mapped whenever needed.
+        The file must exist and stay readable until the object is disposed of.
+        It must also stay the same size. It will be mapped whenever needed.
     
     
         Currently there is a restriction on the size of the pack as the python
         Currently there is a restriction on the size of the pack as the python
         mmap implementation is flawed.
         mmap implementation is flawed.
@@ -625,9 +640,9 @@ class PackData(object):
         for (offset, type, obj, crc32) in todo:
         for (offset, type, obj, crc32) in todo:
             assert isinstance(offset, int)
             assert isinstance(offset, int)
             assert isinstance(type, int)
             assert isinstance(type, int)
-            assert isinstance(obj, list) or isinstance(obj, str)
             try:
             try:
-                type, obj = self.resolve_object(offset, type, obj, get_ref_text)
+                type, obj = self.resolve_object(offset, type, obj,
+                    get_ref_text)
             except Postpone, (sha, ):
             except Postpone, (sha, ):
                 postponed[sha].append((offset, type, obj))
                 postponed[sha].append((offset, type, obj))
             else:
             else:
@@ -656,8 +671,8 @@ class PackData(object):
         """Create a version 1 file for this data file.
         """Create a version 1 file for this data file.
 
 
         :param filename: Index filename.
         :param filename: Index filename.
-        :param resolve_ext_ref: Function to use for resolving externally referenced
-            SHA1s (for thin packs)
+        :param resolve_ext_ref: Function to use for resolving externally
+            referenced SHA1s (for thin packs)
         :param progress: Progress report function
         :param progress: Progress report function
         """
         """
         entries = self.sorted_entries(resolve_ext_ref, progress=progress)
         entries = self.sorted_entries(resolve_ext_ref, progress=progress)
@@ -667,8 +682,8 @@ class PackData(object):
         """Create a version 2 index file for this data file.
         """Create a version 2 index file for this data file.
 
 
         :param filename: Index filename.
         :param filename: Index filename.
-        :param resolve_ext_ref: Function to use for resolving externally referenced
-            SHA1s (for thin packs)
+        :param resolve_ext_ref: Function to use for resolving externally
+            referenced SHA1s (for thin packs)
         :param progress: Progress report function
         :param progress: Progress report function
         """
         """
         entries = self.sorted_entries(resolve_ext_ref, progress=progress)
         entries = self.sorted_entries(resolve_ext_ref, progress=progress)
@@ -679,8 +694,8 @@ class PackData(object):
         """Create an  index file for this data file.
         """Create an  index file for this data file.
 
 
         :param filename: Index filename.
         :param filename: Index filename.
-        :param resolve_ext_ref: Function to use for resolving externally referenced
-            SHA1s (for thin packs)
+        :param resolve_ext_ref: Function to use for resolving externally
+            referenced SHA1s (for thin packs)
         :param progress: Progress report function
         :param progress: Progress report function
         """
         """
         if version == 1:
         if version == 1:
@@ -702,8 +717,8 @@ class PackData(object):
     def get_object_at(self, offset):
     def get_object_at(self, offset):
         """Given an offset in to the packfile return the object that is there.
         """Given an offset in to the packfile return the object that is there.
     
     
-        Using the associated index the location of an object can be looked up, and
-        then the packfile can be asked directly for that object using this
+        Using the associated index the location of an object can be looked up,
+        and then the packfile can be asked directly for that object using this
         function.
         function.
         """
         """
         if offset in self._offset_cache:
         if offset in self._offset_cache:
@@ -834,7 +849,7 @@ def write_pack_data(f, objects, num_objects, window=10):
     # This helps us find good objects to diff against us
     # This helps us find good objects to diff against us
     magic = []
     magic = []
     for obj, path in recency:
     for obj, path in recency:
-        magic.append( (obj.type, path, 1, -obj.raw_length(), obj) )
+        magic.append( (obj.type_num, path, 1, -obj.raw_length(), obj) )
     magic.sort()
     magic.sort()
     # Build a map of objects and their index in magic - so we can find preceeding objects
     # Build a map of objects and their index in magic - so we can find preceeding objects
     # to diff against
     # to diff against
@@ -849,14 +864,14 @@ def write_pack_data(f, objects, num_objects, window=10):
     f.write(struct.pack(">L", num_objects)) # Number of objects in pack
     f.write(struct.pack(">L", num_objects)) # Number of objects in pack
     for o, path in recency:
     for o, path in recency:
         sha1 = o.sha().digest()
         sha1 = o.sha().digest()
-        orig_t = o.type
+        orig_t = o.type_num
         raw = o.as_raw_string()
         raw = o.as_raw_string()
         winner = raw
         winner = raw
         t = orig_t
         t = orig_t
         #for i in range(offs[o]-window, window):
         #for i in range(offs[o]-window, window):
         #    if i < 0 or i >= len(offs): continue
         #    if i < 0 or i >= len(offs): continue
         #    b = magic[i][4]
         #    b = magic[i][4]
-        #    if b.type != orig_t: continue
+        #    if b.type_num != orig_t: continue
         #    base = b.as_raw_string()
         #    base = b.as_raw_string()
         #    delta = create_delta(base, raw)
         #    delta = create_delta(base, raw)
         #    if len(delta) < len(winner):
         #    if len(delta) < len(winner):
@@ -871,8 +886,8 @@ def write_pack_index_v1(filename, entries, pack_checksum):
     """Write a new pack index file.
     """Write a new pack index file.
 
 
     :param filename: The filename of the new pack index file.
     :param filename: The filename of the new pack index file.
-    :param entries: List of tuples with object name (sha), offset_in_pack,  and
-            crc32_checksum.
+    :param entries: List of tuples with object name (sha), offset_in_pack,
+        and crc32_checksum.
     :param pack_checksum: Checksum of the pack file.
     :param pack_checksum: Checksum of the pack file.
     """
     """
     f = GitFile(filename, 'wb')
     f = GitFile(filename, 'wb')
@@ -1020,8 +1035,8 @@ def write_pack_index_v2(filename, entries, pack_checksum):
     """Write a new pack index file.
     """Write a new pack index file.
 
 
     :param filename: The filename of the new pack index file.
     :param filename: The filename of the new pack index file.
-    :param entries: List of tuples with object name (sha), offset_in_pack,  and
-            crc32_checksum.
+    :param entries: List of tuples with object name (sha), offset_in_pack, and
+        crc32_checksum.
     :param pack_checksum: Checksum of the pack file.
     :param pack_checksum: Checksum of the pack file.
     """
     """
     f = GitFile(filename, 'wb')
     f = GitFile(filename, 'wb')

+ 33 - 20
dulwich/repo.py

@@ -49,7 +49,7 @@ from dulwich.objects import (
     Tag,
     Tag,
     Tree,
     Tree,
     hex_to_sha,
     hex_to_sha,
-    num_type_map,
+    object_class,
     )
     )
 import warnings
 import warnings
 
 
@@ -62,9 +62,6 @@ REFSDIR_HEADS = 'heads'
 INDEX_FILENAME = "index"
 INDEX_FILENAME = "index"
 
 
 BASE_DIRECTORIES = [
 BASE_DIRECTORIES = [
-    [OBJECTDIR], 
-    [OBJECTDIR, "info"], 
-    [OBJECTDIR, "pack"],
     ["branches"],
     ["branches"],
     [REFSDIR],
     [REFSDIR],
     [REFSDIR, REFSDIR_TAGS],
     [REFSDIR, REFSDIR_TAGS],
@@ -77,7 +74,7 @@ BASE_DIRECTORIES = [
 def read_info_refs(f):
 def read_info_refs(f):
     ret = {}
     ret = {}
     for l in f.readlines():
     for l in f.readlines():
-        (sha, name) = l.rstrip("\n").split("\t", 1)
+        (sha, name) = l.rstrip("\r\n").split("\t", 1)
         ret[name] = sha
         ret[name] = sha
     return ret
     return ret
 
 
@@ -118,6 +115,12 @@ class RefsContainer(object):
     """A container for refs."""
     """A container for refs."""
 
 
     def set_ref(self, name, other):
     def set_ref(self, name, other):
+        warnings.warn("RefsContainer.set_ref() is deprecated."
+            "Use set_symblic_ref instead.",
+            category=DeprecationWarning, stacklevel=2)
+        return self.set_symbolic_ref(name, other)
+
+    def set_symbolic_ref(self, name, other):
         """Make a ref point at another ref.
         """Make a ref point at another ref.
 
 
         :param name: Name of the ref to set
         :param name: Name of the ref to set
@@ -200,6 +203,18 @@ class RefsContainer(object):
         if not name.startswith('refs/') or not check_ref_format(name[5:]):
         if not name.startswith('refs/') or not check_ref_format(name[5:]):
             raise KeyError(name)
             raise KeyError(name)
 
 
+    def read_ref(self, refname):
+        """Read a reference without following any references.
+
+        :param refname: The name of the reference
+        :return: The contents of the ref file, or None if it does 
+            not exist.
+        """
+        contents = self.read_loose_ref(refname)
+        if not contents:
+            contents = self.get_packed_refs().get(refname, None)
+        return contents
+
     def read_loose_ref(self, name):
     def read_loose_ref(self, name):
         """Read a loose reference and return its contents.
         """Read a loose reference and return its contents.
 
 
@@ -220,20 +235,16 @@ class RefsContainer(object):
         depth = 0
         depth = 0
         while contents.startswith(SYMREF):
         while contents.startswith(SYMREF):
             refname = contents[len(SYMREF):]
             refname = contents[len(SYMREF):]
-            contents = self.read_loose_ref(refname)
+            contents = self.read_ref(refname)
             if not contents:
             if not contents:
-                contents = self.get_packed_refs().get(refname, None)
-                if not contents:
-                    break
+                break
             depth += 1
             depth += 1
             if depth > 5:
             if depth > 5:
                 raise KeyError(name)
                 raise KeyError(name)
         return refname, contents
         return refname, contents
 
 
     def __contains__(self, refname):
     def __contains__(self, refname):
-        if self.read_loose_ref(refname):
-            return True
-        if self.get_packed_refs().get(refname, None):
+        if self.read_ref(refname):
             return True
             return True
         return False
         return False
 
 
@@ -380,7 +391,7 @@ class DiskRefsContainer(RefsContainer):
                 header = f.read(len(SYMREF))
                 header = f.read(len(SYMREF))
                 if header == SYMREF:
                 if header == SYMREF:
                     # Read only the first line
                     # Read only the first line
-                    return header + iter(f).next().rstrip("\n")
+                    return header + iter(f).next().rstrip("\r\n")
                 else:
                 else:
                     # Read only the first 40 bytes
                     # Read only the first 40 bytes
                     return header + f.read(40-len(SYMREF))
                     return header + f.read(40-len(SYMREF))
@@ -572,7 +583,7 @@ def read_packed_refs_with_peeled(f):
     for l in f:
     for l in f:
         if l[0] == "#":
         if l[0] == "#":
             continue
             continue
-        l = l.rstrip("\n")
+        l = l.rstrip("\r\n")
         if l[0] == "^":
         if l[0] == "^":
             if not last:
             if not last:
                 raise PackedRefsException("unexpected peeled ref line")
                 raise PackedRefsException("unexpected peeled ref line")
@@ -698,7 +709,7 @@ class BaseRepo(object):
     def _get_object(self, sha, cls):
     def _get_object(self, sha, cls):
         assert len(sha) in (20, 40)
         assert len(sha) in (20, 40)
         ret = self.get_object(sha)
         ret = self.get_object(sha)
-        if ret._type != cls._type:
+        if not isinstance(ret, cls):
             if cls is Commit:
             if cls is Commit:
                 raise NotCommitError(ret)
                 raise NotCommitError(ret)
             elif cls is Blob:
             elif cls is Blob:
@@ -708,7 +719,8 @@ class BaseRepo(object):
             elif cls is Tag:
             elif cls is Tag:
                 raise NotTagError(ret)
                 raise NotTagError(ret)
             else:
             else:
-                raise Exception("Type invalid: %r != %r" % (ret._type, cls._type))
+                raise Exception("Type invalid: %r != %r" % (
+                  ret.type_name, cls.type_name))
         return ret
         return ret
 
 
     def get_object(self, sha):
     def get_object(self, sha):
@@ -784,9 +796,9 @@ class BaseRepo(object):
         if cached is not None:
         if cached is not None:
             return cached
             return cached
         obj = self[ref]
         obj = self[ref]
-        obj_type = num_type_map[obj.type]
-        while obj_type == Tag:
-            obj_type, sha = obj.object
+        obj_class = object_class(obj.type_name)
+        while obj_class is Tag:
+            obj_class, sha = obj.object
             obj = self.get_object(sha)
             obj = self.get_object(sha)
         return obj.id
         return obj.id
 
 
@@ -1001,8 +1013,9 @@ class Repo(BaseRepo):
     def init_bare(cls, path, mkdir=True):
     def init_bare(cls, path, mkdir=True):
         for d in BASE_DIRECTORIES:
         for d in BASE_DIRECTORIES:
             os.mkdir(os.path.join(path, *d))
             os.mkdir(os.path.join(path, *d))
+        DiskObjectStore.init(os.path.join(path, OBJECTDIR))
         ret = cls(path)
         ret = cls(path)
-        ret.refs.set_ref("HEAD", "refs/heads/master")
+        ret.refs.set_symbolic_ref("HEAD", "refs/heads/master")
         ret._put_named_file('description', "Unnamed repository")
         ret._put_named_file('description', "Unnamed repository")
         ret._put_named_file('config', """[core]
         ret._put_named_file('config', """[core]
     repositoryformatversion = 0
     repositoryformatversion = 0

+ 106 - 91
dulwich/server.py

@@ -49,14 +49,27 @@ from dulwich.protocol import (
     MULTI_ACK_DETAILED,
     MULTI_ACK_DETAILED,
     ack_type,
     ack_type,
     )
     )
-from dulwich.repo import (
-    Repo,
-    )
 from dulwich.pack import (
 from dulwich.pack import (
     write_pack_data,
     write_pack_data,
     )
     )
 
 
 class Backend(object):
 class Backend(object):
+    """A backend for the Git smart server implementation."""
+
+    def open_repository(self, path):
+        """Open the repository at a path."""
+        raise NotImplementedError(self.open_repository)
+
+
+class BackendRepo(object):
+    """Repository abstraction used by the Git server.
+    
+    Please note that the methods required here are a 
+    subset of those provided by dulwich.repo.Repo.
+    """
+
+    object_store = None
+    refs = None
 
 
     def get_refs(self):
     def get_refs(self):
         """
         """
@@ -66,14 +79,16 @@ class Backend(object):
         """
         """
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def apply_pack(self, refs, read, delete_refs=True):
-        """ Import a set of changes into a repository and update the refs
+    def get_peeled(self, name):
+        """Return the cached peeled value of a ref, if available.
 
 
-        :param refs: list of tuple(name, sha)
-        :param read: callback to read from the incoming pack
-        :param delete_refs: whether to allow deleting refs
+        :param name: Name of the ref to peel
+        :return: The peeled value of the ref. If the ref is known not point to
+            a tag, this will be the SHA the ref refers to. If no cached
+            information about a tag is available, this method may return None,
+            but it should attempt to peel the tag if possible.
         """
         """
-        raise NotImplementedError
+        return None
 
 
     def fetch_objects(self, determine_wants, graph_walker, progress,
     def fetch_objects(self, determine_wants, graph_walker, progress,
                       get_tagged=None):
                       get_tagged=None):
@@ -87,71 +102,15 @@ class Backend(object):
         raise NotImplementedError
         raise NotImplementedError
 
 
 
 
-class GitBackend(Backend):
+class DictBackend(Backend):
+    """Trivial backend that looks up Git repositories in a dictionary."""
 
 
-    def __init__(self, repo=None):
-        if repo is None:
-            repo = Repo(tmpfile.mkdtemp())
-        self.repo = repo
-        self.refs = self.repo.refs
-        self.object_store = self.repo.object_store
-        self.fetch_objects = self.repo.fetch_objects
-        self.get_refs = self.repo.get_refs
-
-    def apply_pack(self, refs, read, delete_refs=True):
-        f, commit = self.repo.object_store.add_thin_pack()
-        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
-        status = []
-        unpack_error = None
-        # TODO: more informative error messages than just the exception string
-        try:
-            # TODO: decode the pack as we stream to avoid blocking reads beyond
-            # the end of data (when using HTTP/1.1 chunked encoding)
-            while True:
-                data = read(10240)
-                if not data:
-                    break
-                f.write(data)
-        except all_exceptions, e:
-            unpack_error = str(e).replace('\n', '')
-        try:
-            commit()
-        except all_exceptions, e:
-            if not unpack_error:
-                unpack_error = str(e).replace('\n', '')
+    def __init__(self, repos):
+        self.repos = repos
 
 
-        if unpack_error:
-            status.append(('unpack', unpack_error))
-        else:
-            status.append(('unpack', 'ok'))
-
-        for oldsha, sha, ref in refs:
-            ref_error = None
-            try:
-                if sha == ZERO_SHA:
-                    if not delete_refs:
-                        raise GitProtocolError(
-                          'Attempted to delete refs without delete-refs '
-                          'capability.')
-                    try:
-                        del self.repo.refs[ref]
-                    except all_exceptions:
-                        ref_error = 'failed to delete'
-                else:
-                    try:
-                        self.repo.refs[ref] = sha
-                    except all_exceptions:
-                        ref_error = 'failed to write'
-            except KeyError, e:
-                ref_error = 'bad ref'
-            if ref_error:
-                status.append((ref, ref_error))
-            else:
-                status.append((ref, 'ok'))
-
-
-        print "pack applied"
-        return status
+    def open_repository(self, path):
+        # FIXME: What to do in case there is no repo ?
+        return self.repos[path]
 
 
 
 
 class Handler(object):
 class Handler(object):
@@ -198,9 +157,10 @@ class Handler(object):
 class UploadPackHandler(Handler):
 class UploadPackHandler(Handler):
     """Protocol handler for uploading a pack to the server."""
     """Protocol handler for uploading a pack to the server."""
 
 
-    def __init__(self, backend, read, write,
+    def __init__(self, backend, args, read, write,
                  stateless_rpc=False, advertise_refs=False):
                  stateless_rpc=False, advertise_refs=False):
         Handler.__init__(self, backend, read, write)
         Handler.__init__(self, backend, read, write)
+        self.repo = backend.open_repository(args[0])
         self._graph_walker = None
         self._graph_walker = None
         self.stateless_rpc = stateless_rpc
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
         self.advertise_refs = advertise_refs
@@ -230,14 +190,14 @@ class UploadPackHandler(Handler):
         if not self.has_capability("include-tag"):
         if not self.has_capability("include-tag"):
             return {}
             return {}
         if refs is None:
         if refs is None:
-            refs = self.backend.get_refs()
+            refs = self.repo.get_refs()
         if repo is None:
         if repo is None:
-            repo = getattr(self.backend, "repo", None)
+            repo = getattr(self.repo, "repo", None)
             if repo is None:
             if repo is None:
                 # Bail if we don't have a Repo available; this is ok since
                 # Bail if we don't have a Repo available; this is ok since
                 # clients must be able to handle if the server doesn't include
                 # clients must be able to handle if the server doesn't include
                 # all relevant tags.
                 # all relevant tags.
-                # TODO: either guarantee a Repo, or fix behavior when missing
+                # TODO: fix behavior when missing
                 return {}
                 return {}
         tagged = {}
         tagged = {}
         for name, sha in refs.iteritems():
         for name, sha in refs.iteritems():
@@ -249,8 +209,9 @@ class UploadPackHandler(Handler):
     def handle(self):
     def handle(self):
         write = lambda x: self.proto.write_sideband(1, x)
         write = lambda x: self.proto.write_sideband(1, x)
 
 
-        graph_walker = ProtocolGraphWalker(self)
-        objects_iter = self.backend.fetch_objects(
+        graph_walker = ProtocolGraphWalker(self, self.repo.object_store,
+            self.repo.get_peeled)
+        objects_iter = self.repo.fetch_objects(
           graph_walker.determine_wants, graph_walker, self.progress,
           graph_walker.determine_wants, graph_walker, self.progress,
           get_tagged=self.get_tagged)
           get_tagged=self.get_tagged)
 
 
@@ -270,9 +231,9 @@ class UploadPackHandler(Handler):
 class ProtocolGraphWalker(object):
 class ProtocolGraphWalker(object):
     """A graph walker that knows the git protocol.
     """A graph walker that knows the git protocol.
 
 
-    As a graph walker, this class implements ack(), next(), and reset(). It also
-    contains some base methods for interacting with the wire and walking the
-    commit tree.
+    As a graph walker, this class implements ack(), next(), and reset(). It
+    also contains some base methods for interacting with the wire and walking
+    the commit tree.
 
 
     The work of determining which acks to send is passed on to the
     The work of determining which acks to send is passed on to the
     implementation instance stored in _impl. The reason for this is that we do
     implementation instance stored in _impl. The reason for this is that we do
@@ -280,9 +241,10 @@ class ProtocolGraphWalker(object):
     call to set_ack_level() is required to set up the implementation, before any
     call to set_ack_level() is required to set up the implementation, before any
     calls to next() or ack() are made.
     calls to next() or ack() are made.
     """
     """
-    def __init__(self, handler):
+    def __init__(self, handler, object_store, get_peeled):
         self.handler = handler
         self.handler = handler
-        self.store = handler.backend.object_store
+        self.store = object_store
+        self.get_peeled = get_peeled
         self.proto = handler.proto
         self.proto = handler.proto
         self.stateless_rpc = handler.stateless_rpc
         self.stateless_rpc = handler.stateless_rpc
         self.advertise_refs = handler.advertise_refs
         self.advertise_refs = handler.advertise_refs
@@ -312,7 +274,7 @@ class ProtocolGraphWalker(object):
                 if not i:
                 if not i:
                     line = "%s\x00%s" % (line, self.handler.capability_line())
                     line = "%s\x00%s" % (line, self.handler.capability_line())
                 self.proto.write_pkt_line("%s\n" % line)
                 self.proto.write_pkt_line("%s\n" % line)
-                peeled_sha = self.handler.backend.repo.get_peeled(ref)
+                peeled_sha = self.get_peeled(ref)
                 if peeled_sha != sha:
                 if peeled_sha != sha:
                     self.proto.write_pkt_line('%s %s^{}\n' %
                     self.proto.write_pkt_line('%s %s^{}\n' %
                                               (peeled_sha, ref))
                                               (peeled_sha, ref))
@@ -421,10 +383,10 @@ class ProtocolGraphWalker(object):
             commit = pending.popleft()
             commit = pending.popleft()
             if commit.id in haves:
             if commit.id in haves:
                 return True
                 return True
-            if not getattr(commit, 'get_parents', None):
+            if commit.type_name != "commit":
                 # non-commit wants are assumed to be satisfied
                 # non-commit wants are assumed to be satisfied
                 continue
                 continue
-            for parent in commit.get_parents():
+            for parent in commit.parents:
                 parent_obj = self.store[parent]
                 parent_obj = self.store[parent]
                 # TODO: handle parents with later commit times than children
                 # TODO: handle parents with later commit times than children
                 if parent_obj.commit_time >= earliest:
                 if parent_obj.commit_time >= earliest:
@@ -559,17 +521,71 @@ class MultiAckDetailedGraphWalkerImpl(object):
 class ReceivePackHandler(Handler):
 class ReceivePackHandler(Handler):
     """Protocol handler for downloading a pack from the client."""
     """Protocol handler for downloading a pack from the client."""
 
 
-    def __init__(self, backend, read, write,
+    def __init__(self, backend, args, read, write,
                  stateless_rpc=False, advertise_refs=False):
                  stateless_rpc=False, advertise_refs=False):
         Handler.__init__(self, backend, read, write)
         Handler.__init__(self, backend, read, write)
+        self.repo = backend.open_repository(args[0])
         self.stateless_rpc = stateless_rpc
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
         self.advertise_refs = advertise_refs
 
 
     def capabilities(self):
     def capabilities(self):
         return ("report-status", "delete-refs")
         return ("report-status", "delete-refs")
 
 
+    def _apply_pack(self, refs, read):
+        f, commit = self.repo.object_store.add_thin_pack()
+        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
+        status = []
+        unpack_error = None
+        # TODO: more informative error messages than just the exception string
+        try:
+            # TODO: decode the pack as we stream to avoid blocking reads beyond
+            # the end of data (when using HTTP/1.1 chunked encoding)
+            while True:
+                data = read(10240)
+                if not data:
+                    break
+                f.write(data)
+        except all_exceptions, e:
+            unpack_error = str(e).replace('\n', '')
+        try:
+            commit()
+        except all_exceptions, e:
+            if not unpack_error:
+                unpack_error = str(e).replace('\n', '')
+
+        if unpack_error:
+            status.append(('unpack', unpack_error))
+        else:
+            status.append(('unpack', 'ok'))
+
+        for oldsha, sha, ref in refs:
+            ref_error = None
+            try:
+                if sha == ZERO_SHA:
+                    if not self.has_capability('delete-refs'):
+                        raise GitProtocolError(
+                          'Attempted to delete refs without delete-refs '
+                          'capability.')
+                    try:
+                        del self.repo.refs[ref]
+                    except all_exceptions:
+                        ref_error = 'failed to delete'
+                else:
+                    try:
+                        self.repo.refs[ref] = sha
+                    except all_exceptions:
+                        ref_error = 'failed to write'
+            except KeyError, e:
+                ref_error = 'bad ref'
+            if ref_error:
+                status.append((ref, ref_error))
+            else:
+                status.append((ref, 'ok'))
+
+        return status
+
     def handle(self):
     def handle(self):
-        refs = self.backend.get_refs().items()
+        refs = self.repo.get_refs().items()
 
 
         if self.advertise_refs or not self.stateless_rpc:
         if self.advertise_refs or not self.stateless_rpc:
             if refs:
             if refs:
@@ -603,8 +619,7 @@ class ReceivePackHandler(Handler):
             ref = self.proto.read_pkt_line()
             ref = self.proto.read_pkt_line()
 
 
         # backend can now deal with this refs and read a pack using self.read
         # backend can now deal with this refs and read a pack using self.read
-        status = self.backend.apply_pack(client_refs, self.proto.read,
-                                         self.has_capability('delete-refs'))
+        status = self.repo._apply_pack(client_refs, self.proto.read)
 
 
         # when we have read all the pack from the client, send a status report
         # when we have read all the pack from the client, send a status report
         # if the client asked for it
         # if the client asked for it
@@ -633,7 +648,7 @@ class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
         else:
         else:
             return
             return
 
 
-        h = cls(self.server.backend, self.rfile.read, self.wfile.write)
+        h = cls(self.server.backend, args, self.rfile.read, self.wfile.write)
         h.handle()
         h.handle()
 
 
 
 

+ 4 - 0
dulwich/tests/__init__.py

@@ -21,6 +21,10 @@
 
 
 import unittest
 import unittest
 
 
+# XXX: Ideally we should allow other test runners as well, 
+# but unfortunately unittest doesn't have a SkipTest/TestSkipped
+# exception.
+from nose import SkipTest as TestSkipped
 
 
 def test_suite():
 def test_suite():
     names = [
     names = [

+ 15 - 9
dulwich/tests/compat/test_server.py

@@ -26,27 +26,32 @@ On *nix, you can kill the tests with Ctrl-Z, "kill %".
 
 
 import threading
 import threading
 
 
-from dulwich import server
+from dulwich.server import (
+    DictBackend,
+    TCPGitServer,
+    )
+from dulwich.tests import (
+    TestSkipped,
+    )
 from server_utils import (
 from server_utils import (
     ServerTests,
     ServerTests,
     ShutdownServerMixIn,
     ShutdownServerMixIn,
     )
     )
 from utils import (
 from utils import (
     CompatTestCase,
     CompatTestCase,
-    SkipTest,
     )
     )
 
 
 
 
-if getattr(server.TCPGitServer, 'shutdown', None):
-    TCPGitServer = server.TCPGitServer
-else:
-    class TCPGitServer(ShutdownServerMixIn, server.TCPGitServer):
+if not getattr(TCPGitServer, 'shutdown', None):
+    _TCPGitServer = TCPGitServer
+
+    class TCPGitServer(ShutdownServerMixIn, TCPGitServer):
         """Subclass of TCPGitServer that can be shut down."""
         """Subclass of TCPGitServer that can be shut down."""
 
 
         def __init__(self, *args, **kwargs):
         def __init__(self, *args, **kwargs):
             # BaseServer is old-style so we have to call both __init__s
             # BaseServer is old-style so we have to call both __init__s
             ShutdownServerMixIn.__init__(self)
             ShutdownServerMixIn.__init__(self)
-            server.TCPGitServer.__init__(self, *args, **kwargs)
+            _TCPGitServer.__init__(self, *args, **kwargs)
 
 
         serve = ShutdownServerMixIn.serve_forever
         serve = ShutdownServerMixIn.serve_forever
 
 
@@ -65,11 +70,12 @@ class GitServerTestCase(ServerTests, CompatTestCase):
         CompatTestCase.tearDown(self)
         CompatTestCase.tearDown(self)
 
 
     def _start_server(self, repo):
     def _start_server(self, repo):
-        dul_server = TCPGitServer(server.GitBackend(repo), 'localhost', 0)
+        backend = DictBackend({'/': repo})
+        dul_server = TCPGitServer(backend, 'localhost', 0)
         threading.Thread(target=dul_server.serve).start()
         threading.Thread(target=dul_server.serve).start()
         self._server = dul_server
         self._server = dul_server
         _, port = self._server.socket.getsockname()
         _, port = self._server.socket.getsockname()
         return port
         return port
 
 
     def test_push_to_dulwich(self):
     def test_push_to_dulwich(self):
-        raise SkipTest('Skipping push test due to known deadlock bug.')
+        raise TestSkipped('Skipping push test due to known deadlock bug.')

+ 8 - 5
dulwich/tests/compat/test_web.py

@@ -28,7 +28,10 @@ import threading
 from wsgiref import simple_server
 from wsgiref import simple_server
 
 
 from dulwich.server import (
 from dulwich.server import (
-    GitBackend,
+    DictBackend,
+    )
+from dulwich.tests import (
+    TestSkipped,
     )
     )
 from dulwich.web import (
 from dulwich.web import (
     HTTPGitApplication,
     HTTPGitApplication,
@@ -40,7 +43,6 @@ from server_utils import (
     )
     )
 from utils import (
 from utils import (
     CompatTestCase,
     CompatTestCase,
-    SkipTest,
     )
     )
 
 
 
 
@@ -68,7 +70,8 @@ class WebTests(ServerTests):
     protocol = 'http'
     protocol = 'http'
 
 
     def _start_server(self, repo):
     def _start_server(self, repo):
-        app = self._make_app(GitBackend(repo))
+        backend = DictBackend({'/': repo})
+        app = self._make_app(backend)
         dul_server = simple_server.make_server('localhost', 0, app,
         dul_server = simple_server.make_server('localhost', 0, app,
                                                server_class=WSGIServer)
                                                server_class=WSGIServer)
         threading.Thread(target=dul_server.serve_forever).start()
         threading.Thread(target=dul_server.serve_forever).start()
@@ -95,7 +98,7 @@ class SmartWebTestCase(WebTests, CompatTestCase):
 
 
     def test_push_to_dulwich(self):
     def test_push_to_dulwich(self):
         # TODO(dborowitz): enable after merging thin pack fixes.
         # TODO(dborowitz): enable after merging thin pack fixes.
-        raise SkipTest('Skipping push test due to known pack bug.')
+        raise TestSkipped('Skipping push test due to known pack bug.')
 
 
 
 
 class DumbWebTestCase(WebTests, CompatTestCase):
 class DumbWebTestCase(WebTests, CompatTestCase):
@@ -114,4 +117,4 @@ class DumbWebTestCase(WebTests, CompatTestCase):
 
 
     def test_push_to_dulwich(self):
     def test_push_to_dulwich(self):
         # Note: remove this if dumb pushing is supported
         # Note: remove this if dumb pushing is supported
-        raise SkipTest('Dumb web pushing not supported.')
+        raise TestSkipped('Dumb web pushing not supported.')

+ 5 - 5
dulwich/tests/compat/utils.py

@@ -24,11 +24,11 @@ import subprocess
 import tempfile
 import tempfile
 import unittest
 import unittest
 
 
-# XXX: Ideally we shouldn't depend on nose but allow other testrunners as well.
-from nose import SkipTest
-
 from dulwich.repo import Repo
 from dulwich.repo import Repo
 
 
+from dulwich.tests import (
+    TestSkipped,
+    )
 
 
 _DEFAULT_GIT = 'git'
 _DEFAULT_GIT = 'git'
 
 
@@ -67,8 +67,8 @@ def require_git_version(required_version, git_path=_DEFAULT_GIT):
     if found_version < required_version:
     if found_version < required_version:
         required_version = '.'.join(map(str, required_version))
         required_version = '.'.join(map(str, required_version))
         found_version = '.'.join(map(str, found_version))
         found_version = '.'.join(map(str, found_version))
-        raise SkipTest('Test requires git >= %s, found %s' %
-                            (required_version, found_version))
+        raise TestSkipped('Test requires git >= %s, found %s' %
+                         (required_version, found_version))
 
 
 
 
 def run_git(args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False,
 def run_git(args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False,

binární
dulwich/tests/data/blobs/11/11111111111111111111111111111111111111


+ 0 - 0
dulwich/tests/data/blobs/6f670c0fb53f9463760b7295fbb814e965fb20c8 → dulwich/tests/data/blobs/6f/670c0fb53f9463760b7295fbb814e965fb20c8


+ 0 - 0
dulwich/tests/data/blobs/954a536f7819d40e6f637f849ee187dd10066349 → dulwich/tests/data/blobs/95/4a536f7819d40e6f637f849ee187dd10066349


+ 0 - 0
dulwich/tests/data/blobs/e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 → dulwich/tests/data/blobs/e6/9de29bb2d1d6434b8b29ae775ad8c2e48c5391


+ 0 - 0
dulwich/tests/data/commits/0d89f20333fbb1d2f3a94da77f4981373d8f4310 → dulwich/tests/data/commits/0d/89f20333fbb1d2f3a94da77f4981373d8f4310


+ 0 - 0
dulwich/tests/data/commits/5dac377bdded4c9aeb8dff595f0faeebcc8498cc → dulwich/tests/data/commits/5d/ac377bdded4c9aeb8dff595f0faeebcc8498cc


+ 0 - 0
dulwich/tests/data/commits/60dacdc733de308bb77bb76ce0fb0f9b44c9769e → dulwich/tests/data/commits/60/dacdc733de308bb77bb76ce0fb0f9b44c9769e


+ 0 - 0
dulwich/tests/data/tags/71033db03a03c6a36721efcf1968dd8f8e0cf023 → dulwich/tests/data/tags/71/033db03a03c6a36721efcf1968dd8f8e0cf023


+ 0 - 0
dulwich/tests/data/trees/70c190eb48fa8bbb50ddc692a17b44cb781af7f6 → dulwich/tests/data/trees/70/c190eb48fa8bbb50ddc692a17b44cb781af7f6


+ 16 - 7
dulwich/tests/test_object_store.py

@@ -31,6 +31,7 @@ from dulwich.object_store import (
     )
     )
 import os
 import os
 import shutil
 import shutil
+import tempfile
 
 
 
 
 testobject = Blob()
 testobject = Blob()
@@ -39,12 +40,18 @@ testobject.data = "yummy data"
 
 
 class SpecificDiskObjectStoreTests(TestCase):
 class SpecificDiskObjectStoreTests(TestCase):
 
 
+    def setUp(self):
+        self.store_dir = tempfile.mkdtemp()
+
+    def tearDown(self):
+        shutil.rmtree(self.store_dir)
+
     def test_pack_dir(self):
     def test_pack_dir(self):
-        o = DiskObjectStore("foo")
-        self.assertEquals(os.path.join("foo", "pack"), o.pack_dir)
+        o = DiskObjectStore(self.store_dir)
+        self.assertEquals(os.path.join(self.store_dir, "pack"), o.pack_dir)
 
 
     def test_empty_packs(self):
     def test_empty_packs(self):
-        o = DiskObjectStore("foo")
+        o = DiskObjectStore(self.store_dir)
         self.assertEquals([], o.packs)
         self.assertEquals([], o.packs)
 
 
 
 
@@ -95,10 +102,12 @@ class DiskObjectStoreTests(ObjectStoreTests,TestCase):
 
 
     def setUp(self):
     def setUp(self):
         TestCase.setUp(self)
         TestCase.setUp(self)
-        if os.path.exists("foo"):
-            shutil.rmtree("foo")
-        os.makedirs(os.path.join("foo", "pack"))
-        self.store = DiskObjectStore("foo")
+        self.store_dir = tempfile.mkdtemp()
+        self.store = DiskObjectStore.init(self.store_dir)
+
+    def tearDown(self):
+        TestCase.tearDown(self)
+        shutil.rmtree(self.store_dir)
 
 
 
 
 # TODO: MissingObjectFinderTests
 # TODO: MissingObjectFinderTests

+ 349 - 64
dulwich/tests/test_objects.py

@@ -20,11 +20,18 @@
 
 
 """Tests for git base objects."""
 """Tests for git base objects."""
 
 
+# TODO: Round-trip parse-serialize-parse and serialize-parse-serialize tests.
 
 
+
+import datetime
 import os
 import os
 import stat
 import stat
 import unittest
 import unittest
 
 
+from dulwich.errors import (
+    ChecksumMismatch,
+    ObjectFormatException,
+    )
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
     Tree,
     Tree,
@@ -32,7 +39,15 @@ from dulwich.objects import (
     Tag,
     Tag,
     format_timezone,
     format_timezone,
     hex_to_sha,
     hex_to_sha,
+    hex_to_filename,
+    check_hexsha,
+    check_identity,
     parse_timezone,
     parse_timezone,
+    parse_tree,
+    _parse_tree_py,
+    )
+from dulwich.tests import (
+    TestSkipped,
     )
     )
 
 
 a_sha = '6f670c0fb53f9463760b7295fbb814e965fb20c8'
 a_sha = '6f670c0fb53f9463760b7295fbb814e965fb20c8'
@@ -41,13 +56,46 @@ c_sha = '954a536f7819d40e6f637f849ee187dd10066349'
 tree_sha = '70c190eb48fa8bbb50ddc692a17b44cb781af7f6'
 tree_sha = '70c190eb48fa8bbb50ddc692a17b44cb781af7f6'
 tag_sha = '71033db03a03c6a36721efcf1968dd8f8e0cf023'
 tag_sha = '71033db03a03c6a36721efcf1968dd8f8e0cf023'
 
 
+
+try:
+    from itertools import permutations
+except ImportError:
+    # Implementation of permutations from Python 2.6 documentation:
+    # http://docs.python.org/2.6/library/itertools.html#itertools.permutations
+    # Copyright (c) 2001-2010 Python Software Foundation; All Rights Reserved
+    def permutations(iterable, r=None):
+        # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
+        # permutations(range(3)) --> 012 021 102 120 201 210
+        pool = tuple(iterable)
+        n = len(pool)
+        r = n if r is None else r
+        if r > n:
+            return
+        indices = range(n)
+        cycles = range(n, n-r, -1)
+        yield tuple(pool[i] for i in indices[:r])
+        while n:
+            for i in reversed(range(r)):
+                cycles[i] -= 1
+                if cycles[i] == 0:
+                    indices[i:] = indices[i+1:] + indices[i:i+1]
+                    cycles[i] = n - i
+                else:
+                    j = cycles[i]
+                    indices[i], indices[-j] = indices[-j], indices[i]
+                    yield tuple(pool[i] for i in indices[:r])
+                    break
+            else:
+                return
+
+
 class BlobReadTests(unittest.TestCase):
 class BlobReadTests(unittest.TestCase):
     """Test decompression of blobs"""
     """Test decompression of blobs"""
-  
-    def get_sha_file(self, obj, base, sha):
-        return obj.from_file(os.path.join(os.path.dirname(__file__),
-                                          'data', base, sha))
-  
+
+    def get_sha_file(self, cls, base, sha):
+        dir = os.path.join(os.path.dirname(__file__), 'data', base)
+        return cls.from_file(hex_to_filename(dir, sha))
+
     def get_blob(self, sha):
     def get_blob(self, sha):
         """Return the blob named sha from the test data dir"""
         """Return the blob named sha from the test data dir"""
         return self.get_sha_file(Blob, 'blobs', sha)
         return self.get_sha_file(Blob, 'blobs', sha)
@@ -162,7 +210,28 @@ class BlobReadTests(unittest.TestCase):
         self.assertEqual(c.commit_timezone, 0)
         self.assertEqual(c.commit_timezone, 0)
         self.assertEqual(c.author_timezone, 0)
         self.assertEqual(c.author_timezone, 0)
         self.assertEqual(c.message, 'Merge ../b\n')
         self.assertEqual(c.message, 'Merge ../b\n')
-  
+
+    def test_check_id(self):
+        wrong_sha = '1' * 40
+        b = self.get_blob(wrong_sha)
+        self.assertEqual(wrong_sha, b.id)
+        self.assertRaises(ChecksumMismatch, b.check)
+        self.assertEqual('742b386350576589175e374a5706505cbd17680c', b.id)
+
+
+class ShaFileCheckTests(unittest.TestCase):
+
+    def assertCheckFails(self, cls, data):
+        obj = cls()
+        def do_check():
+            obj.set_raw_string(data)
+            obj.check()
+        self.assertRaises(ObjectFormatException, do_check)
+
+    def assertCheckSucceeds(self, cls, data):
+        obj = cls()
+        obj.set_raw_string(data)
+        self.assertEqual(None, obj.check())
 
 
 
 
 class CommitSerializationTests(unittest.TestCase):
 class CommitSerializationTests(unittest.TestCase):
@@ -219,42 +288,115 @@ class CommitSerializationTests(unittest.TestCase):
         self.assertTrue(" -0100\n" in c.as_raw_string())
         self.assertTrue(" -0100\n" in c.as_raw_string())
 
 
 
 
-class CommitDeserializationTests(unittest.TestCase):
+default_committer = 'James Westby <jw+debian@jameswestby.net> 1174773719 +0000'
+
+class CommitParseTests(ShaFileCheckTests):
+
+    def make_commit_lines(self,
+                          tree='d80c186a03f423a81b39df39dc87fd269736ca86',
+                          parents=['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
+                                   '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
+                          author=default_committer,
+                          committer=default_committer,
+                          encoding=None,
+                          message='Merge ../b\n',
+                          extra=None):
+        lines = []
+        if tree is not None:
+            lines.append('tree %s' % tree)
+        if parents is not None:
+            lines.extend('parent %s' % p for p in parents)
+        if author is not None:
+            lines.append('author %s' % author)
+        if committer is not None:
+            lines.append('committer %s' % committer)
+        if encoding is not None:
+            lines.append('encoding %s' % encoding)
+        if extra is not None:
+            for name, value in sorted(extra.iteritems()):
+                lines.append('%s %s' % (name, value))
+        lines.append('')
+        if message is not None:
+            lines.append(message)
+        return lines
+
+    def make_commit_text(self, **kwargs):
+        return '\n'.join(self.make_commit_lines(**kwargs))
 
 
     def test_simple(self):
     def test_simple(self):
-        c = Commit.from_string(
-                'tree d80c186a03f423a81b39df39dc87fd269736ca86\n'
-                'parent ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd\n'
-                'parent 4cffe90e0a41ad3f5190079d7c8f036bde29cbe6\n'
-                'author James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                'committer James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                '\n'
-                'Merge ../b\n')
+        c = Commit.from_string(self.make_commit_text())
         self.assertEquals('Merge ../b\n', c.message)
         self.assertEquals('Merge ../b\n', c.message)
+        self.assertEquals('James Westby <jw+debian@jameswestby.net>', c.author)
         self.assertEquals('James Westby <jw+debian@jameswestby.net>',
         self.assertEquals('James Westby <jw+debian@jameswestby.net>',
-            c.author)
-        self.assertEquals('James Westby <jw+debian@jameswestby.net>',
-            c.committer)
-        self.assertEquals('d80c186a03f423a81b39df39dc87fd269736ca86',
-            c.tree)
+                          c.committer)
+        self.assertEquals('d80c186a03f423a81b39df39dc87fd269736ca86', c.tree)
         self.assertEquals(['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
         self.assertEquals(['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
-                          '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
-            c.parents)
+                           '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
+                          c.parents)
+        expected_time = datetime.datetime(2007, 3, 24, 22, 1, 59)
+        self.assertEquals(expected_time,
+                          datetime.datetime.utcfromtimestamp(c.commit_time))
+        self.assertEquals(0, c.commit_timezone)
+        self.assertEquals(expected_time,
+                          datetime.datetime.utcfromtimestamp(c.author_time))
+        self.assertEquals(0, c.author_timezone)
+        self.assertEquals(None, c.encoding)
 
 
     def test_custom(self):
     def test_custom(self):
-        c = Commit.from_string(
-                'tree d80c186a03f423a81b39df39dc87fd269736ca86\n'
-                'parent ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd\n'
-                'parent 4cffe90e0a41ad3f5190079d7c8f036bde29cbe6\n'
-                'author James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                'committer James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                'extra-field data\n'
-                '\n'
-                'Merge ../b\n')
+        c = Commit.from_string(self.make_commit_text(
+          extra={'extra-field': 'data'}))
         self.assertEquals([('extra-field', 'data')], c.extra)
         self.assertEquals([('extra-field', 'data')], c.extra)
 
 
-
-class TreeSerializationTests(unittest.TestCase):
+    def test_encoding(self):
+        c = Commit.from_string(self.make_commit_text(encoding='UTF-8'))
+        self.assertEquals('UTF-8', c.encoding)
+
+    def test_check(self):
+        self.assertCheckSucceeds(Commit, self.make_commit_text())
+        self.assertCheckSucceeds(Commit, self.make_commit_text(parents=None))
+        self.assertCheckSucceeds(Commit,
+                                 self.make_commit_text(encoding='UTF-8'))
+
+        self.assertCheckFails(Commit, self.make_commit_text(tree='xxx'))
+        self.assertCheckFails(Commit, self.make_commit_text(
+          parents=[a_sha, 'xxx']))
+        bad_committer = "some guy without an email address 1174773719 +0000"
+        self.assertCheckFails(Commit,
+                              self.make_commit_text(committer=bad_committer))
+        self.assertCheckFails(Commit,
+                              self.make_commit_text(author=bad_committer))
+        self.assertCheckFails(Commit, self.make_commit_text(author=None))
+        self.assertCheckFails(Commit, self.make_commit_text(committer=None))
+        self.assertCheckFails(Commit, self.make_commit_text(
+          author=None, committer=None))
+
+    def test_check_duplicates(self):
+        # duplicate each of the header fields
+        for i in xrange(5):
+            lines = self.make_commit_lines(parents=[a_sha], encoding='UTF-8')
+            lines.insert(i, lines[i])
+            text = '\n'.join(lines)
+            if lines[i].startswith('parent'):
+                # duplicate parents are ok for now
+                self.assertCheckSucceeds(Commit, text)
+            else:
+                self.assertCheckFails(Commit, text)
+
+    def test_check_order(self):
+        lines = self.make_commit_lines(parents=[a_sha], encoding='UTF-8')
+        headers = lines[:5]
+        rest = lines[5:]
+        # of all possible permutations, ensure only the original succeeds
+        for perm in permutations(headers):
+            perm = list(perm)
+            text = '\n'.join(perm + rest)
+            if perm == headers:
+                self.assertCheckSucceeds(Commit, text)
+            else:
+                self.assertCheckFails(Commit, text)
+
+
+class TreeTests(ShaFileCheckTests):
 
 
     def test_simple(self):
     def test_simple(self):
         myhexsha = "d80c186a03f423a81b39df39dc87fd269736ca86"
         myhexsha = "d80c186a03f423a81b39df39dc87fd269736ca86"
@@ -270,6 +412,57 @@ class TreeSerializationTests(unittest.TestCase):
         x["a/c"] = (stat.S_IFDIR, "d80c186a03f423a81b39df39dc87fd269736ca86")
         x["a/c"] = (stat.S_IFDIR, "d80c186a03f423a81b39df39dc87fd269736ca86")
         self.assertEquals(["a.c", "a", "a/c"], [p[0] for p in x.iteritems()])
         self.assertEquals(["a.c", "a", "a/c"], [p[0] for p in x.iteritems()])
 
 
+    def _do_test_parse_tree(self, parse_tree):
+        dir = os.path.join(os.path.dirname(__file__), 'data', 'trees')
+        o = Tree.from_file(hex_to_filename(dir, tree_sha))
+        o._parse_file()
+        self.assertEquals([('a', 0100644, a_sha), ('b', 0100644, b_sha)],
+                          list(parse_tree(o.as_raw_string())))
+
+    def test_parse_tree(self):
+        self._do_test_parse_tree(_parse_tree_py)
+
+    def test_parse_tree_extension(self):
+        if parse_tree is _parse_tree_py:
+            raise TestSkipped('parse_tree extension not found')
+        self._do_test_parse_tree(parse_tree)
+
+    def test_check(self):
+        t = Tree
+        sha = hex_to_sha(a_sha)
+
+        # filenames
+        self.assertCheckSucceeds(t, '100644 .a\0%s' % sha)
+        self.assertCheckFails(t, '100644 \0%s' % sha)
+        self.assertCheckFails(t, '100644 .\0%s' % sha)
+        self.assertCheckFails(t, '100644 a/a\0%s' % sha)
+        self.assertCheckFails(t, '100644 ..\0%s' % sha)
+
+        # modes
+        self.assertCheckSucceeds(t, '100644 a\0%s' % sha)
+        self.assertCheckSucceeds(t, '100755 a\0%s' % sha)
+        self.assertCheckSucceeds(t, '160000 a\0%s' % sha)
+        # TODO more whitelisted modes
+        self.assertCheckFails(t, '123456 a\0%s' % sha)
+        self.assertCheckFails(t, '123abc a\0%s' % sha)
+
+        # shas
+        self.assertCheckFails(t, '100644 a\0%s' % ('x' * 5))
+        self.assertCheckFails(t, '100644 a\0%s' % ('x' * 18 + '\0'))
+        self.assertCheckFails(t, '100644 a\0%s\n100644 b\0%s' % ('x' * 21, sha))
+
+        # ordering
+        sha2 = hex_to_sha(b_sha)
+        self.assertCheckSucceeds(t, '100644 a\0%s\n100644 b\0%s' % (sha, sha))
+        self.assertCheckSucceeds(t, '100644 a\0%s\n100644 b\0%s' % (sha, sha2))
+        self.assertCheckFails(t, '100644 a\0%s\n100755 a\0%s' % (sha, sha2))
+        self.assertCheckFails(t, '100644 b\0%s\n100644 a\0%s' % (sha2, sha))
+
+    def test_iter(self):
+        t = Tree()
+        t["foo"] = (0100644, a_sha)
+        self.assertEquals(set(["foo"]), set(t))
+
 
 
 class TagSerializeTests(unittest.TestCase):
 class TagSerializeTests(unittest.TestCase):
 
 
@@ -278,7 +471,7 @@ class TagSerializeTests(unittest.TestCase):
         x.tagger = "Jelmer Vernooij <jelmer@samba.org>"
         x.tagger = "Jelmer Vernooij <jelmer@samba.org>"
         x.name = "0.1"
         x.name = "0.1"
         x.message = "Tag 0.1"
         x.message = "Tag 0.1"
-        x.object = (3, "d80c186a03f423a81b39df39dc87fd269736ca86")
+        x.object = (Blob, "d80c186a03f423a81b39df39dc87fd269736ca86")
         x.tag_time = 423423423
         x.tag_time = 423423423
         x.tag_timezone = 0
         x.tag_timezone = 0
         self.assertEquals("""object d80c186a03f423a81b39df39dc87fd269736ca86
         self.assertEquals("""object d80c186a03f423a81b39df39dc87fd269736ca86
@@ -289,16 +482,9 @@ tagger Jelmer Vernooij <jelmer@samba.org> 423423423 +0000
 Tag 0.1""", x.as_raw_string())
 Tag 0.1""", x.as_raw_string())
 
 
 
 
-class TagParseTests(unittest.TestCase):
-
-    def test_parse_ctime(self):
-        x = Tag()
-        x.set_raw_string("""object a38d6181ff27824c79fc7df825164a212eff6a3f
-type commit
-tag v2.6.22-rc7
-tagger Linus Torvalds <torvalds@woody.linux-foundation.org> Sun Jul 1 12:54:34 2007 -0700
-
-Linux 2.6.22-rc7
+default_tagger = ('Linus Torvalds <torvalds@woody.linux-foundation.org> '
+                  '1183319674 -0700')
+default_message = """Linux 2.6.22-rc7
 -----BEGIN PGP SIGNATURE-----
 -----BEGIN PGP SIGNATURE-----
 Version: GnuPG v1.4.7 (GNU/Linux)
 Version: GnuPG v1.4.7 (GNU/Linux)
 
 
@@ -306,39 +492,136 @@ iD8DBQBGiAaAF3YsRnbiHLsRAitMAKCiLboJkQECM/jpYsY3WPfvUgLXkACgg3ql
 OK2XeQOiEeXtT76rV4t2WR4=
 OK2XeQOiEeXtT76rV4t2WR4=
 =ivrA
 =ivrA
 -----END PGP SIGNATURE-----
 -----END PGP SIGNATURE-----
-""")
-        self.assertEquals("Linus Torvalds <torvalds@woody.linux-foundation.org>", x.tagger)
+"""
+
+
+class TagParseTests(ShaFileCheckTests):
+    def make_tag_lines(self,
+                       object_sha="a38d6181ff27824c79fc7df825164a212eff6a3f",
+                       object_type_name="commit",
+                       name="v2.6.22-rc7",
+                       tagger=default_tagger,
+                       message=default_message):
+        lines = []
+        if object_sha is not None:
+            lines.append("object %s" % object_sha)
+        if object_type_name is not None:
+            lines.append("type %s" % object_type_name)
+        if name is not None:
+            lines.append("tag %s" % name)
+        if tagger is not None:
+            lines.append("tagger %s" % tagger)
+        lines.append("")
+        if message is not None:
+            lines.append(message)
+        return lines
+
+    def make_tag_text(self, **kwargs):
+        return "\n".join(self.make_tag_lines(**kwargs))
+
+    def test_parse(self):
+        x = Tag()
+        x.set_raw_string(self.make_tag_text())
+        self.assertEquals(
+            "Linus Torvalds <torvalds@woody.linux-foundation.org>", x.tagger)
         self.assertEquals("v2.6.22-rc7", x.name)
         self.assertEquals("v2.6.22-rc7", x.name)
+        object_type, object_sha = x.object
+        self.assertEquals("a38d6181ff27824c79fc7df825164a212eff6a3f",
+                          object_sha)
+        self.assertEquals(Commit, object_type)
+        self.assertEquals(datetime.datetime.utcfromtimestamp(x.tag_time),
+                          datetime.datetime(2007, 7, 1, 19, 54, 34))
+        self.assertEquals(-25200, x.tag_timezone)
 
 
     def test_parse_no_tagger(self):
     def test_parse_no_tagger(self):
         x = Tag()
         x = Tag()
-        x.set_raw_string("""object a38d6181ff27824c79fc7df825164a212eff6a3f
-type commit
-tag v2.6.22-rc7
-
-Linux 2.6.22-rc7
------BEGIN PGP SIGNATURE-----
-Version: GnuPG v1.4.7 (GNU/Linux)
-
-iD8DBQBGiAaAF3YsRnbiHLsRAitMAKCiLboJkQECM/jpYsY3WPfvUgLXkACgg3ql
-OK2XeQOiEeXtT76rV4t2WR4=
-=ivrA
------END PGP SIGNATURE-----
-""")
+        x.set_raw_string(self.make_tag_text(tagger=None))
         self.assertEquals(None, x.tagger)
         self.assertEquals(None, x.tagger)
         self.assertEquals("v2.6.22-rc7", x.name)
         self.assertEquals("v2.6.22-rc7", x.name)
 
 
+    def test_check(self):
+        self.assertCheckSucceeds(Tag, self.make_tag_text())
+        self.assertCheckFails(Tag, self.make_tag_text(object_sha=None))
+        self.assertCheckFails(Tag, self.make_tag_text(object_type_name=None))
+        self.assertCheckFails(Tag, self.make_tag_text(name=None))
+        self.assertCheckFails(Tag, self.make_tag_text(name=''))
+        self.assertCheckFails(Tag, self.make_tag_text(
+          object_type_name="foobar"))
+        self.assertCheckFails(Tag, self.make_tag_text(
+          tagger="some guy without an email address 1183319674 -0700"))
+        self.assertCheckFails(Tag, self.make_tag_text(
+          tagger=("Linus Torvalds <torvalds@woody.linux-foundation.org> "
+                  "Sun 7 Jul 2007 12:54:34 +0700")))
+        self.assertCheckFails(Tag, self.make_tag_text(object_sha="xxx"))
+
+    def test_check_duplicates(self):
+        # duplicate each of the header fields
+        for i in xrange(4):
+            lines = self.make_tag_lines()
+            lines.insert(i, lines[i])
+            self.assertCheckFails(Tag, '\n'.join(lines))
+
+    def test_check_order(self):
+        lines = self.make_tag_lines()
+        headers = lines[:4]
+        rest = lines[4:]
+        # of all possible permutations, ensure only the original succeeds
+        for perm in permutations(headers):
+            perm = list(perm)
+            text = '\n'.join(perm + rest)
+            if perm == headers:
+                self.assertCheckSucceeds(Tag, text)
+            else:
+                self.assertCheckFails(Tag, text)
+
+
+class CheckTests(unittest.TestCase):
+
+    def test_check_hexsha(self):
+        check_hexsha(a_sha, "failed to check good sha")
+        self.assertRaises(ObjectFormatException, check_hexsha, '1' * 39,
+                          'sha too short')
+        self.assertRaises(ObjectFormatException, check_hexsha, '1' * 41,
+                          'sha too long')
+        self.assertRaises(ObjectFormatException, check_hexsha, 'x' * 40,
+                          'invalid characters')
+
+    def test_check_identity(self):
+        check_identity("Dave Borowitz <dborowitz@google.com>",
+                       "failed to check good identity")
+        check_identity("<dborowitz@google.com>",
+                       "failed to check good identity")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz", "no email")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <dborowitz", "incomplete email")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "dborowitz@google.com>", "incomplete email")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <<dborowitz@google.com>", "typo")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <dborowitz@google.com>>", "typo")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <dborowitz@google.com>xxx",
+                          "trailing characters")
+
 
 
 class TimezoneTests(unittest.TestCase):
 class TimezoneTests(unittest.TestCase):
 
 
     def test_parse_timezone_utc(self):
     def test_parse_timezone_utc(self):
-        self.assertEquals(0, parse_timezone("+0000"))
+        self.assertEquals((0, False), parse_timezone("+0000"))
+
+    def test_parse_timezone_utc_negative(self):
+        self.assertEquals((0, True), parse_timezone("-0000"))
 
 
     def test_generate_timezone_utc(self):
     def test_generate_timezone_utc(self):
         self.assertEquals("+0000", format_timezone(0))
         self.assertEquals("+0000", format_timezone(0))
 
 
+    def test_generate_timezone_utc_negative(self):
+        self.assertEquals("-0000", format_timezone(0, True))
+
     def test_parse_timezone_cet(self):
     def test_parse_timezone_cet(self):
-        self.assertEquals(60 * 60, parse_timezone("+0100"))
+        self.assertEquals((60 * 60, False), parse_timezone("+0100"))
 
 
     def test_format_timezone_cet(self):
     def test_format_timezone_cet(self):
         self.assertEquals("+0100", format_timezone(60 * 60))
         self.assertEquals("+0100", format_timezone(60 * 60))
@@ -347,10 +630,12 @@ class TimezoneTests(unittest.TestCase):
         self.assertEquals("-0400", format_timezone(-4 * 60 * 60))
         self.assertEquals("-0400", format_timezone(-4 * 60 * 60))
 
 
     def test_parse_timezone_pdt(self):
     def test_parse_timezone_pdt(self):
-        self.assertEquals(-4 * 60 * 60, parse_timezone("-0400"))
+        self.assertEquals((-4 * 60 * 60, False), parse_timezone("-0400"))
 
 
     def test_format_timezone_pdt_half(self):
     def test_format_timezone_pdt_half(self):
-        self.assertEquals("-0440", format_timezone(int(((-4 * 60) - 40) * 60)))
+        self.assertEquals("-0440",
+            format_timezone(int(((-4 * 60) - 40) * 60)))
 
 
     def test_parse_timezone_pdt_half(self):
     def test_parse_timezone_pdt_half(self):
-        self.assertEquals(((-4 * 60) - 40) * 60, parse_timezone("-0440"))
+        self.assertEquals((((-4 * 60) - 40) * 60, False),
+            parse_timezone("-0440"))

+ 3 - 4
dulwich/tests/test_pack.py

@@ -183,13 +183,13 @@ class TestPack(PackTests):
         """Tests random access for non-delta objects"""
         """Tests random access for non-delta objects"""
         p = self.get_pack(pack1_sha)
         p = self.get_pack(pack1_sha)
         obj = p[a_sha]
         obj = p[a_sha]
-        self.assertEqual(obj._type, 'blob')
+        self.assertEqual(obj.type_name, 'blob')
         self.assertEqual(obj.sha().hexdigest(), a_sha)
         self.assertEqual(obj.sha().hexdigest(), a_sha)
         obj = p[tree_sha]
         obj = p[tree_sha]
-        self.assertEqual(obj._type, 'tree')
+        self.assertEqual(obj.type_name, 'tree')
         self.assertEqual(obj.sha().hexdigest(), tree_sha)
         self.assertEqual(obj.sha().hexdigest(), tree_sha)
         obj = p[commit_sha]
         obj = p[commit_sha]
-        self.assertEqual(obj._type, 'commit')
+        self.assertEqual(obj.type_name, 'commit')
         self.assertEqual(obj.sha().hexdigest(), commit_sha)
         self.assertEqual(obj.sha().hexdigest(), commit_sha)
 
 
     def test_copy(self):
     def test_copy(self):
@@ -285,4 +285,3 @@ class ZlibTests(unittest.TestCase):
     def test_simple_decompress(self):
     def test_simple_decompress(self):
         self.assertEquals((["tree 4ada885c9196b6b6fa08744b5862bf92896fc002\nparent None\nauthor Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\ncommitter Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\n\nProvide replacement for mmap()'s offset argument."], 158, 'Z'), 
         self.assertEquals((["tree 4ada885c9196b6b6fa08744b5862bf92896fc002\nparent None\nauthor Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\ncommitter Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\n\nProvide replacement for mmap()'s offset argument."], 158, 'Z'), 
         read_zlib_chunks(StringIO(TEST_COMP1).read, 229))
         read_zlib_chunks(StringIO(TEST_COMP1).read, 229))
-

+ 22 - 14
dulwich/tests/test_repository.py

@@ -92,16 +92,16 @@ class RepositoryTests(unittest.TestCase):
     def test_head(self):
     def test_head(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
-  
+
     def test_get_object(self):
     def test_get_object(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         obj = r.get_object(r.head())
         obj = r.get_object(r.head())
-        self.assertEqual(obj._type, 'commit')
-  
+        self.assertEqual(obj.type_name, 'commit')
+
     def test_get_object_non_existant(self):
     def test_get_object_non_existant(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         self.assertRaises(KeyError, r.get_object, missing_sha)
         self.assertRaises(KeyError, r.get_object, missing_sha)
-  
+
     def test_commit(self):
     def test_commit(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         warnings.simplefilter("ignore", DeprecationWarning)
         warnings.simplefilter("ignore", DeprecationWarning)
@@ -109,8 +109,8 @@ class RepositoryTests(unittest.TestCase):
             obj = r.commit(r.head())
             obj = r.commit(r.head())
         finally:
         finally:
             warnings.resetwarnings()
             warnings.resetwarnings()
-        self.assertEqual(obj._type, 'commit')
-  
+        self.assertEqual(obj.type_name, 'commit')
+
     def test_commit_not_commit(self):
     def test_commit_not_commit(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         warnings.simplefilter("ignore", DeprecationWarning)
         warnings.simplefilter("ignore", DeprecationWarning)
@@ -119,7 +119,7 @@ class RepositoryTests(unittest.TestCase):
                 r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
                 r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
         finally:
         finally:
             warnings.resetwarnings()
             warnings.resetwarnings()
-  
+
     def test_tree(self):
     def test_tree(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         commit = r[r.head()]
         commit = r[r.head()]
@@ -128,9 +128,9 @@ class RepositoryTests(unittest.TestCase):
             tree = r.tree(commit.tree)
             tree = r.tree(commit.tree)
         finally:
         finally:
             warnings.resetwarnings()
             warnings.resetwarnings()
-        self.assertEqual(tree._type, 'tree')
+        self.assertEqual(tree.type_name, 'tree')
         self.assertEqual(tree.sha().hexdigest(), commit.tree)
         self.assertEqual(tree.sha().hexdigest(), commit.tree)
-  
+
     def test_tree_not_tree(self):
     def test_tree_not_tree(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         warnings.simplefilter("ignore", DeprecationWarning)
         warnings.simplefilter("ignore", DeprecationWarning)
@@ -147,10 +147,10 @@ class RepositoryTests(unittest.TestCase):
             tag = r.tag(tag_sha)
             tag = r.tag(tag_sha)
         finally:
         finally:
             warnings.resetwarnings()
             warnings.resetwarnings()
-        self.assertEqual(tag._type, 'tag')
+        self.assertEqual(tag.type_name, 'tag')
         self.assertEqual(tag.sha().hexdigest(), tag_sha)
         self.assertEqual(tag.sha().hexdigest(), tag_sha)
-        obj_type, obj_sha = tag.object
-        self.assertEqual(obj_type, objects.Commit)
+        obj_class, obj_sha = tag.object
+        self.assertEqual(obj_class, objects.Commit)
         self.assertEqual(obj_sha, r.head())
         self.assertEqual(obj_sha, r.head())
 
 
     def test_tag_not_tag(self):
     def test_tag_not_tag(self):
@@ -190,9 +190,9 @@ class RepositoryTests(unittest.TestCase):
             blob = r.get_blob(blob_sha)
             blob = r.get_blob(blob_sha)
         finally:
         finally:
             warnings.resetwarnings()
             warnings.resetwarnings()
-        self.assertEqual(blob._type, 'blob')
+        self.assertEqual(blob.type_name, 'blob')
         self.assertEqual(blob.sha().hexdigest(), blob_sha)
         self.assertEqual(blob.sha().hexdigest(), blob_sha)
-  
+
     def test_get_blob_notblob(self):
     def test_get_blob_notblob(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         warnings.simplefilter("ignore", DeprecationWarning)
         warnings.simplefilter("ignore", DeprecationWarning)
@@ -556,3 +556,11 @@ class RefsContainerTests(unittest.TestCase):
             self._refs.remove_if_equals('refs/tags/refs-0.1',
             self._refs.remove_if_equals('refs/tags/refs-0.1',
             'df6800012397fb85c56e7418dd4eb9405dee075c'))
             'df6800012397fb85c56e7418dd4eb9405dee075c'))
         self.assertRaises(KeyError, lambda: self._refs['refs/tags/refs-0.1'])
         self.assertRaises(KeyError, lambda: self._refs['refs/tags/refs-0.1'])
+
+    def test_read_ref(self):
+        self.assertEqual('ref: refs/heads/master', self._refs.read_ref("HEAD"))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec', 
+            self._refs.read_ref("refs/heads/packed"))
+        self.assertEqual(None,
+            self._refs.read_ref("nonexistant"))
+

+ 16 - 12
dulwich/tests/test_server.py

@@ -26,12 +26,15 @@ from dulwich.errors import (
     GitProtocolError,
     GitProtocolError,
     )
     )
 from dulwich.server import (
 from dulwich.server import (
-    UploadPackHandler,
+    Backend,
+    DictBackend,
+    BackendRepo,
     Handler,
     Handler,
-    ProtocolGraphWalker,
-    SingleAckGraphWalkerImpl,
     MultiAckGraphWalkerImpl,
     MultiAckGraphWalkerImpl,
     MultiAckDetailedGraphWalkerImpl,
     MultiAckDetailedGraphWalkerImpl,
+    ProtocolGraphWalker,
+    SingleAckGraphWalkerImpl,
+    UploadPackHandler,
     )
     )
 
 
 
 
@@ -76,7 +79,7 @@ class TestProto(object):
 class HandlerTestCase(TestCase):
 class HandlerTestCase(TestCase):
 
 
     def setUp(self):
     def setUp(self):
-        self._handler = Handler(None, None, None)
+        self._handler = Handler(Backend(), None, None)
         self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
         self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
         self._handler.required_capabilities = lambda: ('cap2',)
         self._handler.required_capabilities = lambda: ('cap2',)
 
 
@@ -119,7 +122,9 @@ class HandlerTestCase(TestCase):
 class UploadPackHandlerTestCase(TestCase):
 class UploadPackHandlerTestCase(TestCase):
 
 
     def setUp(self):
     def setUp(self):
-        self._handler = UploadPackHandler(None, None, None)
+        self._backend = DictBackend({"/": BackendRepo()})
+        self._handler = UploadPackHandler(self._backend,
+                ["/", "host=lolcathost"], None, None)
         self._handler.proto = TestProto()
         self._handler.proto = TestProto()
 
 
     def test_progress(self):
     def test_progress(self):
@@ -170,11 +175,9 @@ class TestCommit(object):
 
 
     def __init__(self, sha, parents, commit_time):
     def __init__(self, sha, parents, commit_time):
         self.id = sha
         self.id = sha
-        self._parents = parents
+        self.parents = parents
         self.commit_time = commit_time
         self.commit_time = commit_time
-
-    def get_parents(self):
-        return self._parents
+        self.type_name = "commit"
 
 
     def __repr__(self):
     def __repr__(self):
         return '%s(%s)' % (self.__class__.__name__, self._sha)
         return '%s(%s)' % (self.__class__.__name__, self._sha)
@@ -223,7 +226,8 @@ class ProtocolGraphWalkerTestCase(TestCase):
             }
             }
 
 
         self._walker = ProtocolGraphWalker(
         self._walker = ProtocolGraphWalker(
-            TestUploadPackHandler(self._objects, TestProto()))
+            TestUploadPackHandler(self._objects, TestProto()),
+            self._objects, None)
 
 
     def test_is_satisfied_no_haves(self):
     def test_is_satisfied_no_haves(self):
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))
@@ -275,7 +279,7 @@ class ProtocolGraphWalkerTestCase(TestCase):
             'want %s' % TWO,
             'want %s' % TWO,
             ])
             ])
         heads = {'ref1': ONE, 'ref2': TWO, 'ref3': THREE}
         heads = {'ref1': ONE, 'ref2': TWO, 'ref3': THREE}
-        self._walker.handler.backend.repo.peeled = heads
+        self._walker.get_peeled = heads.get
         self.assertEquals([ONE, TWO], self._walker.determine_wants(heads))
         self.assertEquals([ONE, TWO], self._walker.determine_wants(heads))
 
 
         self._walker.proto.set_output(['want %s multi_ack' % FOUR])
         self._walker.proto.set_output(['want %s multi_ack' % FOUR])
@@ -295,7 +299,7 @@ class ProtocolGraphWalkerTestCase(TestCase):
         # advertise branch tips plus tag
         # advertise branch tips plus tag
         heads = {'ref4': FOUR, 'ref5': FIVE, 'tag6': SIX}
         heads = {'ref4': FOUR, 'ref5': FIVE, 'tag6': SIX}
         peeled = {'ref4': FOUR, 'ref5': FIVE, 'tag6': FIVE}
         peeled = {'ref4': FOUR, 'ref5': FIVE, 'tag6': FIVE}
-        self._walker.handler.backend.repo.peeled = peeled
+        self._walker.get_peeled = peeled.get
         self._walker.determine_wants(heads)
         self._walker.determine_wants(heads)
         lines = []
         lines = []
         while True:
         while True:

+ 21 - 18
dulwich/tests/test_web.py

@@ -23,7 +23,6 @@ import re
 from unittest import TestCase
 from unittest import TestCase
 
 
 from dulwich.objects import (
 from dulwich.objects import (
-    Tag,
     Blob,
     Blob,
     )
     )
 from dulwich.web import (
 from dulwich.web import (
@@ -96,15 +95,11 @@ class DumbHandlersTestCase(WebTestCase):
         self._environ['QUERY_STRING'] = ''
         self._environ['QUERY_STRING'] = ''
 
 
         class TestTag(object):
         class TestTag(object):
-            type = Tag().type
-
-            def __init__(self, sha, obj_type, obj_sha):
+            def __init__(self, sha, obj_class, obj_sha):
                 self.sha = lambda: sha
                 self.sha = lambda: sha
-                self.object = (obj_type, obj_sha)
+                self.object = (obj_class, obj_sha)
 
 
         class TestBlob(object):
         class TestBlob(object):
-            type = Blob().type
-
             def __init__(self, sha):
             def __init__(self, sha):
                 self.sha = lambda: sha
                 self.sha = lambda: sha
 
 
@@ -112,9 +107,10 @@ class DumbHandlersTestCase(WebTestCase):
         blob2 = TestBlob('222')
         blob2 = TestBlob('222')
         blob3 = TestBlob('333')
         blob3 = TestBlob('333')
 
 
-        tag1 = TestTag('aaa', TestBlob.type, '222')
+        tag1 = TestTag('aaa', Blob, '222')
 
 
         class TestRepo(object):
         class TestRepo(object):
+
             def __init__(self, objects, peeled):
             def __init__(self, objects, peeled):
                 self._objects = dict((o.sha(), o) for o in objects)
                 self._objects = dict((o.sha(), o) for o in objects)
                 self._peeled = peeled
                 self._peeled = peeled
@@ -125,6 +121,14 @@ class DumbHandlersTestCase(WebTestCase):
             def __getitem__(self, sha):
             def __getitem__(self, sha):
                 return self._objects[sha]
                 return self._objects[sha]
 
 
+            def get_refs(self):
+                return {
+                    'HEAD': '000',
+                    'refs/heads/master': blob1.sha(),
+                    'refs/tags/tag-tag': tag1.sha(),
+                    'refs/tags/blob-tag': blob3.sha(),
+                    }
+
         class TestBackend(object):
         class TestBackend(object):
             def __init__(self):
             def __init__(self):
                 objects = [blob1, blob2, blob3, tag1]
                 objects = [blob1, blob2, blob3, tag1]
@@ -135,19 +139,16 @@ class DumbHandlersTestCase(WebTestCase):
                     'refs/tags/blob-tag': blob3.sha(),
                     'refs/tags/blob-tag': blob3.sha(),
                     })
                     })
 
 
-            def get_refs(self):
-                return {
-                    'HEAD': '000',
-                    'refs/heads/master': blob1.sha(),
-                    'refs/tags/tag-tag': tag1.sha(),
-                    'refs/tags/blob-tag': blob3.sha(),
-                    }
+            def open_repository(self, path):
+                assert path == '/'
+                return self.repo
 
 
+        mat = re.search('.*', '//info/refs')
         self.assertEquals(['111\trefs/heads/master\n',
         self.assertEquals(['111\trefs/heads/master\n',
                            '333\trefs/tags/blob-tag\n',
                            '333\trefs/tags/blob-tag\n',
                            'aaa\trefs/tags/tag-tag\n',
                            'aaa\trefs/tags/tag-tag\n',
                            '222\trefs/tags/tag-tag^{}\n'],
                            '222\trefs/tags/tag-tag^{}\n'],
-                          list(get_info_refs(self._req, TestBackend(), None)))
+                          list(get_info_refs(self._req, TestBackend(), mat)))
 
 
 
 
 class SmartHandlersTestCase(WebTestCase):
 class SmartHandlersTestCase(WebTestCase):
@@ -163,8 +164,9 @@ class SmartHandlersTestCase(WebTestCase):
                 self._handler.write('pkt-line: %s' % line)
                 self._handler.write('pkt-line: %s' % line)
 
 
     class _TestUploadPackHandler(object):
     class _TestUploadPackHandler(object):
-        def __init__(self, backend, read, write, stateless_rpc=False,
+        def __init__(self, backend, args, read, write, stateless_rpc=False,
                      advertise_refs=False):
                      advertise_refs=False):
+            self.args = args
             self.read = read
             self.read = read
             self.write = write
             self.write = write
             self.proto = SmartHandlersTestCase.TestProtocol(self)
             self.proto = SmartHandlersTestCase.TestProtocol(self)
@@ -217,7 +219,8 @@ class SmartHandlersTestCase(WebTestCase):
         self._environ['wsgi.input'] = StringIO('foo')
         self._environ['wsgi.input'] = StringIO('foo')
         self._environ['QUERY_STRING'] = 'service=git-upload-pack'
         self._environ['QUERY_STRING'] = 'service=git-upload-pack'
 
 
-        output = ''.join(get_info_refs(self._req, 'backend', None,
+        mat = re.search('.*', '/git-upload-pack')
+        output = ''.join(get_info_refs(self._req, 'backend', mat,
                                        services=self.services()))
                                        services=self.services()))
         self.assertEquals(('pkt-line: # service=git-upload-pack\n'
         self.assertEquals(('pkt-line: # service=git-upload-pack\n'
                            'flush-pkt\n'
                            'flush-pkt\n'

+ 35 - 13
dulwich/web.py

@@ -21,8 +21,11 @@
 from cStringIO import StringIO
 from cStringIO import StringIO
 import re
 import re
 import time
 import time
-import urlparse
 
 
+try:
+    from urlparse import parse_qs
+except ImportError:
+    from dulwich.misc import parse_qs
 from dulwich.server import (
 from dulwich.server import (
     ReceivePackHandler,
     ReceivePackHandler,
     UploadPackHandler,
     UploadPackHandler,
@@ -33,7 +36,7 @@ HTTP_NOT_FOUND = '404 Not Found'
 HTTP_FORBIDDEN = '403 Forbidden'
 HTTP_FORBIDDEN = '403 Forbidden'
 
 
 
 
-def date_time_string(self, timestamp=None):
+def date_time_string(timestamp=None):
     # Based on BaseHTTPServer.py in python2.5
     # Based on BaseHTTPServer.py in python2.5
     weekdays = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
     weekdays = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
     months = [None,
     months = [None,
@@ -46,6 +49,22 @@ def date_time_string(self, timestamp=None):
             weekdays[wd], day, months[month], year, hh, mm, ss)
             weekdays[wd], day, months[month], year, hh, mm, ss)
 
 
 
 
+def url_prefix(mat):
+    """Extract the URL prefix from a regex match.
+
+    :param mat: A regex match object.
+    :returns: The URL prefix, defined as the text before the match in the
+        original string. Normalized to start with one leading slash and end with
+        zero.
+    """
+    return '/' + mat.string[:mat.start()].strip('/')
+
+
+def get_repo(backend, mat):
+    """Get a Repo instance for the given backend and URL regex match."""
+    return backend.open_repository(url_prefix(mat))
+
+
 def send_file(req, f, content_type):
 def send_file(req, f, content_type):
     """Send a file-like object to the request output.
     """Send a file-like object to the request output.
 
 
@@ -73,13 +92,13 @@ def send_file(req, f, content_type):
 
 
 def get_text_file(req, backend, mat):
 def get_text_file(req, backend, mat):
     req.nocache()
     req.nocache()
-    return send_file(req, backend.repo.get_named_file(mat.group()),
+    return send_file(req, get_repo(backend, mat).get_named_file(mat.group()),
                      'text/plain')
                      'text/plain')
 
 
 
 
 def get_loose_object(req, backend, mat):
 def get_loose_object(req, backend, mat):
     sha = mat.group(1) + mat.group(2)
     sha = mat.group(1) + mat.group(2)
-    object_store = backend.object_store
+    object_store = get_repo(backend, mat).object_store
     if not object_store.contains_loose(sha):
     if not object_store.contains_loose(sha):
         yield req.not_found('Object not found')
         yield req.not_found('Object not found')
         return
         return
@@ -94,13 +113,13 @@ def get_loose_object(req, backend, mat):
 
 
 def get_pack_file(req, backend, mat):
 def get_pack_file(req, backend, mat):
     req.cache_forever()
     req.cache_forever()
-    return send_file(req, backend.repo.get_named_file(mat.group()),
+    return send_file(req, get_repo(backend, mat).get_named_file(mat.group()),
                      'application/x-git-packed-objects')
                      'application/x-git-packed-objects')
 
 
 
 
 def get_idx_file(req, backend, mat):
 def get_idx_file(req, backend, mat):
     req.cache_forever()
     req.cache_forever()
-    return send_file(req, backend.repo.get_named_file(mat.group()),
+    return send_file(req, get_repo(backend, mat).get_named_file(mat.group()),
                      'application/x-git-packed-objects-toc')
                      'application/x-git-packed-objects-toc')
 
 
 
 
@@ -109,7 +128,7 @@ default_services = {'git-upload-pack': UploadPackHandler,
 def get_info_refs(req, backend, mat, services=None):
 def get_info_refs(req, backend, mat, services=None):
     if services is None:
     if services is None:
         services = default_services
         services = default_services
-    params = urlparse.parse_qs(req.environ['QUERY_STRING'])
+    params = parse_qs(req.environ['QUERY_STRING'])
     service = params.get('service', [None])[0]
     service = params.get('service', [None])[0]
     if service and not req.dumb:
     if service and not req.dumb:
         handler_cls = services.get(service, None)
         handler_cls = services.get(service, None)
@@ -120,7 +139,8 @@ def get_info_refs(req, backend, mat, services=None):
         req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
         req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
         output = StringIO()
         output = StringIO()
         dummy_input = StringIO()  # GET request, handler doesn't need to read
         dummy_input = StringIO()  # GET request, handler doesn't need to read
-        handler = handler_cls(backend, dummy_input.read, output.write,
+        handler = handler_cls(backend, [url_prefix(mat)],
+                              dummy_input.read, output.write,
                               stateless_rpc=True, advertise_refs=True)
                               stateless_rpc=True, advertise_refs=True)
         handler.proto.write_pkt_line('# service=%s\n' % service)
         handler.proto.write_pkt_line('# service=%s\n' % service)
         handler.proto.write_pkt_line(None)
         handler.proto.write_pkt_line(None)
@@ -131,18 +151,19 @@ def get_info_refs(req, backend, mat, services=None):
         # TODO: select_getanyfile() (see http-backend.c)
         # TODO: select_getanyfile() (see http-backend.c)
         req.nocache()
         req.nocache()
         req.respond(HTTP_OK, 'text/plain')
         req.respond(HTTP_OK, 'text/plain')
-        refs = backend.get_refs()
+        repo = get_repo(backend, mat)
+        refs = repo.get_refs()
         for name in sorted(refs.iterkeys()):
         for name in sorted(refs.iterkeys()):
             # get_refs() includes HEAD as a special case, but we don't want to
             # get_refs() includes HEAD as a special case, but we don't want to
             # advertise it
             # advertise it
             if name == 'HEAD':
             if name == 'HEAD':
                 continue
                 continue
             sha = refs[name]
             sha = refs[name]
-            o = backend.repo[sha]
+            o = repo[sha]
             if not o:
             if not o:
                 continue
                 continue
             yield '%s\t%s\n' % (sha, name)
             yield '%s\t%s\n' % (sha, name)
-            peeled_sha = backend.repo.get_peeled(name)
+            peeled_sha = repo.get_peeled(name)
             if peeled_sha != sha:
             if peeled_sha != sha:
                 yield '%s\t%s^{}\n' % (peeled_sha, name)
                 yield '%s\t%s^{}\n' % (peeled_sha, name)
 
 
@@ -150,7 +171,7 @@ def get_info_refs(req, backend, mat, services=None):
 def get_info_packs(req, backend, mat):
 def get_info_packs(req, backend, mat):
     req.nocache()
     req.nocache()
     req.respond(HTTP_OK, 'text/plain')
     req.respond(HTTP_OK, 'text/plain')
-    for pack in backend.object_store.packs:
+    for pack in get_repo(backend, mat).object_store.packs:
         yield 'P pack-%s.pack\n' % pack.name()
         yield 'P pack-%s.pack\n' % pack.name()
 
 
 
 
@@ -195,7 +216,8 @@ def handle_service_request(req, backend, mat, services=None):
     # content-length
     # content-length
     if 'CONTENT_LENGTH' in req.environ:
     if 'CONTENT_LENGTH' in req.environ:
         input = _LengthLimitedFile(input, int(req.environ['CONTENT_LENGTH']))
         input = _LengthLimitedFile(input, int(req.environ['CONTENT_LENGTH']))
-    handler = handler_cls(backend, input.read, output.write, stateless_rpc=True)
+    handler = handler_cls(backend, [url_prefix(mat)], input.read, output.write,
+                          stateless_rpc=True)
     handler.handle()
     handler.handle()
     yield output.getvalue()
     yield output.getvalue()