Jelajahi Sumber

New upstream release.

Jelmer Vernooij 15 tahun lalu
induk
melakukan
f517a43131
69 mengubah file dengan 5743 tambahan dan 1545 penghapusan
  1. 1 0
      .bzrignore
  2. 3 0
      .testr.conf
  3. 2 1
      AUTHORS
  4. 6 3
      HACKING
  5. 7 2
      Makefile
  6. 93 4
      NEWS
  7. 4 0
      README
  8. 4 3
      bin/dul-daemon
  9. 2 2
      bin/dul-web
  10. 27 35
      bin/dulwich
  11. 6 0
      debian/changelog
  12. 1 1
      dulwich/__init__.py
  13. 131 15
      dulwich/_objects.c
  14. 50 4
      dulwich/_pack.c
  15. 106 40
      dulwich/client.py
  16. 56 21
      dulwich/errors.py
  17. 82 0
      dulwich/fastexport.py
  18. 41 2
      dulwich/file.py
  19. 16 0
      dulwich/index.py
  20. 13 3
      dulwich/misc.py
  21. 128 67
      dulwich/object_store.py
  22. 662 250
      dulwich/objects.py
  23. 409 269
      dulwich/pack.py
  24. 67 5
      dulwich/patch.py
  25. 116 2
      dulwich/protocol.py
  26. 371 79
      dulwich/repo.py
  27. 251 143
      dulwich/server.py
  28. 30 0
      dulwich/tests/__init__.py
  29. 157 0
      dulwich/tests/compat/server_utils.py
  30. 179 0
      dulwich/tests/compat/test_client.py
  31. 73 0
      dulwich/tests/compat/test_pack.py
  32. 131 0
      dulwich/tests/compat/test_repository.py
  33. 78 0
      dulwich/tests/compat/test_server.py
  34. 116 0
      dulwich/tests/compat/test_web.py
  35. 196 0
      dulwich/tests/compat/utils.py
  36. TEMPAT SAMPAH
      dulwich/tests/data/blobs/11/11111111111111111111111111111111111111
  37. 0 0
      dulwich/tests/data/blobs/6f/670c0fb53f9463760b7295fbb814e965fb20c8
  38. 0 0
      dulwich/tests/data/blobs/95/4a536f7819d40e6f637f849ee187dd10066349
  39. 0 0
      dulwich/tests/data/blobs/e6/9de29bb2d1d6434b8b29ae775ad8c2e48c5391
  40. 0 0
      dulwich/tests/data/commits/0d/89f20333fbb1d2f3a94da77f4981373d8f4310
  41. 0 0
      dulwich/tests/data/commits/5d/ac377bdded4c9aeb8dff595f0faeebcc8498cc
  42. 0 0
      dulwich/tests/data/commits/60/dacdc733de308bb77bb76ce0fb0f9b44c9769e
  43. 2 0
      dulwich/tests/data/repos/a.git/objects/28/237f4dc30d0d462658d6b937b08a0f0b6ef55a
  44. 3 0
      dulwich/tests/data/repos/a.git/objects/b0/931cadc54336e78a1d980420e3268903b57a50
  45. 3 0
      dulwich/tests/data/repos/a.git/packed-refs
  46. 1 0
      dulwich/tests/data/repos/a.git/refs/tags/mytag
  47. 3 0
      dulwich/tests/data/repos/refs.git/objects/3e/c9c43c84ff242e3ef4a9fc5bc111fd780a76a8
  48. 5 0
      dulwich/tests/data/repos/refs.git/objects/cd/a609072918d7b70057b6bef9f4c2537843fcfe
  49. 1 0
      dulwich/tests/data/repos/refs.git/packed-refs
  50. 1 0
      dulwich/tests/data/repos/refs.git/refs/tags/refs-0.2
  51. 99 0
      dulwich/tests/data/repos/server_new.export
  52. 57 0
      dulwich/tests/data/repos/server_old.export
  53. 0 0
      dulwich/tests/data/tags/71/033db03a03c6a36721efcf1968dd8f8e0cf023
  54. 0 0
      dulwich/tests/data/trees/70/c190eb48fa8bbb50ddc692a17b44cb781af7f6
  55. 11 5
      dulwich/tests/test_client.py
  56. 78 0
      dulwich/tests/test_fastexport.py
  57. 68 2
      dulwich/tests/test_file.py
  58. 31 11
      dulwich/tests/test_index.py
  59. 33 29
      dulwich/tests/test_object_store.py
  60. 420 104
      dulwich/tests/test_objects.py
  61. 212 86
      dulwich/tests/test_pack.py
  62. 88 0
      dulwich/tests/test_patch.py
  63. 90 7
      dulwich/tests/test_protocol.py
  64. 477 141
      dulwich/tests/test_repository.py
  65. 199 78
      dulwich/tests/test_server.py
  66. 69 55
      dulwich/tests/test_web.py
  67. 86 0
      dulwich/tests/utils.py
  68. 90 75
      dulwich/web.py
  69. 1 1
      setup.py

+ 1 - 0
.bzrignore

@@ -4,3 +4,4 @@ MANIFEST
 dist
 apidocs
 *,cover
+.testrepository

+ 3 - 0
.testr.conf

@@ -0,0 +1,3 @@
+[DEFAULT]
+test_command=PYTHONPATH=. python -m subunit.run $IDLIST
+test_id_list_default=dulwich.tests.test_suite

+ 2 - 1
AUTHORS

@@ -1,3 +1,4 @@
+Jelmer Vernooij <jelmer@samba.org>
 James Westby <jw+debian@jameswestby.net>
 John Carr <john.carr@unrouted.co.uk>
-Jelmer Vernooij <jelmer@samba.org>
+Dave Borowitz <dborowitz@google.com>

+ 6 - 3
HACKING

@@ -1,5 +1,8 @@
 Please follow PEP8 with regard to coding style.
 
-All functionality should be available in pure Python. Optional C implementations
-may be written for performance reasons, but should never replace the Python 
-implementation. The C implementations should follow the kernel/git coding style.
+All functionality should be available in pure Python. Optional C
+implementations may be written for performance reasons, but should never
+replace the Python implementation. The C implementations should follow the
+kernel/git coding style.
+
+Where possible please include updates to NEWS along with your improvements.

+ 7 - 2
Makefile

@@ -2,8 +2,9 @@ PYTHON = python
 SETUP = $(PYTHON) setup.py
 PYDOCTOR ?= pydoctor
 TESTRUNNER = $(shell which nosetests)
+TESTFLAGS =
 
-all: build 
+all: build
 
 doc:: pydoctor
 
@@ -19,9 +20,13 @@ install::
 
 check:: build
 	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) dulwich
+	which git > /dev/null && PYTHONPATH=. $(PYTHON) $(TESTRUNNER) $(TESTFLAGS) -i compat
 
 check-noextensions:: clean
-	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) dulwich
+	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) $(TESTFLAGS) dulwich
+
+check-compat:: build
+	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) $(TESTFLAGS) -i compat
 
 clean::
 	$(SETUP) clean --all

+ 93 - 4
NEWS

@@ -1,8 +1,96 @@
+0.6.0	2010-05-22
+
+note: This list is most likely incomplete for 0.6.0.
+
+ BUG FIXES
+ 
+  * Fix ReceivePackHandler to disallow removing refs without delete-refs.
+    (Dave Borowitz)
+
+  * Deal with capabilities required by the client, even if they 
+    can not be disabled in the server. (Dave Borowitz)
+
+  * Fix trailing newlines in generated patch files.
+    (Jelmer Vernooij)
+
+  * Implement RefsContainer.__contains__. (Jelmer Vernooij)
+
+  * Cope with \r in ref files on Windows. (
+	http://github.com/jelmer/dulwich/issues/#issue/13, Jelmer Vernooij)
+
+  * Fix GitFile breakage on Windows. (Anatoly Techtonik, #557585)
+
+  * Support packed ref deletion with no peeled refs. (Augie Fackler)
+
+  * Fix send pack when there is nothing to fetch. (Augie Fackler)
+
+  * Fix fetch if no progress function is specified. (Augie Fackler)
+
+  * Allow double-staging of files that are deleted in the index. 
+    (Dave Borowitz)
+
+  * Fix RefsContainer.add_if_new to support dangling symrefs.
+    (Dave Borowitz)
+
+  * Non-existant index files in non-bare repositories are now treated as 
+    empty. (Dave Borowitz)
+
+  * Always update ShaFile.id when the contents of the object get changed. 
+    (Jelmer Vernooij)
+
+  * Various Python2.4-compatibility fixes. (Dave Borowitz)
+
+  * Fix thin pack handling. (Dave Borowitz)
+ 
+ FEATURES
+
+  * Add include-tag capability to server. (Dave Borowitz)
+
+  * New dulwich.fastexport module that can generate fastexport 
+    streams. (Jelmer Vernooij)
+
+  * Implemented BaseRepo.__contains__. (Jelmer Vernooij)
+
+  * Add __setitem__ to DictRefsContainer. (Dave Borowitz)
+
+  * Overall improvements checking Git objects. (Dave Borowitz)
+
+  * Packs are now verified while they are received. (Dave Borowitz)
+
+ TESTS
+
+  * Add framework for testing compatibility with C Git. (Dave Borowitz)
+
+  * Add various tests for the use of non-bare repositories. (Dave Borowitz)
+
+  * Cope with diffstat not being available on all platforms. 
+    (Tay Ray Chuan, Jelmer Vernooij)
+
+  * Add make_object and make_commit convenience functions to test utils.
+    (Dave Borowitz)
+
+ API BREAKAGES
+
+  * The 'committer' and 'message' arguments to Repo.do_commit() have 
+    been swapped. 'committer' is now optional. (Jelmer Vernooij)
+
+  * Repo.get_blob, Repo.commit, Repo.tag and Repo.tree are now deprecated.
+    (Jelmer Vernooij)
+
+  * RefsContainer.set_ref() was renamed to RefsContainer.set_symbolic_ref(),
+    for clarity. (Jelmer Vernooij)
+
+ API CHANGES
+
+  * The primary serialization APIs in dulwich.objects now work 
+    with chunks of strings rather than with full-text strings. 
+    (Jelmer Vernooij)
+
 0.5.0	2010-03-03
 
  BUG FIXES
 
-  * Support custom fields in commits.
+  * Support custom fields in commits (readonly). (Jelmer Vernooij)
 
   * Improved ref handling. (Dave Borowitz)
 
@@ -31,13 +119,14 @@
 
  FEATURES
 
-  * Add ObjectStore.iter_tree_contents()
+  * Add ObjectStore.iter_tree_contents(). (Jelmer Vernooij)
 
-  * Add Index.changes_from_tree()
+  * Add Index.changes_from_tree(). (Jelmer Vernooij)
 
-  * Add ObjectStore.tree_changes()
+  * Add ObjectStore.tree_changes(). (Jelmer Vernooij)
 
   * Add functionality for writing patches in dulwich.patch.
+    (Jelmer Vernooij)
 
 0.4.0	2009-10-07
 

+ 4 - 0
README

@@ -21,3 +21,7 @@ The project is named after the part of London that Mr. and Mrs. Git live in
 in the particular Monty Python sketch. It is based on the Python-Git module 
 that James Westby <jw+debian@jameswestby.net> released in 2007 and now 
 maintained by Jelmer Vernooij and John Carr.
+
+Please file bugs in the Dulwich project on Launchpad: 
+
+https://bugs.launchpad.net/dulwich/+filebug

+ 4 - 3
bin/dul-daemon

@@ -19,13 +19,14 @@
 
 import sys
 from dulwich.repo import Repo
-from dulwich.server import GitBackend, TCPGitServer
+from dulwich.server import DictBackend, TCPGitServer
 
 if __name__ == "__main__":
-    gitdir = None
     if len(sys.argv) > 1:
         gitdir = sys.argv[1]
+    else:
+        gitdir = "."
 
-    backend = GitBackend(Repo(gitdir))
+    backend = DictBackend({"/": Repo(gitdir)})
     server = TCPGitServer(backend, 'localhost')
     server.serve_forever()

+ 2 - 2
bin/dul-web

@@ -20,7 +20,7 @@
 import os
 import sys
 from dulwich.repo import Repo
-from dulwich.server import GitBackend
+from dulwich.server import DictBackend
 from dulwich.web import HTTPGitApplication
 from wsgiref.simple_server import make_server
 
@@ -30,7 +30,7 @@ if __name__ == "__main__":
     else:
         gitdir = os.getcwd()
 
-    backend = GitBackend(Repo(gitdir))
+    backend = DictBackend({"/": Repo(gitdir)})
     app = HTTPGitApplication(backend)
     # TODO: allow serving on other ports via command-line flag
     server = make_server('', 8000, app)

+ 27 - 35
bin/dulwich

@@ -1,5 +1,5 @@
 #!/usr/bin/env python
-# dul-daemon - Simple git smart server client
+# dulwich - Simple command-line interface to Dulwich
 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
 # 
 # This program is free software; you can redistribute it and/or
@@ -17,21 +17,25 @@
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # MA  02110-1301, USA.
 
+"""Simple command-line interface to Dulwich>
+
+This is a very simple command-line wrapper for Dulwich. It is by 
+no means intended to be a full-blown Git command-line interface but just 
+a way to test Dulwich.
+"""
+
+import os
 import sys
 from getopt import getopt
 
-def get_transport_and_path(uri):
-    from dulwich.client import TCPGitClient, SSHGitClient, SubprocessGitClient
-    for handler, transport in (("git://", TCPGitClient), ("git+ssh://", SSHGitClient)):
-        if uri.startswith(handler):
-            host, path = uri[len(handler):].split("/", 1)
-            return transport(host), "/"+path
-    # if its not git or git+ssh, try a local url..
-    return SubprocessGitClient(), uri
+from dulwich.client import get_transport_and_path
+from dulwich.errors import ApplyDeltaError
+from dulwich.index import Index
+from dulwich.pack import Pack, sha_to_hex
+from dulwich.repo import Repo
 
 
 def cmd_fetch_pack(args):
-    from dulwich.repo import Repo
     opts, args = getopt(args, "", ["all"])
     opts = dict(opts)
     client, path = get_transport_and_path(args.pop(0))
@@ -45,9 +49,12 @@ def cmd_fetch_pack(args):
 
 
 def cmd_log(args):
-    from dulwich.repo import Repo
     opts, args = getopt(args, "", [])
-    r = Repo(".")
+    if len(args) > 0:
+        path = args.pop(0)
+    else:
+        path = "."
+    r = Repo(path)
     todo = [r.head()]
     done = set()
     while todo:
@@ -56,7 +63,7 @@ def cmd_log(args):
         if sha in done:
             continue
         done.add(sha)
-        commit = r.commit(sha)
+        commit = r[sha]
         print "-" * 50
         print "commit: %s" % sha
         if len(commit.parents) > 1:
@@ -70,11 +77,6 @@ def cmd_log(args):
 
 
 def cmd_dump_pack(args):
-    from dulwich.errors import ApplyDeltaError
-    from dulwich.pack import Pack, sha_to_hex
-    import os
-    import sys
-
     opts, args = getopt(args, "", [])
 
     if args == []:
@@ -98,8 +100,6 @@ def cmd_dump_pack(args):
 
 
 def cmd_dump_index(args):
-    from dulwich.index import Index
-
     opts, args = getopt(args, "", [])
 
     if args == []:
@@ -114,8 +114,6 @@ def cmd_dump_index(args):
 
 
 def cmd_init(args):
-    from dulwich.repo import Repo
-    import os
     opts, args = getopt(args, "", ["--bare"])
     opts = dict(opts)
 
@@ -134,16 +132,13 @@ def cmd_init(args):
 
 
 def cmd_clone(args):
-    from dulwich.repo import Repo
-    import os
-    import sys
     opts, args = getopt(args, "", [])
     opts = dict(opts)
 
     if args == []:
         print "usage: dulwich clone host:path [PATH]"
         sys.exit(1)
-        client, host_path = get_transport_and_path(args.pop(0))
+    client, host_path = get_transport_and_path(args.pop(0))
 
     if len(args) > 0:
         path = args.pop(0)
@@ -152,18 +147,14 @@ def cmd_clone(args):
 
     if not os.path.exists(path):
         os.mkdir(path)
-    Repo.init(path)
-    r = Repo(path)
-    graphwalker = r.get_graph_walker()
-    f, commit = r.object_store.add_pack()
-    client.fetch_pack(host_path, r.object_store.determine_wants_all, 
-                      graphwalker, f.write, sys.stdout.write)
-    commit()
+    r = Repo.init(path)
+    remote_refs = client.fetch(host_path, r,
+        determine_wants=r.object_store.determine_wants_all,
+        progress=sys.stdout.write)
+    r["HEAD"] = remote_refs["HEAD"]
 
 
 def cmd_commit(args):
-    from dulwich.repo import Repo
-    import os
     opts, args = getopt(args, "", ["message"])
     opts = dict(opts)
     r = Repo(".")
@@ -173,6 +164,7 @@ def cmd_commit(args):
                           os.getenv("GIT_AUTHOR_EMAIL"))
     r.do_commit(committer=committer, author=author, message=opts["--message"])
 
+
 commands = {
     "commit": cmd_commit,
     "fetch-pack": cmd_fetch_pack,

+ 6 - 0
debian/changelog

@@ -1,3 +1,9 @@
+dulwich (0.6.0-1) unstable; urgency=low
+
+  * New upstream release.
+
+ -- Jelmer Vernooij <jelmer@debian.org>  Sat, 22 May 2010 23:06:20 +0200
+
 dulwich (0.5.0-1) unstable; urgency=low
 
   * New upstream release.

+ 1 - 1
dulwich/__init__.py

@@ -27,4 +27,4 @@ import protocol
 import repo
 import server
 
-__version__ = (0, 5, 0)
+__version__ = (0, 6, 0)

+ 131 - 15
dulwich/_objects.c

@@ -18,6 +18,12 @@
  */
 
 #include <Python.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+
+#if (PY_VERSION_HEX < 0x02050000)
+typedef int Py_ssize_t;
+#endif
 
 #define bytehex(x) (((x)<0xa)?('0'+(x)):('a'-0xa+(x)))
 
@@ -35,33 +41,37 @@ static PyObject *sha_to_pyhex(const unsigned char *sha)
 
 static PyObject *py_parse_tree(PyObject *self, PyObject *args)
 {
-	char *text, *end;
+	char *text, *start, *end;
 	int len, namelen;
 	PyObject *ret, *item, *name;
 
 	if (!PyArg_ParseTuple(args, "s#", &text, &len))
 		return NULL;
 
-	ret = PyDict_New();
+	/* TODO: currently this returns a list; if memory usage is a concern,
+	* consider rewriting as a custom iterator object */
+	ret = PyList_New(0);
+
 	if (ret == NULL) {
 		return NULL;
 	}
 
+	start = text;
 	end = text + len;
 
-    while (text < end) {
-        long mode;
+	while (text < end) {
+		long mode;
 		mode = strtol(text, &text, 8);
 
 		if (*text != ' ') {
-			PyErr_SetString(PyExc_RuntimeError, "Expected space");
+			PyErr_SetString(PyExc_ValueError, "Expected space");
 			Py_DECREF(ret);
 			return NULL;
 		}
 
 		text++;
 
-        namelen = strlen(text);
+		namelen = strnlen(text, len - (text - start));
 
 		name = PyString_FromStringAndSize(text, namelen);
 		if (name == NULL) {
@@ -69,28 +79,134 @@ static PyObject *py_parse_tree(PyObject *self, PyObject *args)
 			return NULL;
 		}
 
-        item = Py_BuildValue("(lN)", mode, sha_to_pyhex((unsigned char *)text+namelen+1));
-        if (item == NULL) {
-            Py_DECREF(ret);
+		if (text + namelen + 20 >= end) {
+			PyErr_SetString(PyExc_ValueError, "SHA truncated");
+			Py_DECREF(ret);
 			Py_DECREF(name);
-            return NULL;
-        }
-		if (PyDict_SetItem(ret, name, item) == -1) {
+			return NULL;
+		}
+
+		item = Py_BuildValue("(NlN)", name, mode,
+							 sha_to_pyhex((unsigned char *)text+namelen+1));
+		if (item == NULL) {
+			Py_DECREF(ret);
+			Py_DECREF(name);
+			return NULL;
+		}
+		if (PyList_Append(ret, item) == -1) {
 			Py_DECREF(ret);
 			Py_DECREF(item);
 			return NULL;
 		}
-		Py_DECREF(name);
 		Py_DECREF(item);
 
 		text += namelen+21;
-    }
+	}
+
+	return ret;
+}
+
+struct tree_item {
+	const char *name;
+	int mode;
+	PyObject *tuple;
+};
+
+int cmp_tree_item(const void *_a, const void *_b)
+{
+	const struct tree_item *a = _a, *b = _b;
+	const char *remain_a, *remain_b;
+	int ret, common;
+	if (strlen(a->name) > strlen(b->name)) {
+		common = strlen(b->name);
+		remain_a = a->name + common;
+		remain_b = (S_ISDIR(b->mode)?"/":"");
+	} else if (strlen(b->name) > strlen(a->name)) { 
+		common = strlen(a->name);
+		remain_a = (S_ISDIR(a->mode)?"/":"");
+		remain_b = b->name + common;
+	} else { /* strlen(a->name) == strlen(b->name) */
+		common = 0;
+		remain_a = a->name;
+		remain_b = b->name;
+	}
+	ret = strncmp(a->name, b->name, common);
+	if (ret != 0)
+		return ret;
+	return strcmp(remain_a, remain_b);
+}
+
+static PyObject *py_sorted_tree_items(PyObject *self, PyObject *entries)
+{
+	struct tree_item *qsort_entries;
+	int num, i;
+	PyObject *ret;
+	Py_ssize_t pos = 0; 
+	PyObject *key, *value;
+
+	if (!PyDict_Check(entries)) {
+		PyErr_SetString(PyExc_TypeError, "Argument not a dictionary");
+		return NULL;
+	}
+
+	num = PyDict_Size(entries);
+	qsort_entries = malloc(num * sizeof(struct tree_item));
+	if (qsort_entries == NULL) {
+		PyErr_NoMemory();
+		return NULL;
+	}
+
+	i = 0;
+	while (PyDict_Next(entries, &pos, &key, &value)) {
+		PyObject *py_mode, *py_int_mode, *py_sha;
+		
+		if (PyTuple_Size(value) != 2) {
+			PyErr_SetString(PyExc_ValueError, "Tuple has invalid size");
+			free(qsort_entries);
+			return NULL;
+		}
+
+		py_mode = PyTuple_GET_ITEM(value, 0);
+		py_int_mode = PyNumber_Int(py_mode);
+		if (!py_int_mode) {
+			PyErr_SetString(PyExc_TypeError, "Mode is not an integral type");
+			free(qsort_entries);
+			return NULL;
+		}
+
+		py_sha = PyTuple_GET_ITEM(value, 1);
+		if (!PyString_CheckExact(key)) {
+			PyErr_SetString(PyExc_TypeError, "Name is not a string");
+			free(qsort_entries);
+			return NULL;
+		}
+		qsort_entries[i].name = PyString_AS_STRING(key);
+		qsort_entries[i].mode = PyInt_AS_LONG(py_mode);
+		qsort_entries[i].tuple = PyTuple_Pack(3, key, py_mode, py_sha);
+		i++;
+	}
+
+	qsort(qsort_entries, num, sizeof(struct tree_item), cmp_tree_item);
+
+	ret = PyList_New(num);
+	if (ret == NULL) {
+		free(qsort_entries);
+		PyErr_NoMemory();
+		return NULL;
+	}
+
+	for (i = 0; i < num; i++) {
+		PyList_SET_ITEM(ret, i, qsort_entries[i].tuple);
+	}
+
+	free(qsort_entries);
 
-    return ret;
+	return ret;
 }
 
 static PyMethodDef py_objects_methods[] = {
 	{ "parse_tree", (PyCFunction)py_parse_tree, METH_VARARGS, NULL },
+	{ "sorted_tree_items", (PyCFunction)py_sorted_tree_items, METH_O, NULL },
 	{ NULL, NULL, 0, NULL }
 };
 

+ 50 - 4
dulwich/_pack.c

@@ -47,6 +47,29 @@ static size_t get_delta_header_size(uint8_t *delta, int *index, int length)
 	return size;
 }
 
+static PyObject *py_chunked_as_string(PyObject *py_buf)
+{
+	if (PyList_Check(py_buf)) {
+		PyObject *sep = PyString_FromString("");
+		if (sep == NULL) {
+			PyErr_NoMemory();
+			return NULL;
+		}
+		py_buf = _PyString_Join(sep, py_buf);
+		Py_DECREF(sep);
+		if (py_buf == NULL) {
+			PyErr_NoMemory();
+			return NULL;
+		}
+	} else if (PyString_Check(py_buf)) {
+		Py_INCREF(py_buf);
+	} else {
+		PyErr_SetString(PyExc_TypeError,
+			"src_buf is not a string or a list of chunks");
+		return NULL;
+	}
+    return py_buf;
+}
 
 static PyObject *py_apply_delta(PyObject *self, PyObject *args)
 {
@@ -56,23 +79,42 @@ static PyObject *py_apply_delta(PyObject *self, PyObject *args)
 	size_t outindex = 0;
 	int index;
 	uint8_t *out;
-	PyObject *ret;
+	PyObject *ret, *py_src_buf, *py_delta;
 
-	if (!PyArg_ParseTuple(args, "s#s#", (uint8_t *)&src_buf, &src_buf_len, 
-						  (uint8_t *)&delta, &delta_len))
+	if (!PyArg_ParseTuple(args, "OO", &py_src_buf, &py_delta))
 		return NULL;
 
+    py_src_buf = py_chunked_as_string(py_src_buf);
+    if (py_src_buf == NULL)
+        return NULL;
+
+    py_delta = py_chunked_as_string(py_delta);
+    if (py_delta == NULL) {
+        Py_DECREF(py_src_buf);
+        return NULL;
+    }
+
+	src_buf = (uint8_t *)PyString_AS_STRING(py_src_buf);
+	src_buf_len = PyString_GET_SIZE(py_src_buf);
+
+    delta = (uint8_t *)PyString_AS_STRING(py_delta);
+    delta_len = PyString_GET_SIZE(py_delta);
+
     index = 0;
     src_size = get_delta_header_size(delta, &index, delta_len);
     if (src_size != src_buf_len) {
 		PyErr_Format(PyExc_ValueError, 
 			"Unexpected source buffer size: %lu vs %d", src_size, src_buf_len);
+		Py_DECREF(py_src_buf);
+		Py_DECREF(py_delta);
 		return NULL;
 	}
     dest_size = get_delta_header_size(delta, &index, delta_len);
 	ret = PyString_FromStringAndSize(NULL, dest_size);
 	if (ret == NULL) {
 		PyErr_NoMemory();
+		Py_DECREF(py_src_buf);
+		Py_DECREF(py_delta);
 		return NULL;
 	}
 	out = (uint8_t *)PyString_AsString(ret);
@@ -111,9 +153,13 @@ static PyObject *py_apply_delta(PyObject *self, PyObject *args)
 		} else {
 			PyErr_SetString(PyExc_ValueError, "Invalid opcode 0");
 			Py_DECREF(ret);
+            Py_DECREF(py_delta);
+			Py_DECREF(py_src_buf);
 			return NULL;
 		}
 	}
+	Py_DECREF(py_src_buf);
+    Py_DECREF(py_delta);
     
     if (index != delta_len) {
 		PyErr_SetString(PyExc_ValueError, "delta not empty");
@@ -127,7 +173,7 @@ static PyObject *py_apply_delta(PyObject *self, PyObject *args)
 		return NULL;
 	}
 
-    return ret;
+    return Py_BuildValue("[N]", ret);
 }
 
 static PyObject *py_bisect_find_sha(PyObject *self, PyObject *args)

+ 106 - 40
dulwich/client.py

@@ -28,10 +28,13 @@ import subprocess
 
 from dulwich.errors import (
     ChecksumMismatch,
+    SendPackError,
+    UpdateRefsError,
     )
 from dulwich.protocol import (
     Protocol,
     TCP_GIT_PORT,
+    ZERO_SHA,
     extract_capabilities,
     )
 from dulwich.pack import (
@@ -43,16 +46,19 @@ def _fileno_can_read(fileno):
     """Check if a file descriptor is readable."""
     return len(select.select([fileno], [], [], 0)[0]) > 0
 
+COMMON_CAPABILITIES = ["ofs-delta"]
+FETCH_CAPABILITIES = ["multi_ack", "side-band-64k"] + COMMON_CAPABILITIES
+SEND_CAPABILITIES = ['report-status'] + COMMON_CAPABILITIES
 
-CAPABILITIES = ["multi_ack", "side-band-64k", "ofs-delta"]
-
-
+# TODO(durin42): this doesn't correctly degrade if the server doesn't
+# support some capabilities. This should work properly with servers
+# that don't support side-band-64k and multi_ack.
 class GitClient(object):
     """Git smart server client.
 
     """
 
-    def __init__(self, can_read, read, write, thin_packs=True, 
+    def __init__(self, can_read, read, write, thin_packs=True,
         report_activity=None):
         """Create a new GitClient instance.
 
@@ -66,12 +72,10 @@ class GitClient(object):
         """
         self.proto = Protocol(read, write, report_activity)
         self._can_read = can_read
-        self._capabilities = list(CAPABILITIES)
+        self._fetch_capabilities = list(FETCH_CAPABILITIES)
+        self._send_capabilities = list(SEND_CAPABILITIES)
         if thin_packs:
-            self._capabilities.append("thin-pack")
-
-    def capabilities(self):
-        return " ".join(self._capabilities)
+            self._fetch_capabilities.append("thin-pack")
 
     def read_refs(self):
         server_capabilities = None
@@ -84,44 +88,90 @@ class GitClient(object):
             refs[ref] = sha
         return refs, server_capabilities
 
+    # TODO(durin42): add side-band-64k capability support here and advertise it
     def send_pack(self, path, determine_wants, generate_pack_contents):
         """Upload a pack to a remote repository.
 
         :param path: Repository path
-        :param generate_pack_contents: Function that can return the shas of the 
+        :param generate_pack_contents: Function that can return the shas of the
             objects to upload.
+
+        :raises SendPackError: if server rejects the pack data
+        :raises UpdateRefsError: if the server supports report-status
+                                 and rejects ref updates
         """
         old_refs, server_capabilities = self.read_refs()
+        if 'report-status' not in server_capabilities:
+            self._send_capabilities.remove('report-status')
         new_refs = determine_wants(old_refs)
         if not new_refs:
             self.proto.write_pkt_line(None)
             return {}
         want = []
-        have = [x for x in old_refs.values() if not x == "0" * 40]
+        have = [x for x in old_refs.values() if not x == ZERO_SHA]
         sent_capabilities = False
         for refname in set(new_refs.keys() + old_refs.keys()):
-            old_sha1 = old_refs.get(refname, "0" * 40)
-            new_sha1 = new_refs.get(refname, "0" * 40)
+            old_sha1 = old_refs.get(refname, ZERO_SHA)
+            new_sha1 = new_refs.get(refname, ZERO_SHA)
             if old_sha1 != new_sha1:
                 if sent_capabilities:
-                    self.proto.write_pkt_line("%s %s %s" % (old_sha1, new_sha1, refname))
+                    self.proto.write_pkt_line("%s %s %s" % (old_sha1, new_sha1,
+                                                            refname))
                 else:
-                    self.proto.write_pkt_line("%s %s %s\0%s" % (old_sha1, new_sha1, refname, self.capabilities()))
+                    self.proto.write_pkt_line(
+                      "%s %s %s\0%s" % (old_sha1, new_sha1, refname,
+                                        ' '.join(self._send_capabilities)))
                     sent_capabilities = True
-            if not new_sha1 in (have, "0" * 40):
+            if new_sha1 not in have and new_sha1 != ZERO_SHA:
                 want.append(new_sha1)
         self.proto.write_pkt_line(None)
         if not want:
             return new_refs
         objects = generate_pack_contents(have, want)
-        (entries, sha) = write_pack_data(self.proto.write_file(), objects, 
+        (entries, sha) = write_pack_data(self.proto.write_file(), objects,
                                          len(objects))
-        
-        # read the final confirmation sha
-        client_sha = self.proto.read(20)
-        if not client_sha in (None, "", sha):
-            raise ChecksumMismatch(sha, client_sha)
-            
+
+        if 'report-status' in self._send_capabilities:
+            unpack = self.proto.read_pkt_line().strip()
+            if unpack != 'unpack ok':
+                st = True
+                # flush remaining error data
+                while st is not None:
+                    st = self.proto.read_pkt_line()
+                raise SendPackError(unpack)
+            statuses = []
+            errs = False
+            ref_status = self.proto.read_pkt_line()
+            while ref_status:
+                ref_status = ref_status.strip()
+                statuses.append(ref_status)
+                if not ref_status.startswith('ok '):
+                    errs = True
+                ref_status = self.proto.read_pkt_line()
+
+            if errs:
+                ref_status = {}
+                ok = set()
+                for status in statuses:
+                    if ' ' not in status:
+                        # malformed response, move on to the next one
+                        continue
+                    status, ref = status.split(' ', 1)
+
+                    if status == 'ng':
+                        if ' ' in ref:
+                            ref, status = ref.split(' ', 1)
+                    else:
+                        ok.add(ref)
+                    ref_status[ref] = status
+                raise UpdateRefsError('%s failed to update' %
+                                      ', '.join([ref for ref in ref_status
+                                                 if ref not in ok]),
+                                      ref_status=ref_status)
+        # wait for EOF before returning
+        data = self.proto.read()
+        if data:
+            raise SendPackError('Unexpected response %r' % data)
         return new_refs
 
     def fetch(self, path, target, determine_wants=None, progress=None):
@@ -129,7 +179,7 @@ class GitClient(object):
 
         :param path: Path to fetch from
         :param target: Target repository to fetch into
-        :param determine_wants: Optional function to determine what refs 
+        :param determine_wants: Optional function to determine what refs
             to fetch
         :param progress: Optional progress function
         :return: remote refs
@@ -138,8 +188,8 @@ class GitClient(object):
             determine_wants = target.object_store.determine_wants_all
         f, commit = target.object_store.add_pack()
         try:
-            return self.fetch_pack(path, determine_wants, target.graph_walker, 
-                                   f.write, progress)
+            return self.fetch_pack(path, determine_wants,
+                target.get_graph_walker(), f.write, progress)
         finally:
             commit()
 
@@ -158,7 +208,8 @@ class GitClient(object):
             self.proto.write_pkt_line(None)
             return refs
         assert isinstance(wants, list) and type(wants[0]) == str
-        self.proto.write_pkt_line("want %s %s\n" % (wants[0], self.capabilities()))
+        self.proto.write_pkt_line("want %s %s\n" % (
+            wants[0], ' '.join(self._fetch_capabilities)))
         for want in wants[1:]:
             self.proto.write_pkt_line("want %s\n" % want)
         self.proto.write_pkt_line(None)
@@ -181,13 +232,16 @@ class GitClient(object):
             if len(parts) < 3 or parts[2] != "continue":
                 break
             pkt = self.proto.read_pkt_line()
+        # TODO(durin42): this is broken if the server didn't support the
+        # side-band-64k capability.
         for pkt in self.proto.read_pkt_seq():
             channel = ord(pkt[0])
             pkt = pkt[1:]
             if channel == 1:
                 pack_data(pkt)
             elif channel == 2:
-                progress(pkt)
+                if progress is not None:
+                    progress(pkt)
             else:
                 raise AssertionError("Invalid sideband channel %d" % channel)
         return refs
@@ -216,9 +270,9 @@ class TCPGitClient(GitClient):
 
     def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress):
         """Fetch a pack from the remote host.
-        
+
         :param path: Path of the reposiutory on the remote host
-        :param determine_wants: Callback that receives available refs dict and 
+        :param determine_wants: Callback that receives available refs dict and
             should return list of sha's to fetch.
         :param graph_walker: GraphWalker instance used to find missing shas
         :param pack_data: Callback for writing pack data
@@ -254,18 +308,18 @@ class SubprocessGitClient(GitClient):
 
         :param path: Path to the git repository on the server
         :param changed_refs: Dictionary with new values for the refs
-        :param generate_pack_contents: Function that returns an iterator over 
+        :param generate_pack_contents: Function that returns an iterator over
             objects to send
         """
         client = self._connect("git-receive-pack", path)
         return client.send_pack(path, changed_refs, generate_pack_contents)
 
-    def fetch_pack(self, path, determine_wants, graph_walker, pack_data, 
+    def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
         progress):
         """Retrieve a pack from the server
 
         :param path: Path to the git repository on the server
-        :param determine_wants: Function that receives existing refs 
+        :param determine_wants: Function that receives existing refs
             on the server and returns a list of desired shas
         :param graph_walker: GraphWalker instance
         :param pack_data: Function that can write pack data
@@ -281,12 +335,8 @@ class SSHSubprocess(object):
 
     def __init__(self, proc):
         self.proc = proc
-
-    def send(self, data):
-        return os.write(self.proc.stdin.fileno(), data)
-
-    def recv(self, count):
-        return self.proc.stdout.read(count)
+        self.read = self.recv = proc.stdout.read
+        self.write = self.send = proc.stdin.write
 
     def close(self):
         self.proc.stdin.close()
@@ -323,7 +373,9 @@ class SSHGitClient(GitClient):
         self._kwargs = kwargs
 
     def send_pack(self, path, determine_wants, generate_pack_contents):
-        remote = get_ssh_vendor().connect_ssh(self.host, ["git-receive-pack '%s'" % path], port=self.port, username=self.username)
+        remote = get_ssh_vendor().connect_ssh(
+            self.host, ["git-receive-pack '%s'" % path],
+            port=self.port, username=self.username)
         client = GitClient(lambda: _fileno_can_read(remote.proc.stdout.fileno()), remote.recv, remote.send, *self._args, **self._kwargs)
         return client.send_pack(path, determine_wants, generate_pack_contents)
 
@@ -334,3 +386,17 @@ class SSHGitClient(GitClient):
         return client.fetch_pack(path, determine_wants, graph_walker, pack_data,
                                  progress)
 
+
+def get_transport_and_path(uri):
+    """Obtain a git client from a URI or path.
+
+    :param uri: URI or path
+    :return: Tuple with client instance and relative path.
+    """
+    from dulwich.client import TCPGitClient, SSHGitClient, SubprocessGitClient
+    for handler, transport in (("git://", TCPGitClient), ("git+ssh://", SSHGitClient)):
+        if uri.startswith(handler):
+            host, path = uri[len(handler):].split("/", 1)
+            return transport(host), "/"+path
+    # if its not git or git+ssh, try a local url..
+    return SubprocessGitClient(), uri

+ 56 - 21
dulwich/errors.py

@@ -1,17 +1,17 @@
 # errors.py -- errors for dulwich
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
 # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # or (at your option) any later version of the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -19,71 +19,83 @@
 
 """Dulwich-related exception classes and utility functions."""
 
+import binascii
+
+
 class ChecksumMismatch(Exception):
     """A checksum didn't match the expected contents."""
 
     def __init__(self, expected, got, extra=None):
+        if len(expected) == 20:
+            expected = binascii.hexlify(expected)
+        if len(got) == 20:
+            got = binascii.hexlify(got)
         self.expected = expected
         self.got = got
         self.extra = extra
         if self.extra is None:
-            Exception.__init__(self, 
+            Exception.__init__(self,
                 "Checksum mismatch: Expected %s, got %s" % (expected, got))
         else:
             Exception.__init__(self,
-                "Checksum mismatch: Expected %s, got %s; %s" % 
+                "Checksum mismatch: Expected %s, got %s; %s" %
                 (expected, got, extra))
 
 
 class WrongObjectException(Exception):
     """Baseclass for all the _ is not a _ exceptions on objects.
-  
+
     Do not instantiate directly.
-  
-    Subclasses should define a _type attribute that indicates what
+
+    Subclasses should define a type_name attribute that indicates what
     was expected if they were raised.
     """
-  
+
     def __init__(self, sha, *args, **kwargs):
-        string = "%s is not a %s" % (sha, self._type)
-        Exception.__init__(self, string)
+        Exception.__init__(self, "%s is not a %s" % (sha, self.type_name))
 
 
 class NotCommitError(WrongObjectException):
     """Indicates that the sha requested does not point to a commit."""
-  
-    _type = 'commit'
+
+    type_name = 'commit'
 
 
 class NotTreeError(WrongObjectException):
     """Indicates that the sha requested does not point to a tree."""
-  
-    _type = 'tree'
+
+    type_name = 'tree'
+
+
+class NotTagError(WrongObjectException):
+    """Indicates that the sha requested does not point to a tag."""
+
+    type_name = 'tag'
 
 
 class NotBlobError(WrongObjectException):
     """Indicates that the sha requested does not point to a blob."""
-  
-    _type = 'blob'
+
+    type_name = 'blob'
 
 
 class MissingCommitError(Exception):
     """Indicates that a commit was not found in the repository"""
-  
+
     def __init__(self, sha, *args, **kwargs):
         Exception.__init__(self, "%s is not in the revision store" % sha)
 
 
 class ObjectMissing(Exception):
     """Indicates that a requested object is missing."""
-  
+
     def __init__(self, sha, *args, **kwargs):
         Exception.__init__(self, "%s is not in the pack" % sha)
 
 
 class ApplyDeltaError(Exception):
     """Indicates that applying a delta failed."""
-    
+
     def __init__(self, *args, **kwargs):
         Exception.__init__(self, *args, **kwargs)
 
@@ -97,11 +109,26 @@ class NotGitRepository(Exception):
 
 class GitProtocolError(Exception):
     """Git protocol exception."""
-    
+
+    def __init__(self, *args, **kwargs):
+        Exception.__init__(self, *args, **kwargs)
+
+
+class SendPackError(GitProtocolError):
+    """An error occurred during send_pack."""
+
     def __init__(self, *args, **kwargs):
         Exception.__init__(self, *args, **kwargs)
 
 
+class UpdateRefsError(GitProtocolError):
+    """The server reported errors updating refs."""
+
+    def __init__(self, *args, **kwargs):
+        self.ref_status = kwargs.pop('ref_status')
+        Exception.__init__(self, *args, **kwargs)
+
+
 class HangupException(GitProtocolError):
     """Hangup exception."""
 
@@ -118,5 +145,13 @@ class PackedRefsException(FileFormatException):
     """Indicates an error parsing a packed-refs file."""
 
 
+class ObjectFormatException(FileFormatException):
+    """Indicates an error parsing an object."""
+
+
 class NoIndexPresent(Exception):
     """No index is present."""
+
+
+class CommitError(Exception):
+    """An error occurred while performing a commit."""

+ 82 - 0
dulwich/fastexport.py

@@ -0,0 +1,82 @@
+# __init__.py -- Fast export/import functionality
+# Copyright (C) 2010 Jelmer Vernooij <jelmer@samba.org>
+# 
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of 
+# the License.
+# 
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+# 
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+
+"""Fast export/import functionality."""
+
+from dulwich.objects import (
+    Tree,
+    format_timezone,
+    )
+
+import stat
+
+class FastExporter(object):
+    """Generate a fast-export output stream for Git objects."""
+
+    def __init__(self, outf, store):
+        self.outf = outf
+        self.store = store
+        self.markers = {}
+        self._marker_idx = 0
+
+    def _allocate_marker(self):
+        self._marker_idx+=1
+        return self._marker_idx
+
+    def _dump_blob(self, blob, marker):
+        self.outf.write("blob\nmark :%s\n" % marker)
+        self.outf.write("data %s\n" % blob.raw_length())
+        for chunk in blob.as_raw_chunks():
+            self.outf.write(chunk)
+        self.outf.write("\n")
+
+    def export_blob(self, blob):
+        i = self._allocate_marker()
+        self.markers[i] = blob.id
+        self._dump_blob(blob, i)
+        return i
+
+    def _dump_commit(self, commit, marker, ref, file_changes):
+        self.outf.write("commit %s\n" % ref)
+        self.outf.write("mark :%s\n" % marker)
+        self.outf.write("author %s %s %s\n" % (commit.author,
+            commit.author_time, format_timezone(commit.author_timezone)))
+        self.outf.write("committer %s %s %s\n" % (commit.committer,
+            commit.commit_time, format_timezone(commit.commit_timezone)))
+        self.outf.write("data %s\n" % len(commit.message))
+        self.outf.write(commit.message)
+        self.outf.write("\n")
+        self.outf.write('\n'.join(file_changes))
+        self.outf.write("\n\n")
+
+    def export_commit(self, commit, ref, base_tree=None):
+        file_changes = []
+        for (old_path, new_path), (old_mode, new_mode), (old_hexsha, new_hexsha) in \
+                self.store.tree_changes(base_tree, commit.tree):
+            if new_path is None:
+                file_changes.append("D %s" % old_path)
+                continue
+            if not stat.S_ISDIR(new_mode):
+                marker = self.export_blob(self.store[new_hexsha])
+            file_changes.append("M %o :%s %s" % (new_mode, marker, new_path))
+
+        i = self._allocate_marker()
+        self._dump_commit(commit, i, ref, file_changes)
+        return i

+ 41 - 2
dulwich/file.py

@@ -22,6 +22,7 @@
 
 import errno
 import os
+import tempfile
 
 def ensure_dir_exists(dirname):
     """Ensure a directory exists, creating if necessary."""
@@ -31,6 +32,36 @@ def ensure_dir_exists(dirname):
         if e.errno != errno.EEXIST:
             raise
 
+def fancy_rename(oldname, newname):
+    """Rename file with temporary backup file to rollback if rename fails"""
+    if not os.path.exists(newname):
+        try:
+            os.rename(oldname, newname)
+        except OSError, e:
+            raise
+        return
+
+    # destination file exists
+    try:
+        (fd, tmpfile) = tempfile.mkstemp(".tmp", prefix=oldname+".", dir=".")
+        os.close(fd)
+        os.remove(tmpfile)
+    except OSError, e:
+        # either file could not be created (e.g. permission problem)
+        # or could not be deleted (e.g. rude virus scanner)
+        raise
+    try:
+        os.rename(newname, tmpfile)
+    except OSError, e:
+        raise   # no rename occurred
+    try:
+        os.rename(oldname, newname)
+    except OSError, e:
+        os.rename(tmpfile, newname)
+        raise
+    os.remove(tmpfile)
+
+
 def GitFile(filename, mode='r', bufsize=-1):
     """Create a file object that obeys the git file locking protocol.
 
@@ -89,7 +120,8 @@ class _GitFile(object):
     def __init__(self, filename, mode, bufsize):
         self._filename = filename
         self._lockfilename = '%s.lock' % self._filename
-        fd = os.open(self._lockfilename, os.O_RDWR | os.O_CREAT | os.O_EXCL)
+        fd = os.open(self._lockfilename,
+            os.O_RDWR | os.O_CREAT | os.O_EXCL | getattr(os, "O_BINARY", 0))
         self._file = os.fdopen(fd, mode, bufsize)
         self._closed = False
 
@@ -111,6 +143,7 @@ class _GitFile(object):
             # The file may have been removed already, which is ok.
             if e.errno != errno.ENOENT:
                 raise
+            self._closed = True
 
     def close(self):
         """Close this file, saving the lockfile over the original.
@@ -127,7 +160,13 @@ class _GitFile(object):
             return
         self._file.close()
         try:
-            os.rename(self._lockfilename, self._filename)
+            try:
+                os.rename(self._lockfilename, self._filename)
+            except OSError, e:
+                # Windows versions prior to Vista don't support atomic renames
+                if e.errno != errno.EEXIST:
+                    raise
+                fancy_rename(self._lockfilename, self._filename)
         finally:
             self.abort()
 

+ 16 - 0
dulwich/index.py

@@ -204,6 +204,8 @@ class Index(object):
 
     def read(self):
         """Read current contents of index from disk."""
+        if not os.path.exists(self._filename):
+            return
         f = GitFile(self._filename, 'rb')
         try:
             f = SHA1Reader(f)
@@ -254,6 +256,10 @@ class Index(object):
         # Remove the old entry if any
         self._byname[name] = x
 
+    def __delitem__(self, name):
+        assert isinstance(name, str)
+        del self._byname[name]
+
     def iteritems(self):
         return self._byname.iteritems()
 
@@ -283,6 +289,14 @@ class Index(object):
         for name in mine:
             yield ((None, name), (None, self.get_mode(name)), (None, self.get_sha1(name)))
 
+    def commit(self, object_store):
+        """Create a new tree from an index.
+
+        :param object_store: Object store to save the tree in
+        :return: Root tree SHA
+        """
+        return commit_tree(object_store, self.iterblobs())
+
 
 def commit_tree(object_store, blobs):
     """Commit a new tree.
@@ -327,5 +341,7 @@ def commit_index(object_store, index):
 
     :param object_store: Object store to save the tree in
     :param index: Index file
+    :note: This function is deprecated, use index.commit() instead.
+    :return: Root tree sha.
     """
     return commit_tree(object_store, index.iterblobs())

+ 13 - 3
dulwich/misc.py

@@ -15,15 +15,26 @@
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # MA  02110-1301, USA.
-"""Misc utilities to work with python2.4.
+"""Misc utilities to work with python <2.6.
 
 These utilities can all be deleted when dulwich decides it wants to stop
-support for python 2.4.
+support for python <2.6.
 """
 try:
     import hashlib
 except ImportError:
     import sha
+
+try:
+    from urlparse import parse_qs
+except ImportError:
+    from cgi import parse_qs
+
+try:
+    from os import SEEK_END
+except ImportError:
+    SEEK_END = 2
+
 import struct
 
 
@@ -87,4 +98,3 @@ def unpack_from(fmt, buf, offset=0):
     except AttributeError:
         b = buf[offset:offset+struct.calcsize(fmt)]
         return struct.unpack(fmt, b)
-

+ 128 - 67
dulwich/object_store.py

@@ -23,6 +23,7 @@
 import errno
 import itertools
 import os
+import posixpath
 import stat
 import tempfile
 import urllib2
@@ -38,11 +39,13 @@ from dulwich.objects import (
     Tree,
     hex_to_sha,
     sha_to_hex,
+    hex_to_filename,
     S_ISGITLINK,
     )
 from dulwich.pack import (
     Pack,
     PackData,
+    ThinPackData,
     iter_sha1,
     load_pack_index,
     write_pack,
@@ -57,7 +60,8 @@ class BaseObjectStore(object):
     """Object store interface."""
 
     def determine_wants_all(self, refs):
-	    return [sha for (ref, sha) in refs.iteritems() if not sha in self and not ref.endswith("^{}")]
+        return [sha for (ref, sha) in refs.iteritems()
+                if not sha in self and not ref.endswith("^{}")]
 
     def iter_shas(self, shas):
         """Iterate over the objects for the specified shas.
@@ -91,14 +95,14 @@ class BaseObjectStore(object):
         """Obtain the raw text for an object.
 
         :param name: sha for the object.
-        :return: tuple with object type and object contents.
+        :return: tuple with numeric type and object contents.
         """
         raise NotImplementedError(self.get_raw)
 
     def __getitem__(self, sha):
         """Obtain an object by SHA1."""
-        type, uncomp = self.get_raw(sha)
-        return ShaFile.from_raw_string(type, uncomp)
+        type_num, uncomp = self.get_raw(sha)
+        return ShaFile.from_raw_string(type_num, uncomp)
 
     def __iter__(self):
         """Iterate over the SHAs that are present in this store."""
@@ -137,10 +141,7 @@ class BaseObjectStore(object):
             else:
                 ttree = {}
             for name, oldmode, oldhexsha in stree.iteritems():
-                if path == "":
-                    oldchildpath = name
-                else:
-                    oldchildpath = "%s/%s" % (path, name)
+                oldchildpath = posixpath.join(path, name)
                 try:
                     (newmode, newhexsha) = ttree[name]
                     newchildpath = oldchildpath
@@ -148,7 +149,7 @@ class BaseObjectStore(object):
                     newmode = None
                     newhexsha = None
                     newchildpath = None
-                if (want_unchanged or oldmode != newmode or 
+                if (want_unchanged or oldmode != newmode or
                     oldhexsha != newhexsha):
                     if stat.S_ISDIR(oldmode):
                         if newmode is None or stat.S_ISDIR(newmode):
@@ -166,10 +167,7 @@ class BaseObjectStore(object):
                             yield ((oldchildpath, newchildpath), (oldmode, newmode), (oldhexsha, newhexsha))
 
             for name, newmode, newhexsha in ttree.iteritems():
-                if path == "":
-                    childpath = name
-                else:
-                    childpath = "%s/%s" % (path, name)
+                childpath = posixpath.join(path, name)
                 if not name in stree:
                     if not stat.S_ISDIR(newmode):
                         yield ((None, childpath), (None, newmode), (None, newhexsha))
@@ -185,26 +183,27 @@ class BaseObjectStore(object):
         while todo:
             (tid, tpath) = todo.pop()
             tree = self[tid]
-            for name, mode, hexsha in tree.iteritems(): 
-                if tpath == "":
-                    path = name
-                else:
-                    path = "%s/%s" % (tpath, name)
+            for name, mode, hexsha in tree.iteritems():
+                path = posixpath.join(tpath, name)
                 if stat.S_ISDIR(mode):
                     todo.add((hexsha, path))
                 else:
                     yield path, mode, hexsha
 
-    def find_missing_objects(self, haves, wants, progress=None):
+    def find_missing_objects(self, haves, wants, progress=None,
+                             get_tagged=None):
         """Find the missing objects required for a set of revisions.
 
         :param haves: Iterable over SHAs already in common.
         :param wants: Iterable over SHAs of objects to fetch.
-        :param progress: Simple progress function that will be called with 
+        :param progress: Simple progress function that will be called with
             updated progress strings.
+        :param get_tagged: Function that returns a dict of pointed-to sha -> tag
+            sha for including tags.
         :return: Iterator over (sha, path) pairs.
         """
-        return iter(MissingObjectFinder(self, haves, wants, progress).next, None)
+        finder = MissingObjectFinder(self, haves, wants, progress, get_tagged)
+        return iter(finder.next, None)
 
     def find_common_revisions(self, graphwalker):
         """Find which revisions this store has in common using graphwalker.
@@ -223,19 +222,20 @@ class BaseObjectStore(object):
 
     def get_graph_walker(self, heads):
         """Obtain a graph walker for this object store.
-        
+
         :param heads: Local heads to start search with
         :return: GraphWalker object
         """
         return ObjectStoreGraphWalker(heads, lambda sha: self[sha].parents)
 
-    def generate_pack_contents(self, have, want):
+    def generate_pack_contents(self, have, want, progress=None):
         """Iterate over the contents of a pack file.
 
         :param have: List of SHA1s of objects that should not be sent
         :param want: List of SHA1s of objects that should be sent
+        :param progress: Optional progress reporting method
         """
-        return self.iter_shas(self.find_missing_objects(have, want))
+        return self.iter_shas(self.find_missing_objects(have, want, progress))
 
 
 class PackBasedObjectStore(BaseObjectStore):
@@ -253,6 +253,10 @@ class PackBasedObjectStore(BaseObjectStore):
     def _load_packs(self):
         raise NotImplementedError(self._load_packs)
 
+    def _pack_cache_stale(self):
+        """Check whether the pack cache is stale."""
+        raise NotImplementedError(self._pack_cache_stale)
+
     def _add_known_pack(self, pack):
         """Add a newly appeared pack to the cache by path.
 
@@ -263,7 +267,7 @@ class PackBasedObjectStore(BaseObjectStore):
     @property
     def packs(self):
         """List with pack objects."""
-        if self._pack_cache is None:
+        if self._pack_cache is None or self._pack_cache_stale():
             self._pack_cache = self._load_packs()
         return self._pack_cache
 
@@ -284,9 +288,9 @@ class PackBasedObjectStore(BaseObjectStore):
 
     def get_raw(self, name):
         """Obtain the raw text for an object.
-        
+
         :param name: sha for the object.
-        :return: tuple with object type and object contents.
+        :return: tuple with numeric type and object contents.
         """
         if len(name) == 40:
             sha = hex_to_sha(name)
@@ -301,24 +305,25 @@ class PackBasedObjectStore(BaseObjectStore):
                 return pack.get_raw(sha)
             except KeyError:
                 pass
-        if hexsha is None: 
+        if hexsha is None:
             hexsha = sha_to_hex(name)
         ret = self._get_loose_object(hexsha)
         if ret is not None:
-            return ret.type, ret.as_raw_string()
+            return ret.type_num, ret.as_raw_string()
         raise KeyError(hexsha)
 
     def add_objects(self, objects):
         """Add a set of objects to this object store.
 
         :param objects: Iterable over objects, should support __len__.
+        :return: Pack object of the objects written.
         """
         if len(objects) == 0:
             # Don't bother writing an empty pack file
             return
         f, commit = self.add_pack()
         write_pack_data(f, objects, len(objects))
-        commit()
+        return commit()
 
 
 class DiskObjectStore(PackBasedObjectStore):
@@ -332,11 +337,14 @@ class DiskObjectStore(PackBasedObjectStore):
         super(DiskObjectStore, self).__init__()
         self.path = path
         self.pack_dir = os.path.join(self.path, PACKDIR)
+        self._pack_cache_time = 0
 
     def _load_packs(self):
         pack_files = []
         try:
-            for name in os.listdir(self.pack_dir):
+            self._pack_cache_time = os.stat(self.pack_dir).st_mtime
+            pack_dir_contents = os.listdir(self.pack_dir)
+            for name in pack_dir_contents:
                 # TODO: verify that idx exists first
                 if name.startswith("pack-") and name.endswith(".pack"):
                     filename = os.path.join(self.pack_dir, name)
@@ -349,11 +357,17 @@ class DiskObjectStore(PackBasedObjectStore):
         suffix_len = len(".pack")
         return [Pack(f[:-suffix_len]) for _, f in pack_files]
 
+    def _pack_cache_stale(self):
+        try:
+            return os.stat(self.pack_dir).st_mtime > self._pack_cache_time
+        except OSError, e:
+            if e.errno == errno.ENOENT:
+                return True
+            raise
+
     def _get_shafile_path(self, sha):
-        dir = sha[:2]
-        file = sha[2:]
         # Check from object dir
-        return os.path.join(self.path, dir, file)
+        return hex_to_filename(self.path, sha)
 
     def _iter_loose_objects(self):
         for base in os.listdir(self.path):
@@ -365,8 +379,8 @@ class DiskObjectStore(PackBasedObjectStore):
     def _get_loose_object(self, sha):
         path = self._get_shafile_path(sha)
         try:
-            return ShaFile.from_file(path)
-        except OSError, e:
+            return ShaFile.from_path(path)
+        except (OSError, IOError), e:
             if e.errno == errno.ENOENT:
                 return None
             raise
@@ -374,51 +388,54 @@ class DiskObjectStore(PackBasedObjectStore):
     def move_in_thin_pack(self, path):
         """Move a specific file containing a pack into the pack directory.
 
-        :note: The file should be on the same file system as the 
+        :note: The file should be on the same file system as the
             packs directory.
 
         :param path: Path to the pack file.
         """
-        data = PackData(path)
+        data = ThinPackData(self.get_raw, path)
 
         # Write index for the thin pack (do we really need this?)
-        temppath = os.path.join(self.pack_dir, 
+        temppath = os.path.join(self.pack_dir,
             sha_to_hex(urllib2.randombytes(20))+".tempidx")
-        data.create_index_v2(temppath, self.get_raw)
+        data.create_index_v2(temppath)
         p = Pack.from_objects(data, load_pack_index(temppath))
 
         # Write a full pack version
-        temppath = os.path.join(self.pack_dir, 
+        temppath = os.path.join(self.pack_dir,
             sha_to_hex(urllib2.randombytes(20))+".temppack")
-        write_pack(temppath, ((o, None) for o in p.iterobjects(self.get_raw)), 
-                len(p))
+        write_pack(temppath, ((o, None) for o in p.iterobjects()), len(p))
         pack_sha = load_pack_index(temppath+".idx").objects_sha1()
         newbasename = os.path.join(self.pack_dir, "pack-%s" % pack_sha)
         os.rename(temppath+".pack", newbasename+".pack")
         os.rename(temppath+".idx", newbasename+".idx")
-        self._add_known_pack(Pack(newbasename))
+        final_pack = Pack(newbasename)
+        self._add_known_pack(final_pack)
+        return final_pack
 
     def move_in_pack(self, path):
         """Move a specific file containing a pack into the pack directory.
 
-        :note: The file should be on the same file system as the 
+        :note: The file should be on the same file system as the
             packs directory.
 
         :param path: Path to the pack file.
         """
         p = PackData(path)
         entries = p.sorted_entries()
-        basename = os.path.join(self.pack_dir, 
+        basename = os.path.join(self.pack_dir,
             "pack-%s" % iter_sha1(entry[0] for entry in entries))
         write_pack_index_v2(basename+".idx", entries, p.get_stored_checksum())
         p.close()
         os.rename(path, basename + ".pack")
-        self._add_known_pack(Pack(basename))
+        final_pack = Pack(basename)
+        self._add_known_pack(final_pack)
+        return final_pack
 
     def add_thin_pack(self):
         """Add a new thin pack to this object store.
 
-        Thin packs are packs that contain deltas with parents that exist 
+        Thin packs are packs that contain deltas with parents that exist
         in a different pack.
         """
         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
@@ -427,13 +444,15 @@ class DiskObjectStore(PackBasedObjectStore):
             os.fsync(fd)
             f.close()
             if os.path.getsize(path) > 0:
-                self.move_in_thin_pack(path)
+                return self.move_in_thin_pack(path)
+            else:
+                return None
         return f, commit
 
     def add_pack(self):
-        """Add a new pack to this object store. 
+        """Add a new pack to this object store.
 
-        :return: Fileobject to write to and a commit function to 
+        :return: Fileobject to write to and a commit function to
             call when the pack is finished.
         """
         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
@@ -442,7 +461,9 @@ class DiskObjectStore(PackBasedObjectStore):
             os.fsync(fd)
             f.close()
             if os.path.getsize(path) > 0:
-                self.move_in_pack(path)
+                return self.move_in_pack(path)
+            else:
+                return None
         return f, commit
 
     def add_object(self, obj):
@@ -451,8 +472,11 @@ class DiskObjectStore(PackBasedObjectStore):
         :param obj: Object to add
         """
         dir = os.path.join(self.path, obj.id[:2])
-        if not os.path.isdir(dir):
+        try:
             os.mkdir(dir)
+        except OSError, e:
+            if e.errno != errno.EEXIST:
+                raise
         path = os.path.join(dir, obj.id[2:])
         if os.path.exists(path):
             return # Already there, no need to write again
@@ -462,6 +486,17 @@ class DiskObjectStore(PackBasedObjectStore):
         finally:
             f.close()
 
+    @classmethod
+    def init(cls, path):
+        try:
+            os.mkdir(path)
+        except OSError, e:
+            if e.errno != errno.EEXIST:
+                raise
+        os.mkdir(os.path.join(path, "info"))
+        os.mkdir(os.path.join(path, PACKDIR))
+        return cls(path)
+
 
 class MemoryObjectStore(BaseObjectStore):
     """Object store that keeps all objects in memory."""
@@ -489,9 +524,9 @@ class MemoryObjectStore(BaseObjectStore):
 
     def get_raw(self, name):
         """Obtain the raw text for an object.
-        
+
         :param name: sha for the object.
-        :return: tuple with object type and object contents.
+        :return: tuple with numeric type and object contents.
         """
         return self[name].as_raw_string()
 
@@ -573,7 +608,7 @@ class ObjectStoreIterator(ObjectIterator):
     def __contains__(self, needle):
         """Check if an object is present.
 
-        :note: This checks if the object is present in 
+        :note: This checks if the object is present in
             the underlying object store, not if it would
             be yielded by the iterator.
 
@@ -583,7 +618,7 @@ class ObjectStoreIterator(ObjectIterator):
 
     def __getitem__(self, key):
         """Find an object by SHA1.
-        
+
         :note: This retrieves the object from the underlying
             object store. It will also succeed if the object would
             not be returned by the iterator.
@@ -604,9 +639,10 @@ def tree_lookup_path(lookup_obj, root_sha, path):
     """
     parts = path.split("/")
     sha = root_sha
+    mode = None
     for p in parts:
         obj = lookup_obj(sha)
-        if type(obj) is not Tree:
+        if not isinstance(obj, Tree):
             raise NotTreeError(sha)
         if p == '':
             continue
@@ -617,14 +653,18 @@ def tree_lookup_path(lookup_obj, root_sha, path):
 class MissingObjectFinder(object):
     """Find the objects missing from another object store.
 
-    :param object_store: Object store containing at least all objects to be 
+    :param object_store: Object store containing at least all objects to be
         sent
     :param haves: SHA1s of commits not to send (already present in target)
     :param wants: SHA1s of commits to send
     :param progress: Optional function to report progress to.
+    :param get_tagged: Function that returns a dict of pointed-to sha -> tag
+        sha for including tags.
+    :param tagged: dict of pointed-to sha -> tag sha for including tags
     """
 
-    def __init__(self, object_store, haves, wants, progress=None):
+    def __init__(self, object_store, haves, wants, progress=None,
+                 get_tagged=None):
         self.sha_done = set(haves)
         self.objects_to_send = set([(w, None, False) for w in wants if w not in haves])
         self.object_store = object_store
@@ -632,6 +672,7 @@ class MissingObjectFinder(object):
             self.progress = lambda x: None
         else:
             self.progress = progress
+        self._tagged = get_tagged and get_tagged() or {}
 
     def add_todo(self, entries):
         self.objects_to_send.update([e for e in entries if not e[0] in self.sha_done])
@@ -658,13 +699,19 @@ class MissingObjectFinder(object):
                 self.parse_tree(o)
             elif isinstance(o, Tag):
                 self.parse_tag(o)
+        if sha in self._tagged:
+            self.add_todo([(self._tagged[sha], None, True)])
         self.sha_done.add(sha)
         self.progress("counting objects: %d\r" % len(self.sha_done))
         return (sha, name)
 
 
 class ObjectStoreGraphWalker(object):
-    """Graph walker that finds out what commits are missing from an object store."""
+    """Graph walker that finds what commits are missing from an object store.
+
+    :ivar heads: Revisions without descendants in the local repo
+    :ivar get_parents: Function to retrieve parents in the local repo
+    """
 
     def __init__(self, local_heads, get_parents):
         """Create a new instance.
@@ -677,12 +724,26 @@ class ObjectStoreGraphWalker(object):
         self.parents = {}
 
     def ack(self, sha):
-        """Ack that a particular revision and its ancestors are present in the source."""
-        if sha in self.heads:
-            self.heads.remove(sha)
-        if sha in self.parents:
-            for p in self.parents[sha]:
-                self.ack(p)
+        """Ack that a revision and its ancestors are present in the source."""
+        ancestors = set([sha])
+
+        # stop if we run out of heads to remove
+        while self.heads:
+            for a in ancestors:
+                if a in self.heads:
+                    self.heads.remove(a)
+
+            # collect all ancestors
+            new_ancestors = set()
+            for a in ancestors:
+                if a in self.parents:
+                    new_ancestors.update(self.parents[a])
+
+            # no more ancestors; stop
+            if not new_ancestors:
+                break
+
+            ancestors = new_ancestors
 
     def next(self):
         """Iterate over ancestors of heads in the target."""

File diff ditekan karena terlalu besar
+ 662 - 250
dulwich/objects.py


File diff ditekan karena terlalu besar
+ 409 - 269
dulwich/pack.py


+ 67 - 5
dulwich/patch.py

@@ -22,10 +22,14 @@ These patches are basically unified diffs with some extra metadata tacked
 on.
 """
 
-import difflib
+from difflib import SequenceMatcher
+import rfc822
 import subprocess
 import time
 
+from dulwich.objects import (
+    Commit,
+    )
 
 def write_commit_patch(f, commit, contents, progress, version=None):
     """Write a individual file patch.
@@ -68,6 +72,36 @@ def get_summary(commit):
     return commit.message.splitlines()[0].replace(" ", "-")
 
 
+def unified_diff(a, b, fromfile='', tofile='', n=3):
+    """difflib.unified_diff that doesn't write any dates or trailing spaces.
+
+    Based on the same function in Python2.6.5-rc2's difflib.py
+    """
+    started = False
+    for group in SequenceMatcher(None, a, b).get_grouped_opcodes(n):
+        if not started:
+            yield '--- %s\n' % fromfile
+            yield '+++ %s\n' % tofile
+            started = True
+        i1, i2, j1, j2 = group[0][1], group[-1][2], group[0][3], group[-1][4]
+        yield "@@ -%d,%d +%d,%d @@\n" % (i1+1, i2-i1, j1+1, j2-j1)
+        for tag, i1, i2, j1, j2 in group:
+            if tag == 'equal':
+                for line in a[i1:i2]:
+                    yield ' ' + line
+                continue
+            if tag == 'replace' or tag == 'delete':
+                for line in a[i1:i2]:
+                    if not line[-1] == '\n':
+                        line += '\n\\ No newline at end of file\n'
+                    yield '-' + line
+            if tag == 'replace' or tag == 'insert':
+                for line in b[j1:j2]:
+                    if not line[-1] == '\n':
+                        line += '\n\\ No newline at end of file\n'
+                    yield '+' + line
+
+
 def write_blob_diff(f, (old_path, old_mode, old_blob), 
                        (new_path, new_mode, new_blob)):
     """Write diff file header.
@@ -98,13 +132,41 @@ def write_blob_diff(f, (old_path, old_mode, old_blob),
     if old_mode != new_mode:
         if new_mode is not None:
             if old_mode is not None:
-                f.write("old file mode %o\n" % old_mode)
-            f.write("new file mode %o\n" % new_mode) 
+                f.write("old mode %o\n" % old_mode)
+            f.write("new mode %o\n" % new_mode) 
         else:
-            f.write("deleted file mode %o\n" % old_mode)
+            f.write("deleted mode %o\n" % old_mode)
     f.write("index %s..%s %o\n" % (
         blob_id(old_blob), blob_id(new_blob), new_mode))
     old_contents = lines(old_blob)
     new_contents = lines(new_blob)
-    f.writelines(difflib.unified_diff(old_contents, new_contents, 
+    f.writelines(unified_diff(old_contents, new_contents, 
         old_path, new_path))
+
+
+def git_am_patch_split(f):
+    """Parse a git-am-style patch and split it up into bits.
+
+    :param f: File-like object to parse
+    :return: Tuple with commit object, diff contents and git version
+    """
+    msg = rfc822.Message(f)
+    c = Commit()
+    c.author = msg["from"]
+    c.committer = msg["from"]
+    if msg["subject"].startswith("[PATCH"):
+        subject = msg["subject"].split("]", 1)[1][1:]
+    else:
+        subject = msg["subject"]
+    c.message = subject
+    for l in f:
+        if l == "---\n":
+            break
+        c.message += l
+    diff = ""
+    for l in f:
+        if l == "-- \n":
+            break
+        diff += l
+    version = f.next().rstrip("\n")
+    return c, diff, version

+ 116 - 2
dulwich/protocol.py

@@ -19,19 +19,27 @@
 
 """Generic functions for talking the git smart server protocol."""
 
+from cStringIO import StringIO
+import os
 import socket
 
 from dulwich.errors import (
     HangupException,
     GitProtocolError,
     )
+from dulwich.misc import (
+    SEEK_END,
+    )
 
 TCP_GIT_PORT = 9418
 
+ZERO_SHA = "0" * 40
+
 SINGLE_ACK = 0
 MULTI_ACK = 1
 MULTI_ACK_DETAILED = 2
 
+
 class ProtocolFile(object):
     """
     Some network ops are like file ops. The file ops expect to operate on
@@ -160,6 +168,112 @@ class Protocol(object):
         return cmd, args[:-1].split(chr(0))
 
 
+_RBUFSIZE = 8192  # Default read buffer size.
+
+
+class ReceivableProtocol(Protocol):
+    """Variant of Protocol that allows reading up to a size without blocking.
+
+    This class has a recv() method that behaves like socket.recv() in addition
+    to a read() method.
+
+    If you want to read n bytes from the wire and block until exactly n bytes
+    (or EOF) are read, use read(n). If you want to read at most n bytes from the
+    wire but don't care if you get less, use recv(n). Note that recv(n) will
+    still block until at least one byte is read.
+    """
+
+    def __init__(self, recv, write, report_activity=None, rbufsize=_RBUFSIZE):
+        super(ReceivableProtocol, self).__init__(self.read, write,
+                                                 report_activity)
+        self._recv = recv
+        self._rbuf = StringIO()
+        self._rbufsize = rbufsize
+
+    def read(self, size):
+        # From _fileobj.read in socket.py in the Python 2.6.5 standard library,
+        # with the following modifications:
+        #  - omit the size <= 0 branch
+        #  - seek back to start rather than 0 in case some buffer has been
+        #    consumed.
+        #  - use SEEK_END instead of the magic number.
+        # Copyright (c) 2001-2010 Python Software Foundation; All Rights Reserved
+        # Licensed under the Python Software Foundation License.
+        # TODO: see if buffer is more efficient than cStringIO.
+        assert size > 0
+
+        # Our use of StringIO rather than lists of string objects returned by
+        # recv() minimizes memory usage and fragmentation that occurs when
+        # rbufsize is large compared to the typical return value of recv().
+        buf = self._rbuf
+        start = buf.tell()
+        buf.seek(0, SEEK_END)
+        # buffer may have been partially consumed by recv()
+        buf_len = buf.tell() - start
+        if buf_len >= size:
+            # Already have size bytes in our buffer?  Extract and return.
+            buf.seek(start)
+            rv = buf.read(size)
+            self._rbuf = StringIO()
+            self._rbuf.write(buf.read())
+            self._rbuf.seek(0)
+            return rv
+
+        self._rbuf = StringIO()  # reset _rbuf.  we consume it via buf.
+        while True:
+            left = size - buf_len
+            # recv() will malloc the amount of memory given as its
+            # parameter even though it often returns much less data
+            # than that.  The returned data string is short lived
+            # as we copy it into a StringIO and free it.  This avoids
+            # fragmentation issues on many platforms.
+            data = self._recv(left)
+            if not data:
+                break
+            n = len(data)
+            if n == size and not buf_len:
+                # Shortcut.  Avoid buffer data copies when:
+                # - We have no data in our buffer.
+                # AND
+                # - Our call to recv returned exactly the
+                #   number of bytes we were asked to read.
+                return data
+            if n == left:
+                buf.write(data)
+                del data  # explicit free
+                break
+            assert n <= left, "_recv(%d) returned %d bytes" % (left, n)
+            buf.write(data)
+            buf_len += n
+            del data  # explicit free
+            #assert buf_len == buf.tell()
+        buf.seek(start)
+        return buf.read()
+
+    def recv(self, size):
+        assert size > 0
+
+        buf = self._rbuf
+        start = buf.tell()
+        buf.seek(0, SEEK_END)
+        buf_len = buf.tell()
+        buf.seek(start)
+
+        left = buf_len - start
+        if not left:
+            # only read from the wire if our read buffer is exhausted
+            data = self._recv(self._rbufsize)
+            if len(data) == size:
+                # shortcut: skip the buffer if we read exactly size bytes
+                return data
+            buf = StringIO()
+            buf.write(data)
+            buf.seek(0)
+            del data  # explicit free
+            self._rbuf = buf
+        return buf.read(size)
+
+
 def extract_capabilities(text):
     """Extract a capabilities list from a string, if present.
 
@@ -169,7 +283,7 @@ def extract_capabilities(text):
     if not "\0" in text:
         return text, []
     text, capabilities = text.rstrip().split("\0")
-    return (text, capabilities.split(" "))
+    return (text, capabilities.strip().split(" "))
 
 
 def extract_want_line_capabilities(text):
@@ -192,7 +306,7 @@ def extract_want_line_capabilities(text):
 def ack_type(capabilities):
     """Extract the ack type from a capabilities list."""
     if 'multi_ack_detailed' in capabilities:
-      return MULTI_ACK_DETAILED
+        return MULTI_ACK_DETAILED
     elif 'multi_ack' in capabilities:
         return MULTI_ACK
     return SINGLE_ACK

+ 371 - 79
dulwich/repo.py

@@ -1,18 +1,18 @@
 # repo.py -- For dealing wih git repositories.
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
 # Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
-# of the License or (at your option) any later version of 
+# of the License or (at your option) any later version of
 # the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -26,13 +26,15 @@ import errno
 import os
 
 from dulwich.errors import (
-    MissingCommitError, 
+    MissingCommitError,
     NoIndexPresent,
-    NotBlobError, 
-    NotCommitError, 
+    NotBlobError,
+    NotCommitError,
     NotGitRepository,
-    NotTreeError, 
+    NotTreeError,
+    NotTagError,
     PackedRefsException,
+    CommitError,
     )
 from dulwich.file import (
     ensure_dir_exists,
@@ -48,7 +50,10 @@ from dulwich.objects import (
     Tag,
     Tree,
     hex_to_sha,
+    object_class,
     )
+import warnings
+
 
 OBJECTDIR = 'objects'
 SYMREF = 'ref: '
@@ -58,9 +63,6 @@ REFSDIR_HEADS = 'heads'
 INDEX_FILENAME = "index"
 
 BASE_DIRECTORIES = [
-    [OBJECTDIR], 
-    [OBJECTDIR, "info"], 
-    [OBJECTDIR, "pack"],
     ["branches"],
     [REFSDIR],
     [REFSDIR, REFSDIR_TAGS],
@@ -73,7 +75,7 @@ BASE_DIRECTORIES = [
 def read_info_refs(f):
     ret = {}
     for l in f.readlines():
-        (sha, name) = l.rstrip("\n").split("\t", 1)
+        (sha, name) = l.rstrip("\r\n").split("\t", 1)
         ret[name] = sha
     return ret
 
@@ -114,12 +116,18 @@ class RefsContainer(object):
     """A container for refs."""
 
     def set_ref(self, name, other):
+        warnings.warn("RefsContainer.set_ref() is deprecated."
+            "Use set_symblic_ref instead.",
+            category=DeprecationWarning, stacklevel=2)
+        return self.set_symbolic_ref(name, other)
+
+    def set_symbolic_ref(self, name, other):
         """Make a ref point at another ref.
 
         :param name: Name of the ref to set
         :param other: Name of the ref to point at
         """
-        self[name] = SYMREF + other + '\n'
+        raise NotImplementedError(self.set_symbolic_ref)
 
     def get_packed_refs(self):
         """Get contents of the packed-refs file.
@@ -131,14 +139,28 @@ class RefsContainer(object):
         """
         raise NotImplementedError(self.get_packed_refs)
 
+    def get_peeled(self, name):
+        """Return the cached peeled value of a ref, if available.
+
+        :param name: Name of the ref to peel
+        :return: The peeled value of the ref. If the ref is known not point to a
+            tag, this will be the SHA the ref refers to. If the ref may point to
+            a tag, but no cached information is available, None is returned.
+        """
+        return None
+
     def import_refs(self, base, other):
         for name, value in other.iteritems():
             self["%s/%s" % (base, name)] = value
 
+    def allkeys(self):
+        """All refs present in this container."""
+        raise NotImplementedError(self.allkeys)
+
     def keys(self, base=None):
         """Refs present in this container.
 
-        :param base: An optional base to return refs under
+        :param base: An optional base to return refs under.
         :return: An unsorted set of valid refs in this container, including
             packed refs.
         """
@@ -148,10 +170,17 @@ class RefsContainer(object):
             return self.allkeys()
 
     def subkeys(self, base):
+        """Refs present in this container under a base.
+
+        :param base: The base to return refs under.
+        :return: A set of valid refs in this container under the base; the base
+            prefix is stripped from the ref names returned.
+        """
         keys = set()
+        base_len = len(base) + 1
         for refname in self.allkeys():
             if refname.startswith(base):
-                keys.add(refname)
+                keys.add(refname[base_len:])
         return keys
 
     def as_dict(self, base=None):
@@ -186,11 +215,23 @@ class RefsContainer(object):
         if not name.startswith('refs/') or not check_ref_format(name[5:]):
             raise KeyError(name)
 
+    def read_ref(self, refname):
+        """Read a reference without following any references.
+
+        :param refname: The name of the reference
+        :return: The contents of the ref file, or None if it does
+            not exist.
+        """
+        contents = self.read_loose_ref(refname)
+        if not contents:
+            contents = self.get_packed_refs().get(refname, None)
+        return contents
+
     def read_loose_ref(self, name):
         """Read a loose reference and return its contents.
 
         :param name: the refname to read
-        :return: The contents of the ref file, or None if it does 
+        :return: The contents of the ref file, or None if it does
             not exist.
         """
         raise NotImplementedError(self.read_loose_ref)
@@ -206,16 +247,19 @@ class RefsContainer(object):
         depth = 0
         while contents.startswith(SYMREF):
             refname = contents[len(SYMREF):]
-            contents = self.read_loose_ref(refname)
+            contents = self.read_ref(refname)
             if not contents:
-                contents = self.get_packed_refs().get(refname, None)
-                if not contents:
-                    break
+                break
             depth += 1
             if depth > 5:
                 raise KeyError(name)
         return refname, contents
 
+    def __contains__(self, refname):
+        if self.read_ref(refname):
+            return True
+        return False
+
     def __getitem__(self, name):
         """Get the SHA1 for a reference name.
 
@@ -226,8 +270,74 @@ class RefsContainer(object):
             raise KeyError(name)
         return sha
 
+    def set_if_equals(self, name, old_ref, new_ref):
+        """Set a refname to new_ref only if it currently equals old_ref.
+
+        This method follows all symbolic references if applicable for the
+        subclass, and can be used to perform an atomic compare-and-swap
+        operation.
+
+        :param name: The refname to set.
+        :param old_ref: The old sha the refname must refer to, or None to set
+            unconditionally.
+        :param new_ref: The new sha the refname will refer to.
+        :return: True if the set was successful, False otherwise.
+        """
+        raise NotImplementedError(self.set_if_equals)
+
+    def add_if_new(self, name, ref):
+        """Add a new reference only if it does not already exist."""
+        raise NotImplementedError(self.add_if_new)
+
+    def __setitem__(self, name, ref):
+        """Set a reference name to point to the given SHA1.
+
+        This method follows all symbolic references if applicable for the
+        subclass.
+
+        :note: This method unconditionally overwrites the contents of a
+            reference. To update atomically only if the reference has not
+            changed, use set_if_equals().
+        :param name: The refname to set.
+        :param ref: The new sha the refname will refer to.
+        """
+        self.set_if_equals(name, None, ref)
+
+    def remove_if_equals(self, name, old_ref):
+        """Remove a refname only if it currently equals old_ref.
+
+        This method does not follow symbolic references, even if applicable for
+        the subclass. It can be used to perform an atomic compare-and-delete
+        operation.
+
+        :param name: The refname to delete.
+        :param old_ref: The old sha the refname must refer to, or None to delete
+            unconditionally.
+        :return: True if the delete was successful, False otherwise.
+        """
+        raise NotImplementedError(self.remove_if_equals)
+
+    def __delitem__(self, name):
+        """Remove a refname.
+
+        This method does not follow symbolic references, even if applicable for
+        the subclass.
+
+        :note: This method unconditionally deletes the contents of a reference.
+            To delete atomically only if the reference has not changed, use
+            remove_if_equals().
+
+        :param name: The refname to delete.
+        """
+        self.remove_if_equals(name, None)
+
 
 class DictRefsContainer(RefsContainer):
+    """RefsContainer backed by a simple dict.
+
+    This container does not support symbolic or packed references and is not
+    threadsafe.
+    """
 
     def __init__(self, refs):
         self._refs = refs
@@ -236,7 +346,32 @@ class DictRefsContainer(RefsContainer):
         return self._refs.keys()
 
     def read_loose_ref(self, name):
-        return self._refs[name]
+        return self._refs.get(name, None)
+
+    def get_packed_refs(self):
+        return {}
+
+    def set_symbolic_ref(self, name, other):
+        self._refs[name] = SYMREF + other
+
+    def set_if_equals(self, name, old_ref, new_ref):
+        if old_ref is not None and self._refs.get(name, None) != old_ref:
+            return False
+        realname, _ = self._follow(name)
+        self._refs[realname] = new_ref
+        return True
+
+    def add_if_new(self, name, ref):
+        if name in self._refs:
+            return False
+        self._refs[name] = ref
+        return True
+
+    def remove_if_equals(self, name, old_ref):
+        if old_ref is not None and self._refs.get(name, None) != old_ref:
+            return False
+        del self._refs[name]
+        return True
 
 
 class DiskRefsContainer(RefsContainer):
@@ -245,7 +380,7 @@ class DiskRefsContainer(RefsContainer):
     def __init__(self, path):
         self.path = path
         self._packed_refs = None
-        self._peeled_refs = {}
+        self._peeled_refs = None
 
     def __repr__(self):
         return "%s(%r)" % (self.__class__.__name__, self.path)
@@ -298,7 +433,10 @@ class DiskRefsContainer(RefsContainer):
         """
         # TODO: invalidate the cache on repacking
         if self._packed_refs is None:
+            # set both to empty because we want _peeled_refs to be
+            # None if and only if _packed_refs is also None.
             self._packed_refs = {}
+            self._peeled_refs = {}
             path = os.path.join(self.path, 'packed-refs')
             try:
                 f = GitFile(path, 'rb')
@@ -322,6 +460,24 @@ class DiskRefsContainer(RefsContainer):
                 f.close()
         return self._packed_refs
 
+    def get_peeled(self, name):
+        """Return the cached peeled value of a ref, if available.
+
+        :param name: Name of the ref to peel
+        :return: The peeled value of the ref. If the ref is known not point to a
+            tag, this will be the SHA the ref refers to. If the ref may point to
+            a tag, but no cached information is available, None is returned.
+        """
+        self.get_packed_refs()
+        if self._peeled_refs is None or name not in self._packed_refs:
+            # No cache: no peeled refs were read, or this ref is loose
+            return None
+        if name in self._peeled_refs:
+            return self._peeled_refs[name]
+        else:
+            # Known not peelable
+            return self[name]
+
     def read_loose_ref(self, name):
         """Read a reference file and return its contents.
 
@@ -340,7 +496,7 @@ class DiskRefsContainer(RefsContainer):
                 header = f.read(len(SYMREF))
                 if header == SYMREF:
                     # Read only the first line
-                    return header + iter(f).next().rstrip("\n")
+                    return header + iter(f).next().rstrip("\r\n")
                 else:
                     # Read only the first 40 bytes
                     return header + f.read(40-len(SYMREF))
@@ -372,6 +528,25 @@ class DiskRefsContainer(RefsContainer):
         finally:
             f.abort()
 
+    def set_symbolic_ref(self, name, other):
+        """Make a ref point at another ref.
+
+        :param name: Name of the ref to set
+        :param other: Name of the ref to point at
+        """
+        self._check_refname(name)
+        self._check_refname(other)
+        filename = self.refpath(name)
+        try:
+            f = GitFile(filename, 'wb')
+            try:
+                f.write(SYMREF + other + '\n')
+            except (IOError, OSError):
+                f.abort()
+                raise
+        finally:
+            f.close()
+
     def set_if_equals(self, name, old_ref, new_ref):
         """Set a refname to new_ref only if it currently equals old_ref.
 
@@ -414,9 +589,23 @@ class DiskRefsContainer(RefsContainer):
         return True
 
     def add_if_new(self, name, ref):
-        """Add a new reference only if it does not already exist."""
-        self._check_refname(name)
-        filename = self.refpath(name)
+        """Add a new reference only if it does not already exist.
+
+        This method follows symrefs, and only ensures that the last ref in the
+        chain does not exist.
+
+        :param name: The refname to set.
+        :param ref: The new sha the refname will refer to.
+        :return: True if the add was successful, False otherwise.
+        """
+        try:
+            realname, contents = self._follow(name)
+            if contents is not None:
+                return False
+        except KeyError:
+            realname = name
+        self._check_refname(realname)
+        filename = self.refpath(realname)
         ensure_dir_exists(os.path.dirname(filename))
         f = GitFile(filename, 'wb')
         try:
@@ -432,17 +621,6 @@ class DiskRefsContainer(RefsContainer):
             f.close()
         return True
 
-    def __setitem__(self, name, ref):
-        """Set a reference name to point to the given SHA1.
-
-        This method follows all symbolic references.
-
-        :note: This method unconditionally overwrites the contents of a reference
-            on disk. To update atomically only if the reference has not changed
-            on disk, use set_if_equals().
-        """
-        self.set_if_equals(name, None, ref)
-
     def remove_if_equals(self, name, old_ref):
         """Remove a refname only if it currently equals old_ref.
 
@@ -477,16 +655,6 @@ class DiskRefsContainer(RefsContainer):
             f.abort()
         return True
 
-    def __delitem__(self, name):
-        """Remove a refname.
-
-        This method does not follow symbolic references.
-        :note: This method unconditionally deletes the contents of a reference
-            on disk. To delete atomically only if the reference has not changed
-            on disk, use set_if_equals().
-        """
-        self.remove_if_equals(name, None)
-
 
 def _split_ref_line(line):
     """Split a single ref line into a tuple of SHA1 and name."""
@@ -516,7 +684,7 @@ def read_packed_refs(f):
             continue
         if l[0] == "^":
             raise PackedRefsException(
-                "found peeled ref in packed-refs without peeled")
+              "found peeled ref in packed-refs without peeled")
         yield _split_ref_line(l)
 
 
@@ -532,7 +700,7 @@ def read_packed_refs_with_peeled(f):
     for l in f:
         if l[0] == "#":
             continue
-        l = l.rstrip("\n")
+        l = l.rstrip("\r\n")
         if l[0] == "^":
             if not last:
                 raise PackedRefsException("unexpected peeled ref line")
@@ -558,6 +726,7 @@ def write_packed_refs(f, packed_refs, peeled_refs=None):
 
     :param f: empty file-like object to write to
     :param packed_refs: dict of refname to sha of packed refs to write
+    :param peeled_refs: dict of refname to peeled value of sha
     """
     if peeled_refs is None:
         peeled_refs = {}
@@ -595,7 +764,7 @@ class BaseRepo(object):
 
     def open_index(self):
         """Open the index for this repository.
-        
+
         :raises NoIndexPresent: If no index is present
         :return: Index instance
         """
@@ -605,33 +774,39 @@ class BaseRepo(object):
         """Fetch objects into another repository.
 
         :param target: The target repository
-        :param determine_wants: Optional function to determine what refs to 
+        :param determine_wants: Optional function to determine what refs to
             fetch.
         :param progress: Optional progress function
         """
         if determine_wants is None:
             determine_wants = lambda heads: heads.values()
         target.object_store.add_objects(
-            self.fetch_objects(determine_wants, target.get_graph_walker(),
-                progress))
+          self.fetch_objects(determine_wants, target.get_graph_walker(),
+                             progress))
         return self.get_refs()
 
-    def fetch_objects(self, determine_wants, graph_walker, progress):
+    def fetch_objects(self, determine_wants, graph_walker, progress,
+                      get_tagged=None):
         """Fetch the missing objects required for a set of revisions.
 
-        :param determine_wants: Function that takes a dictionary with heads 
+        :param determine_wants: Function that takes a dictionary with heads
             and returns the list of heads to fetch.
-        :param graph_walker: Object that can iterate over the list of revisions 
-            to fetch and has an "ack" method that will be called to acknowledge 
+        :param graph_walker: Object that can iterate over the list of revisions
+            to fetch and has an "ack" method that will be called to acknowledge
             that a revision is present.
-        :param progress: Simple progress function that will be called with 
+        :param progress: Simple progress function that will be called with
             updated progress strings.
+        :param get_tagged: Function that returns a dict of pointed-to sha -> tag
+            sha for including tags.
         :return: iterator over objects, with __len__ implemented
         """
         wants = determine_wants(self.get_refs())
+        if not wants:
+            return []
         haves = self.object_store.find_common_revisions(graph_walker)
         return self.object_store.iter_shas(
-            self.object_store.find_missing_objects(haves, wants, progress))
+          self.object_store.find_missing_objects(haves, wants, progress,
+                                                 get_tagged))
 
     def get_graph_walker(self, heads=None):
         if heads is None:
@@ -653,15 +828,18 @@ class BaseRepo(object):
     def _get_object(self, sha, cls):
         assert len(sha) in (20, 40)
         ret = self.get_object(sha)
-        if ret._type != cls._type:
+        if not isinstance(ret, cls):
             if cls is Commit:
                 raise NotCommitError(ret)
             elif cls is Blob:
                 raise NotBlobError(ret)
             elif cls is Tree:
                 raise NotTreeError(ret)
+            elif cls is Tag:
+                raise NotTagError(ret)
             else:
-                raise Exception("Type invalid: %r != %r" % (ret._type, cls._type))
+                raise Exception("Type invalid: %r != %r" % (
+                  ret.type_name, cls.type_name))
         return ret
 
     def get_object(self, sha):
@@ -678,17 +856,71 @@ class BaseRepo(object):
                     for section in p.sections())
 
     def commit(self, sha):
+        """Retrieve the commit with a particular SHA.
+
+        :param sha: SHA of the commit to retrieve
+        :raise NotCommitError: If the SHA provided doesn't point at a Commit
+        :raise KeyError: If the SHA provided didn't exist
+        :return: A `Commit` object
+        """
+        warnings.warn("Repo.commit(sha) is deprecated. Use Repo[sha] instead.",
+            category=DeprecationWarning, stacklevel=2)
         return self._get_object(sha, Commit)
 
     def tree(self, sha):
+        """Retrieve the tree with a particular SHA.
+
+        :param sha: SHA of the tree to retrieve
+        :raise NotTreeError: If the SHA provided doesn't point at a Tree
+        :raise KeyError: If the SHA provided didn't exist
+        :return: A `Tree` object
+        """
+        warnings.warn("Repo.tree(sha) is deprecated. Use Repo[sha] instead.",
+            category=DeprecationWarning, stacklevel=2)
         return self._get_object(sha, Tree)
 
     def tag(self, sha):
+        """Retrieve the tag with a particular SHA.
+
+        :param sha: SHA of the tag to retrieve
+        :raise NotTagError: If the SHA provided doesn't point at a Tag
+        :raise KeyError: If the SHA provided didn't exist
+        :return: A `Tag` object
+        """
+        warnings.warn("Repo.tag(sha) is deprecated. Use Repo[sha] instead.",
+            category=DeprecationWarning, stacklevel=2)
         return self._get_object(sha, Tag)
 
     def get_blob(self, sha):
+        """Retrieve the blob with a particular SHA.
+
+        :param sha: SHA of the blob to retrieve
+        :raise NotBlobError: If the SHA provided doesn't point at a Blob
+        :raise KeyError: If the SHA provided didn't exist
+        :return: A `Blob` object
+        """
+        warnings.warn("Repo.get_blob(sha) is deprecated. Use Repo[sha] "
+            "instead.", category=DeprecationWarning, stacklevel=2)
         return self._get_object(sha, Blob)
 
+    def get_peeled(self, ref):
+        """Get the peeled value of a ref.
+
+        :param ref: the refname to peel
+        :return: the fully-peeled SHA1 of a tag object, after peeling all
+            intermediate tags; if the original ref does not point to a tag, this
+            will equal the original SHA1.
+        """
+        cached = self.refs.get_peeled(ref)
+        if cached is not None:
+            return cached
+        obj = self[ref]
+        obj_class = object_class(obj.type_name)
+        while obj_class is Tag:
+            obj_class, sha = obj.object
+            obj = self.get_object(sha)
+        return obj.id
+
     def revision_history(self, head):
         """Returns a list of the commits reachable from head.
 
@@ -707,9 +939,11 @@ class BaseRepo(object):
         while pending_commits != []:
             head = pending_commits.pop(0)
             try:
-                commit = self.commit(head)
+                commit = self[head]
             except KeyError:
                 raise MissingCommitError(head)
+            if type(commit) != Commit:
+                raise NotCommitError(commit)
             if commit in history:
                 continue
             i = 0
@@ -718,16 +952,24 @@ class BaseRepo(object):
                     break
                 i += 1
             history.insert(i, commit)
-            parents = commit.parents
-            pending_commits += parents
+            pending_commits += commit.parents
         history.reverse()
         return history
 
     def __getitem__(self, name):
         if len(name) in (20, 40):
-            return self.object_store[name]
+            try:
+                return self.object_store[name]
+            except KeyError:
+                pass
         return self.object_store[self.refs[name]]
 
+    def __contains__(self, name):
+        if len(name) in (20, 40):
+            return name in self.object_store or name in self.refs
+        else:
+            return name in self.refs
+
     def __setitem__(self, name, value):
         if name.startswith("refs/") or name == "HEAD":
             if isinstance(value, ShaFile):
@@ -736,43 +978,47 @@ class BaseRepo(object):
                 self.refs[name] = value
             else:
                 raise TypeError(value)
-        raise ValueError(name)
+        else:
+            raise ValueError(name)
 
     def __delitem__(self, name):
         if name.startswith("refs") or name == "HEAD":
             del self.refs[name]
         raise ValueError(name)
 
-    def do_commit(self, committer, message,
+    def do_commit(self, message, committer=None,
                   author=None, commit_timestamp=None,
-                  commit_timezone=None, author_timestamp=None, 
+                  commit_timezone=None, author_timestamp=None,
                   author_timezone=None, tree=None):
         """Create a new commit.
 
-        :param committer: Committer fullname
         :param message: Commit message
+        :param committer: Committer fullname
         :param author: Author fullname (defaults to committer)
         :param commit_timestamp: Commit timestamp (defaults to now)
         :param commit_timezone: Commit timestamp timezone (defaults to GMT)
         :param author_timestamp: Author timestamp (defaults to commit timestamp)
-        :param author_timezone: Author timestamp timezone 
+        :param author_timezone: Author timestamp timezone
             (defaults to commit timestamp timezone)
         :param tree: SHA1 of the tree root to use (if not specified the current index will be committed).
         :return: New commit SHA1
         """
-        from dulwich.index import commit_index
         import time
         index = self.open_index()
         c = Commit()
         if tree is None:
-            c.tree = commit_index(self.object_store, index)
+            c.tree = index.commit(self.object_store)
         else:
             c.tree = tree
+        # TODO: Allow username to be missing, and get it from .git/config
+        if committer is None:
+            raise ValueError("committer not set")
         c.committer = committer
         if commit_timestamp is None:
             commit_timestamp = time.time()
         c.commit_time = int(commit_timestamp)
         if commit_timezone is None:
+            # FIXME: Use current user timezone rather than UTC
             commit_timezone = 0
         c.commit_timezone = commit_timezone
         if author is None:
@@ -785,8 +1031,20 @@ class BaseRepo(object):
             author_timezone = commit_timezone
         c.author_timezone = author_timezone
         c.message = message
-        self.object_store.add_object(c)
-        self.refs["HEAD"] = c.id
+        try:
+            old_head = self.refs["HEAD"]
+            c.parents = [old_head]
+            self.object_store.add_object(c)
+            ok = self.refs.set_if_equals("HEAD", old_head, c.id)
+        except KeyError:
+            c.parents = []
+            self.object_store.add_object(c)
+            ok = self.refs.add_if_new("HEAD", c.id)
+        if not ok:
+            # Fail if the atomic compare-and-swap failed, leaving the commit and
+            # all its objects as garbage.
+            raise CommitError("HEAD changed during commit")
+
         return c.id
 
 
@@ -804,8 +1062,8 @@ class Repo(BaseRepo):
         else:
             raise NotGitRepository(root)
         self.path = root
-        object_store = DiskObjectStore(
-            os.path.join(self.controldir(), OBJECTDIR))
+        object_store = DiskObjectStore(os.path.join(self.controldir(),
+                                                    OBJECTDIR))
         refs = DiskRefsContainer(self.controldir())
         BaseRepo.__init__(self, object_store, refs)
 
@@ -852,7 +1110,40 @@ class Repo(BaseRepo):
 
     def has_index(self):
         """Check if an index is present."""
-        return os.path.exists(self.index_path())
+        # Bare repos must never have index files; non-bare repos may have a
+        # missing index file, which is treated as empty.
+        return not self.bare
+
+    def stage(self, paths):
+        """Stage a set of paths.
+
+        :param paths: List of paths, relative to the repository path
+        """
+        from dulwich.index import cleanup_mode
+        index = self.open_index()
+        for path in paths:
+            full_path = os.path.join(self.path, path)
+            blob = Blob()
+            try:
+                st = os.stat(full_path)
+            except OSError:
+                # File no longer exists
+                try:
+                    del index[path]
+                except KeyError:
+                    pass  # Doesn't exist in the index either
+            else:
+                f = open(full_path, 'rb')
+                try:
+                    blob.data = f.read()
+                finally:
+                    f.close()
+                self.object_store.add_object(blob)
+                # XXX: Cleanup some of the other file properties as well?
+                index[path] = (st.st_ctime, st.st_mtime, st.st_dev, st.st_ino,
+                    cleanup_mode(st.st_mode), st.st_uid, st.st_gid, st.st_size,
+                    blob.id, 0)
+        index.write()
 
     def __repr__(self):
         return "<Repo at %r>" % self.path
@@ -868,8 +1159,9 @@ class Repo(BaseRepo):
     def init_bare(cls, path, mkdir=True):
         for d in BASE_DIRECTORIES:
             os.mkdir(os.path.join(path, *d))
+        DiskObjectStore.init(os.path.join(path, OBJECTDIR))
         ret = cls(path)
-        ret.refs.set_ref("HEAD", "refs/heads/master")
+        ret.refs.set_symbolic_ref("HEAD", "refs/heads/master")
         ret._put_named_file('description', "Unnamed repository")
         ret._put_named_file('config', """[core]
     repositoryformatversion = 0
@@ -877,7 +1169,7 @@ class Repo(BaseRepo):
     bare = false
     logallrefupdates = true
 """)
-        ret._put_named_file(os.path.join('info', 'excludes'), '')
+        ret._put_named_file(os.path.join('info', 'exclude'), '')
         return ret
 
     create = init_bare

+ 251 - 143
dulwich/server.py

@@ -27,36 +27,55 @@ Documentation/technical directory in the cgit distribution, and in particular:
 
 
 import collections
+import socket
+import zlib
 import SocketServer
-import tempfile
 
 from dulwich.errors import (
     ApplyDeltaError,
     ChecksumMismatch,
     GitProtocolError,
+    ObjectFormatException,
     )
 from dulwich.objects import (
     hex_to_sha,
     )
+from dulwich.pack import (
+    PackStreamReader,
+    write_pack_data,
+    )
 from dulwich.protocol import (
-    Protocol,
+    MULTI_ACK,
+    MULTI_ACK_DETAILED,
     ProtocolFile,
+    ReceivableProtocol,
+    SINGLE_ACK,
     TCP_GIT_PORT,
+    ZERO_SHA,
+    ack_type,
     extract_capabilities,
     extract_want_line_capabilities,
-    SINGLE_ACK,
-    MULTI_ACK,
-    MULTI_ACK_DETAILED,
-    ack_type,
-    )
-from dulwich.repo import (
-    Repo,
-    )
-from dulwich.pack import (
-    write_pack_data,
     )
 
+
+
 class Backend(object):
+    """A backend for the Git smart server implementation."""
+
+    def open_repository(self, path):
+        """Open the repository at a path."""
+        raise NotImplementedError(self.open_repository)
+
+
+class BackendRepo(object):
+    """Repository abstraction used by the Git server.
+    
+    Please note that the methods required here are a 
+    subset of those provided by dulwich.repo.Repo.
+    """
+
+    object_store = None
+    refs = None
 
     def get_refs(self):
         """
@@ -66,144 +85,177 @@ class Backend(object):
         """
         raise NotImplementedError
 
-    def apply_pack(self, refs, read):
-        """ Import a set of changes into a repository and update the refs
+    def get_peeled(self, name):
+        """Return the cached peeled value of a ref, if available.
 
-        :param refs: list of tuple(name, sha)
-        :param read: callback to read from the incoming pack
+        :param name: Name of the ref to peel
+        :return: The peeled value of the ref. If the ref is known not point to
+            a tag, this will be the SHA the ref refers to. If no cached
+            information about a tag is available, this method may return None,
+            but it should attempt to peel the tag if possible.
         """
-        raise NotImplementedError
+        return None
 
-    def fetch_objects(self, determine_wants, graph_walker, progress):
+    def fetch_objects(self, determine_wants, graph_walker, progress,
+                      get_tagged=None):
         """
         Yield the objects required for a list of commits.
 
         :param progress: is a callback to send progress messages to the client
+        :param get_tagged: Function that returns a dict of pointed-to sha -> tag
+            sha for including tags.
         """
         raise NotImplementedError
 
 
-class GitBackend(Backend):
+class PackStreamCopier(PackStreamReader):
+    """Class to verify a pack stream as it is being read.
 
-    def __init__(self, repo=None):
-        if repo is None:
-            repo = Repo(tmpfile.mkdtemp())
-        self.repo = repo
-        self.object_store = self.repo.object_store
-        self.fetch_objects = self.repo.fetch_objects
-        self.get_refs = self.repo.get_refs
+    The pack is read from a ReceivableProtocol using read() or recv() as
+    appropriate and written out to the given file-like object.
+    """
 
-    def apply_pack(self, refs, read):
-        f, commit = self.repo.object_store.add_thin_pack()
-        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
-        status = []
-        unpack_error = None
-        # TODO: more informative error messages than just the exception string
-        try:
-            # TODO: decode the pack as we stream to avoid blocking reads beyond
-            # the end of data (when using HTTP/1.1 chunked encoding)
-            while True:
-                data = read(10240)
-                if not data:
-                    break
-                f.write(data)
-        except all_exceptions, e:
-            unpack_error = str(e).replace('\n', '')
-        try:
-            commit()
-        except all_exceptions, e:
-            if not unpack_error:
-                unpack_error = str(e).replace('\n', '')
+    def __init__(self, read_all, read_some, outfile):
+        super(PackStreamCopier, self).__init__(read_all, read_some)
+        self.outfile = outfile
 
-        if unpack_error:
-            status.append(('unpack', unpack_error))
-        else:
-            status.append(('unpack', 'ok'))
+    def _read(self, read, size):
+        data = super(PackStreamCopier, self)._read(read, size)
+        self.outfile.write(data)
+        return data
 
-        for oldsha, sha, ref in refs:
-            # TODO: check refname
-            ref_error = None
-            try:
-                if ref == "0" * 40:
-                    try:
-                        del self.repo.refs[ref]
-                    except all_exceptions:
-                        ref_error = 'failed to delete'
-                else:
-                    try:
-                        self.repo.refs[ref] = sha
-                    except all_exceptions:
-                        ref_error = 'failed to write'
-            except KeyError, e:
-                ref_error = 'bad ref'
-            if ref_error:
-                status.append((ref, ref_error))
-            else:
-                status.append((ref, 'ok'))
+    def verify(self):
+        """Verify a pack stream and write it to the output file.
 
+        See PackStreamReader.iterobjects for a list of exceptions this may
+        throw.
+        """
+        for _, _, _ in self.read_objects():
+            pass
 
-        print "pack applied"
-        return status
+
+class DictBackend(Backend):
+    """Trivial backend that looks up Git repositories in a dictionary."""
+
+    def __init__(self, repos):
+        self.repos = repos
+
+    def open_repository(self, path):
+        # FIXME: What to do in case there is no repo ?
+        return self.repos[path]
 
 
 class Handler(object):
     """Smart protocol command handler base class."""
 
-    def __init__(self, backend, read, write):
+    def __init__(self, backend, proto):
         self.backend = backend
-        self.proto = Protocol(read, write)
+        self.proto = proto
+        self._client_capabilities = None
+
+    def capability_line(self):
+        return " ".join(self.capabilities())
 
     def capabilities(self):
-        return " ".join(self.default_capabilities())
+        raise NotImplementedError(self.capabilities)
+
+    def innocuous_capabilities(self):
+        return ("include-tag", "thin-pack", "no-progress", "ofs-delta")
+
+    def required_capabilities(self):
+        """Return a list of capabilities that we require the client to have."""
+        return []
+
+    def set_client_capabilities(self, caps):
+        allowable_caps = set(self.innocuous_capabilities())
+        allowable_caps.update(self.capabilities())
+        for cap in caps:
+            if cap not in allowable_caps:
+                raise GitProtocolError('Client asked for capability %s that '
+                                       'was not advertised.' % cap)
+        for cap in self.required_capabilities():
+            if cap not in caps:
+                raise GitProtocolError('Client does not support required '
+                                       'capability %s.' % cap)
+        self._client_capabilities = set(caps)
+
+    def has_capability(self, cap):
+        if self._client_capabilities is None:
+            raise GitProtocolError('Server attempted to access capability %s '
+                                   'before asking client' % cap)
+        return cap in self._client_capabilities
 
 
 class UploadPackHandler(Handler):
     """Protocol handler for uploading a pack to the server."""
 
-    def __init__(self, backend, read, write,
+    def __init__(self, backend, args, proto,
                  stateless_rpc=False, advertise_refs=False):
-        Handler.__init__(self, backend, read, write)
-        self._client_capabilities = None
+        Handler.__init__(self, backend, proto)
+        self.repo = backend.open_repository(args[0])
         self._graph_walker = None
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
 
-    def default_capabilities(self):
+    def capabilities(self):
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
-                "ofs-delta")
+                "ofs-delta", "no-progress", "include-tag")
 
-    def set_client_capabilities(self, caps):
-        my_caps = self.default_capabilities()
-        for cap in caps:
-            if '_ack' in cap and cap not in my_caps:
-                raise GitProtocolError('Client asked for capability %s that '
-                                       'was not advertised.' % cap)
-        self._client_capabilities = caps
+    def required_capabilities(self):
+        return ("side-band-64k", "thin-pack", "ofs-delta")
 
-    def get_client_capabilities(self):
-        return self._client_capabilities
+    def progress(self, message):
+        if self.has_capability("no-progress"):
+            return
+        self.proto.write_sideband(2, message)
 
-    client_capabilities = property(get_client_capabilities,
-                                   set_client_capabilities)
+    def get_tagged(self, refs=None, repo=None):
+        """Get a dict of peeled values of tags to their original tag shas.
 
-    def handle(self):
+        :param refs: dict of refname -> sha of possible tags; defaults to all of
+            the backend's refs.
+        :param repo: optional Repo instance for getting peeled refs; defaults to
+            the backend's repo, if available
+        :return: dict of peeled_sha -> tag_sha, where tag_sha is the sha of a
+            tag whose peeled value is peeled_sha.
+        """
+        if not self.has_capability("include-tag"):
+            return {}
+        if refs is None:
+            refs = self.repo.get_refs()
+        if repo is None:
+            repo = getattr(self.repo, "repo", None)
+            if repo is None:
+                # Bail if we don't have a Repo available; this is ok since
+                # clients must be able to handle if the server doesn't include
+                # all relevant tags.
+                # TODO: fix behavior when missing
+                return {}
+        tagged = {}
+        for name, sha in refs.iteritems():
+            peeled_sha = repo.get_peeled(name)
+            if peeled_sha != sha:
+                tagged[peeled_sha] = sha
+        return tagged
 
-        progress = lambda x: self.proto.write_sideband(2, x)
+    def handle(self):
         write = lambda x: self.proto.write_sideband(1, x)
 
-        graph_walker = ProtocolGraphWalker(self)
-        objects_iter = self.backend.fetch_objects(
-          graph_walker.determine_wants, graph_walker, progress)
+        graph_walker = ProtocolGraphWalker(self, self.repo.object_store,
+            self.repo.get_peeled)
+        objects_iter = self.repo.fetch_objects(
+          graph_walker.determine_wants, graph_walker, self.progress,
+          get_tagged=self.get_tagged)
 
         # Do they want any objects?
         if len(objects_iter) == 0:
             return
 
-        progress("dul-daemon says what\n")
-        progress("counting objects: %d, done.\n" % len(objects_iter))
+        self.progress("dul-daemon says what\n")
+        self.progress("counting objects: %d, done.\n" % len(objects_iter))
         write_pack_data(ProtocolFile(None, write), objects_iter, 
                         len(objects_iter))
-        progress("how was that, then?\n")
+        self.progress("how was that, then?\n")
         # we are done
         self.proto.write("0000")
 
@@ -211,9 +263,9 @@ class UploadPackHandler(Handler):
 class ProtocolGraphWalker(object):
     """A graph walker that knows the git protocol.
 
-    As a graph walker, this class implements ack(), next(), and reset(). It also
-    contains some base methods for interacting with the wire and walking the
-    commit tree.
+    As a graph walker, this class implements ack(), next(), and reset(). It
+    also contains some base methods for interacting with the wire and walking
+    the commit tree.
 
     The work of determining which acks to send is passed on to the
     implementation instance stored in _impl. The reason for this is that we do
@@ -221,9 +273,10 @@ class ProtocolGraphWalker(object):
     call to set_ack_level() is required to set up the implementation, before any
     calls to next() or ack() are made.
     """
-    def __init__(self, handler):
+    def __init__(self, handler, object_store, get_peeled):
         self.handler = handler
-        self.store = handler.backend.object_store
+        self.store = object_store
+        self.get_peeled = get_peeled
         self.proto = handler.proto
         self.stateless_rpc = handler.stateless_rpc
         self.advertise_refs = handler.advertise_refs
@@ -251,9 +304,12 @@ class ProtocolGraphWalker(object):
             for i, (ref, sha) in enumerate(heads.iteritems()):
                 line = "%s %s" % (sha, ref)
                 if not i:
-                    line = "%s\x00%s" % (line, self.handler.capabilities())
+                    line = "%s\x00%s" % (line, self.handler.capability_line())
                 self.proto.write_pkt_line("%s\n" % line)
-                # TODO: include peeled value of any tags
+                peeled_sha = self.get_peeled(ref)
+                if peeled_sha != sha:
+                    self.proto.write_pkt_line('%s %s^{}\n' %
+                                              (peeled_sha, ref))
 
             # i'm done..
             self.proto.write_pkt_line(None)
@@ -266,7 +322,7 @@ class ProtocolGraphWalker(object):
         if not want:
             return []
         line, caps = extract_want_line_capabilities(want)
-        self.handler.client_capabilities = caps
+        self.handler.set_client_capabilities(caps)
         self.set_ack_type(ack_type(caps))
         command, sha = self._split_proto_line(line)
 
@@ -274,10 +330,10 @@ class ProtocolGraphWalker(object):
         while command != None:
             if command != 'want':
                 raise GitProtocolError(
-                    'Protocol got unexpected command %s' % command)
+                  'Protocol got unexpected command %s' % command)
             if sha not in values:
                 raise GitProtocolError(
-                    'Client wants invalid object %s' % sha)
+                  'Client wants invalid object %s' % sha)
             want_revs.append(sha)
             command, sha = self.read_proto_line()
 
@@ -359,10 +415,10 @@ class ProtocolGraphWalker(object):
             commit = pending.popleft()
             if commit.id in haves:
                 return True
-            if not getattr(commit, 'get_parents', None):
+            if commit.type_name != "commit":
                 # non-commit wants are assumed to be satisfied
                 continue
-            for parent in commit.get_parents():
+            for parent in commit.parents:
                 parent_obj = self.store[parent]
                 # TODO: handle parents with later commit times than children
                 if parent_obj.commit_time >= earliest:
@@ -385,10 +441,10 @@ class ProtocolGraphWalker(object):
 
     def set_ack_type(self, ack_type):
         impl_classes = {
-            MULTI_ACK: MultiAckGraphWalkerImpl,
-            MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
-            SINGLE_ACK: SingleAckGraphWalkerImpl,
-            }
+          MULTI_ACK: MultiAckGraphWalkerImpl,
+          MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
+          SINGLE_ACK: SingleAckGraphWalkerImpl,
+          }
         self._impl = impl_classes[ack_type](self)
 
 
@@ -497,32 +553,72 @@ class MultiAckDetailedGraphWalkerImpl(object):
 class ReceivePackHandler(Handler):
     """Protocol handler for downloading a pack from the client."""
 
-    def __init__(self, backend, read, write,
+    def __init__(self, backend, args, proto,
                  stateless_rpc=False, advertise_refs=False):
-        Handler.__init__(self, backend, read, write)
+        Handler.__init__(self, backend, proto)
+        self.repo = backend.open_repository(args[0])
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
 
-    def __init__(self, backend, read, write,
-                 stateless_rpc=False, advertise_refs=False):
-        Handler.__init__(self, backend, read, write)
-        self._stateless_rpc = stateless_rpc
-        self._advertise_refs = advertise_refs
-
-    def default_capabilities(self):
+    def capabilities(self):
         return ("report-status", "delete-refs")
 
+    def _apply_pack(self, refs):
+        f, commit = self.repo.object_store.add_thin_pack()
+        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError,
+                          AssertionError, socket.error, zlib.error,
+                          ObjectFormatException)
+        status = []
+        # TODO: more informative error messages than just the exception string
+        try:
+            PackStreamCopier(self.proto.read, self.proto.recv, f).verify()
+            p = commit()
+            if not p:
+                raise IOError('Failed to write pack')
+            p.check()
+            status.append(('unpack', 'ok'))
+        except all_exceptions, e:
+            status.append(('unpack', str(e).replace('\n', '')))
+            # The pack may still have been moved in, but it may contain broken
+            # objects. We trust a later GC to clean it up.
+
+        for oldsha, sha, ref in refs:
+            ref_status = 'ok'
+            try:
+                if sha == ZERO_SHA:
+                    if not 'delete-refs' in self.capabilities():
+                        raise GitProtocolError(
+                          'Attempted to delete refs without delete-refs '
+                          'capability.')
+                    try:
+                        del self.repo.refs[ref]
+                    except all_exceptions:
+                        ref_status = 'failed to delete'
+                else:
+                    try:
+                        self.repo.refs[ref] = sha
+                    except all_exceptions:
+                        ref_status = 'failed to write'
+            except KeyError, e:
+                ref_status = 'bad ref'
+            status.append((ref, ref_status))
+
+        return status
+
     def handle(self):
-        refs = self.backend.get_refs().items()
+        refs = self.repo.get_refs().items()
 
         if self.advertise_refs or not self.stateless_rpc:
             if refs:
-                self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
+                self.proto.write_pkt_line(
+                  "%s %s\x00%s\n" % (refs[0][1], refs[0][0],
+                                     self.capability_line()))
                 for i in range(1, len(refs)):
                     ref = refs[i]
                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
             else:
-                self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
+                self.proto.write_pkt_line("%s capabilities^{} %s" % (
+                  ZERO_SHA, self.capability_line()))
 
             self.proto.write("0000")
             if self.advertise_refs:
@@ -535,7 +631,8 @@ class ReceivePackHandler(Handler):
         if ref is None:
             return
 
-        ref, client_capabilities = extract_capabilities(ref)
+        ref, caps = extract_capabilities(ref)
+        self.set_client_capabilities(caps)
 
         # client will now send us a list of (oldsha, newsha, ref)
         while ref:
@@ -543,11 +640,11 @@ class ReceivePackHandler(Handler):
             ref = self.proto.read_pkt_line()
 
         # backend can now deal with this refs and read a pack using self.read
-        status = self.backend.apply_pack(client_refs, self.proto.read)
+        status = self._apply_pack(client_refs)
 
         # when we have read all the pack from the client, send a status report
         # if the client asked for it
-        if 'report-status' in client_capabilities:
+        if self.has_capability('report-status'):
             for name, msg in status:
                 if name == 'unpack':
                     self.proto.write_pkt_line('unpack %s\n' % msg)
@@ -558,21 +655,27 @@ class ReceivePackHandler(Handler):
             self.proto.write_pkt_line(None)
 
 
+# Default handler classes for git services.
+DEFAULT_HANDLERS = {
+  'git-upload-pack': UploadPackHandler,
+  'git-receive-pack': ReceivePackHandler,
+  }
+
+
 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
 
+    def __init__(self, handlers, *args, **kwargs):
+        self.handlers = handlers and handlers or DEFAULT_HANDLERS
+        SocketServer.StreamRequestHandler.__init__(self, *args, **kwargs)
+
     def handle(self):
-        proto = Protocol(self.rfile.read, self.wfile.write)
+        proto = ReceivableProtocol(self.connection.recv, self.wfile.write)
         command, args = proto.read_cmd()
 
-        # switch case to handle the specific git command
-        if command == 'git-upload-pack':
-            cls = UploadPackHandler
-        elif command == 'git-receive-pack':
-            cls = ReceivePackHandler
-        else:
-            return
-
-        h = cls(self.server.backend, self.rfile.read, self.wfile.write)
+        cls = self.handlers.get(command, None)
+        if not callable(cls):
+            raise GitProtocolError('Invalid service %s' % command)
+        h = cls(self.server.backend, args, proto)
         h.handle()
 
 
@@ -581,6 +684,11 @@ class TCPGitServer(SocketServer.TCPServer):
     allow_reuse_address = True
     serve = SocketServer.TCPServer.serve_forever
 
-    def __init__(self, backend, listen_addr, port=TCP_GIT_PORT):
+    def _make_handler(self, *args, **kwargs):
+        return TCPGitRequestHandler(self.handlers, *args, **kwargs)
+
+    def __init__(self, backend, listen_addr, port=TCP_GIT_PORT, handlers=None):
         self.backend = backend
-        SocketServer.TCPServer.__init__(self, (listen_addr, port), TCPGitRequestHandler)
+        self.handlers = handlers
+        SocketServer.TCPServer.__init__(self, (listen_addr, port),
+                                        self._make_handler)

+ 30 - 0
dulwich/tests/__init__.py

@@ -18,3 +18,33 @@
 # MA  02110-1301, USA.
 
 """Tests for Dulwich."""
+
+import unittest
+
+# XXX: Ideally we should allow other test runners as well, 
+# but unfortunately unittest doesn't have a SkipTest/TestSkipped
+# exception.
+from nose import SkipTest as TestSkipped
+
+def test_suite():
+    names = [
+        'client',
+        'fastexport',
+        'file',
+        'index',
+        'lru_cache',
+        'objects',
+        'object_store',
+        'pack',
+        'patch',
+        'protocol',
+        'repository',
+        'server',
+        'web',
+        ]
+    module_names = ['dulwich.tests.test_' + name for name in names]
+    result = unittest.TestSuite()
+    loader = unittest.TestLoader()
+    suite = loader.loadTestsFromNames(module_names)
+    result.addTests(suite)
+    return result

+ 157 - 0
dulwich/tests/compat/server_utils.py

@@ -0,0 +1,157 @@
+# server_utils.py -- Git server compatibility utilities
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Utilities for testing git server compatibility."""
+
+
+import select
+import socket
+import threading
+
+from dulwich.tests.utils import (
+    tear_down_repo,
+    )
+from utils import (
+    import_repo,
+    run_git,
+    )
+
+
+class ServerTests(object):
+    """Base tests for testing servers.
+
+    Does not inherit from TestCase so tests are not automatically run.
+    """
+
+    def setUp(self):
+        self._old_repo = import_repo('server_old.export')
+        self._new_repo = import_repo('server_new.export')
+        self._server = None
+
+    def tearDown(self):
+        if self._server is not None:
+            self._server.shutdown()
+            self._server = None
+        tear_down_repo(self._old_repo)
+        tear_down_repo(self._new_repo)
+
+    def test_push_to_dulwich(self):
+        self.assertReposNotEqual(self._old_repo, self._new_repo)
+        port = self._start_server(self._old_repo)
+
+        all_branches = ['master', 'branch']
+        branch_args = ['%s:%s' % (b, b) for b in all_branches]
+        url = '%s://localhost:%s/' % (self.protocol, port)
+        returncode, _ = run_git(['push', url] + branch_args,
+                                cwd=self._new_repo.path)
+        self.assertEqual(0, returncode)
+        self.assertReposEqual(self._old_repo, self._new_repo)
+
+    def test_fetch_from_dulwich(self):
+        self.assertReposNotEqual(self._old_repo, self._new_repo)
+        port = self._start_server(self._new_repo)
+
+        all_branches = ['master', 'branch']
+        branch_args = ['%s:%s' % (b, b) for b in all_branches]
+        url = '%s://localhost:%s/' % (self.protocol, port)
+        returncode, _ = run_git(['fetch', url] + branch_args,
+                                cwd=self._old_repo.path)
+        # flush the pack cache so any new packs are picked up
+        self._old_repo.object_store._pack_cache = None
+        self.assertEqual(0, returncode)
+        self.assertReposEqual(self._old_repo, self._new_repo)
+
+
+class ShutdownServerMixIn:
+    """Mixin that allows serve_forever to be shut down.
+
+    The methods in this mixin are backported from SocketServer.py in the Python
+    2.6.4 standard library. The mixin is unnecessary in 2.6 and later, when
+    BaseServer supports the shutdown method directly.
+    """
+
+    def __init__(self):
+        self.__is_shut_down = threading.Event()
+        self.__serving = False
+
+    def serve_forever(self, poll_interval=0.5):
+        """Handle one request at a time until shutdown.
+
+        Polls for shutdown every poll_interval seconds. Ignores
+        self.timeout. If you need to do periodic tasks, do them in
+        another thread.
+        """
+        self.__serving = True
+        self.__is_shut_down.clear()
+        while self.__serving:
+            # XXX: Consider using another file descriptor or
+            # connecting to the socket to wake this up instead of
+            # polling. Polling reduces our responsiveness to a
+            # shutdown request and wastes cpu at all other times.
+            r, w, e = select.select([self], [], [], poll_interval)
+            if r:
+                self._handle_request_noblock()
+        self.__is_shut_down.set()
+
+    serve = serve_forever  # override alias from TCPGitServer
+
+    def shutdown(self):
+        """Stops the serve_forever loop.
+
+        Blocks until the loop has finished. This must be called while
+        serve_forever() is running in another thread, or it will deadlock.
+        """
+        self.__serving = False
+        self.__is_shut_down.wait()
+
+    def handle_request(self):
+        """Handle one request, possibly blocking.
+
+        Respects self.timeout.
+        """
+        # Support people who used socket.settimeout() to escape
+        # handle_request before self.timeout was available.
+        timeout = self.socket.gettimeout()
+        if timeout is None:
+            timeout = self.timeout
+        elif self.timeout is not None:
+            timeout = min(timeout, self.timeout)
+        fd_sets = select.select([self], [], [], timeout)
+        if not fd_sets[0]:
+            self.handle_timeout()
+            return
+        self._handle_request_noblock()
+
+    def _handle_request_noblock(self):
+        """Handle one request, without blocking.
+
+        I assume that select.select has returned that the socket is
+        readable before this function was called, so there should be
+        no risk of blocking in get_request().
+        """
+        try:
+            request, client_address = self.get_request()
+        except socket.error:
+            return
+        if self.verify_request(request, client_address):
+            try:
+                self.process_request(request, client_address)
+            except:
+                self.handle_error(request, client_address)
+                self.close_request(request)

+ 179 - 0
dulwich/tests/compat/test_client.py

@@ -0,0 +1,179 @@
+# test_client.py -- Compatibilty tests for git client.
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Compatibilty tests between the Dulwich client and the cgit server."""
+
+import os
+import shutil
+import signal
+import tempfile
+
+from dulwich import client
+from dulwich import errors
+from dulwich import file
+from dulwich import index
+from dulwich import protocol
+from dulwich import object_store
+from dulwich import objects
+from dulwich import repo
+from dulwich.tests import (
+    TestSkipped,
+    )
+
+from utils import (
+    CompatTestCase,
+    check_for_daemon,
+    import_repo_to_dir,
+    run_git,
+    )
+
+class DulwichClientTest(CompatTestCase):
+    """Tests for client/server compatibility."""
+
+    def setUp(self):
+        if check_for_daemon(limit=1):
+            raise TestSkipped('git-daemon was already running on port %s' %
+                              protocol.TCP_GIT_PORT)
+        CompatTestCase.setUp(self)
+        fd, self.pidfile = tempfile.mkstemp(prefix='dulwich-test-git-client',
+                                            suffix=".pid")
+        os.fdopen(fd).close()
+        self.gitroot = os.path.dirname(import_repo_to_dir('server_new.export'))
+        dest = os.path.join(self.gitroot, 'dest')
+        file.ensure_dir_exists(dest)
+        run_git(['init', '--bare'], cwd=dest)
+        run_git(
+            ['daemon', '--verbose', '--export-all',
+             '--pid-file=%s' % self.pidfile, '--base-path=%s' % self.gitroot,
+             '--detach', '--reuseaddr', '--enable=receive-pack',
+             '--listen=localhost', self.gitroot], cwd=self.gitroot)
+        if not check_for_daemon():
+            raise TestSkipped('git-daemon failed to start')
+
+    def tearDown(self):
+        CompatTestCase.tearDown(self)
+        try:
+            os.kill(int(open(self.pidfile).read().strip()), signal.SIGKILL)
+            os.unlink(self.pidfile)
+        except (OSError, IOError):
+            pass
+        shutil.rmtree(self.gitroot)
+
+    def assertDestEqualsSrc(self):
+        src = repo.Repo(os.path.join(self.gitroot, 'server_new.export'))
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        self.assertReposEqual(src, dest)
+
+    def _do_send_pack(self):
+        c = client.TCPGitClient('localhost')
+        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        src = repo.Repo(srcpath)
+        sendrefs = dict(src.get_refs())
+        del sendrefs['HEAD']
+        c.send_pack('/dest', lambda _: sendrefs,
+                    src.object_store.generate_pack_contents)
+
+    def test_send_pack(self):
+        self._do_send_pack()
+        self.assertDestEqualsSrc()
+
+    def test_send_pack_nothing_to_send(self):
+        self._do_send_pack()
+        self.assertDestEqualsSrc()
+        # nothing to send, but shouldn't raise either.
+        self._do_send_pack()
+
+    def test_send_without_report_status(self):
+        c = client.TCPGitClient('localhost')
+        c._send_capabilities.remove('report-status')
+        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        src = repo.Repo(srcpath)
+        sendrefs = dict(src.get_refs())
+        del sendrefs['HEAD']
+        c.send_pack('/dest', lambda _: sendrefs,
+                    src.object_store.generate_pack_contents)
+        self.assertDestEqualsSrc()
+
+    def disable_ff_and_make_dummy_commit(self):
+        # disable non-fast-forward pushes to the server
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        run_git(['config', 'receive.denyNonFastForwards', 'true'], cwd=dest.path)
+        b = objects.Blob.from_string('hi')
+        dest.object_store.add_object(b)
+        t = index.commit_tree(dest.object_store, [('hi', b.id, 0100644)])
+        c = objects.Commit()
+        c.author = c.committer = 'Foo Bar <foo@example.com>'
+        c.author_time = c.commit_time = 0
+        c.author_timezone = c.commit_timezone = 0
+        c.message = 'hi'
+        c.tree = t
+        dest.object_store.add_object(c)
+        return dest, c.id
+
+    def compute_send(self):
+        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        src = repo.Repo(srcpath)
+        sendrefs = dict(src.get_refs())
+        del sendrefs['HEAD']
+        return sendrefs, src.object_store.generate_pack_contents
+
+    def test_send_pack_one_error(self):
+        dest, dummy_commit = self.disable_ff_and_make_dummy_commit()
+        dest.refs['refs/heads/master'] = dummy_commit
+        sendrefs, gen_pack = self.compute_send()
+        c = client.TCPGitClient('localhost')
+        try:
+            c.send_pack('/dest', lambda _: sendrefs, gen_pack)
+        except errors.UpdateRefsError, e:
+            self.assertEqual('refs/heads/master failed to update', str(e))
+            self.assertEqual({'refs/heads/branch': 'ok',
+                              'refs/heads/master': 'non-fast-forward'},
+                             e.ref_status)
+
+    def test_send_pack_multiple_errors(self):
+        dest, dummy = self.disable_ff_and_make_dummy_commit()
+        # set up for two non-ff errors
+        dest.refs['refs/heads/branch'] = dest.refs['refs/heads/master'] = dummy
+        sendrefs, gen_pack = self.compute_send()
+        c = client.TCPGitClient('localhost')
+        try:
+            c.send_pack('/dest', lambda _: sendrefs, gen_pack)
+        except errors.UpdateRefsError, e:
+            self.assertEqual('refs/heads/branch, refs/heads/master failed to '
+                             'update', str(e))
+            self.assertEqual({'refs/heads/branch': 'non-fast-forward',
+                              'refs/heads/master': 'non-fast-forward'},
+                             e.ref_status)
+
+    def test_fetch_pack(self):
+        c = client.TCPGitClient('localhost')
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        refs = c.fetch('/server_new.export', dest)
+        map(lambda r: dest.refs.set_if_equals(r[0], None, r[1]), refs.items())
+        self.assertDestEqualsSrc()
+
+    def test_incremental_fetch_pack(self):
+        self.test_fetch_pack()
+        dest, dummy = self.disable_ff_and_make_dummy_commit()
+        dest.refs['refs/heads/master'] = dummy
+        c = client.TCPGitClient('localhost')
+        dest = repo.Repo(os.path.join(self.gitroot, 'server_new.export'))
+        refs = c.fetch('/dest', dest)
+        map(lambda r: dest.refs.set_if_equals(r[0], None, r[1]), refs.items())
+        self.assertDestEqualsSrc()

+ 73 - 0
dulwich/tests/compat/test_pack.py

@@ -0,0 +1,73 @@
+# test_pack.py -- Compatibilty tests for git packs.
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Compatibilty tests for git packs."""
+
+
+import binascii
+import os
+import shutil
+import tempfile
+
+from dulwich.pack import (
+    write_pack,
+    )
+from dulwich.tests.test_pack import (
+    pack1_sha,
+    PackTests,
+    )
+from utils import (
+    require_git_version,
+    run_git,
+    )
+
+
+class TestPack(PackTests):
+    """Compatibility tests for reading and writing pack files."""
+
+    def setUp(self):
+        require_git_version((1, 5, 0))
+        PackTests.setUp(self)
+        self._tempdir = tempfile.mkdtemp()
+
+    def tearDown(self):
+        shutil.rmtree(self._tempdir)
+        PackTests.tearDown(self)
+
+    def test_copy(self):
+        origpack = self.get_pack(pack1_sha)
+        self.assertSucceeds(origpack.index.check)
+        pack_path = os.path.join(self._tempdir, "Elch")
+        write_pack(pack_path, [(x, "") for x in origpack.iterobjects()],
+                   len(origpack))
+
+        returncode, output = run_git(['verify-pack', '-v', pack_path],
+                                     capture_stdout=True)
+        self.assertEquals(0, returncode)
+
+        pack_shas = set()
+        for line in output.splitlines():
+            sha = line[:40]
+            try:
+                binascii.unhexlify(sha)
+            except TypeError:
+                continue  # non-sha line
+            pack_shas.add(sha)
+        orig_shas = set(o.id for o in origpack.iterobjects())
+        self.assertEquals(orig_shas, pack_shas)

+ 131 - 0
dulwich/tests/compat/test_repository.py

@@ -0,0 +1,131 @@
+# test_repo.py -- Git repo compatibility tests
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Compatibility tests for dulwich repositories."""
+
+
+from cStringIO import StringIO
+import itertools
+import os
+
+from dulwich.objects import (
+    hex_to_sha,
+    )
+from dulwich.repo import (
+    check_ref_format,
+    )
+from dulwich.tests.utils import (
+    tear_down_repo,
+    )
+
+from utils import (
+    run_git,
+    import_repo,
+    CompatTestCase,
+    )
+
+
+class ObjectStoreTestCase(CompatTestCase):
+    """Tests for git repository compatibility."""
+
+    def setUp(self):
+        CompatTestCase.setUp(self)
+        self._repo = import_repo('server_new.export')
+
+    def tearDown(self):
+        CompatTestCase.tearDown(self)
+        tear_down_repo(self._repo)
+
+    def _run_git(self, args):
+        returncode, output = run_git(args, capture_stdout=True,
+                                     cwd=self._repo.path)
+        self.assertEqual(0, returncode)
+        return output
+
+    def _parse_refs(self, output):
+        refs = {}
+        for line in StringIO(output):
+            fields = line.rstrip('\n').split(' ')
+            self.assertEqual(3, len(fields))
+            refname, type_name, sha = fields
+            check_ref_format(refname[5:])
+            hex_to_sha(sha)
+            refs[refname] = (type_name, sha)
+        return refs
+
+    def _parse_objects(self, output):
+        return set(s.rstrip('\n').split(' ')[0] for s in StringIO(output))
+
+    def test_bare(self):
+        self.assertTrue(self._repo.bare)
+        self.assertFalse(os.path.exists(os.path.join(self._repo.path, '.git')))
+
+    def test_head(self):
+        output = self._run_git(['rev-parse', 'HEAD'])
+        head_sha = output.rstrip('\n')
+        hex_to_sha(head_sha)
+        self.assertEqual(head_sha, self._repo.refs['HEAD'])
+
+    def test_refs(self):
+        output = self._run_git(
+          ['for-each-ref', '--format=%(refname) %(objecttype) %(objectname)'])
+        expected_refs = self._parse_refs(output)
+
+        actual_refs = {}
+        for refname, sha in self._repo.refs.as_dict().iteritems():
+            if refname == 'HEAD':
+                continue  # handled in test_head
+            obj = self._repo[sha]
+            self.assertEqual(sha, obj.id)
+            actual_refs[refname] = (obj.type_name, obj.id)
+        self.assertEqual(expected_refs, actual_refs)
+
+    # TODO(dborowitz): peeled ref tests
+
+    def _get_loose_shas(self):
+        output = self._run_git(['rev-list', '--all', '--objects', '--unpacked'])
+        return self._parse_objects(output)
+
+    def _get_all_shas(self):
+        output = self._run_git(['rev-list', '--all', '--objects'])
+        return self._parse_objects(output)
+
+    def assertShasMatch(self, expected_shas, actual_shas_iter):
+        actual_shas = set()
+        for sha in actual_shas_iter:
+            obj = self._repo[sha]
+            self.assertEqual(sha, obj.id)
+            actual_shas.add(sha)
+        self.assertEqual(expected_shas, actual_shas)
+
+    def test_loose_objects(self):
+        # TODO(dborowitz): This is currently not very useful since fast-imported
+        # repos only contained packed objects.
+        expected_shas = self._get_loose_shas()
+        self.assertShasMatch(expected_shas,
+                             self._repo.object_store._iter_loose_objects())
+
+    def test_packed_objects(self):
+        expected_shas = self._get_all_shas() - self._get_loose_shas()
+        self.assertShasMatch(expected_shas,
+                             itertools.chain(*self._repo.object_store.packs))
+
+    def test_all_objects(self):
+        expected_shas = self._get_all_shas()
+        self.assertShasMatch(expected_shas, iter(self._repo.object_store))

+ 78 - 0
dulwich/tests/compat/test_server.py

@@ -0,0 +1,78 @@
+# test_server.py -- Compatibilty tests for git server.
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Compatibilty tests between Dulwich and the cgit server.
+
+Warning: these tests should be fairly stable, but when writing/debugging new
+tests, deadlocks may freeze the test process such that it cannot be Ctrl-C'ed.
+On *nix, you can kill the tests with Ctrl-Z, "kill %".
+"""
+
+import threading
+
+from dulwich.server import (
+    DictBackend,
+    TCPGitServer,
+    )
+from dulwich.tests import (
+    TestSkipped,
+    )
+from server_utils import (
+    ServerTests,
+    ShutdownServerMixIn,
+    )
+from utils import (
+    CompatTestCase,
+    )
+
+
+if not getattr(TCPGitServer, 'shutdown', None):
+    _TCPGitServer = TCPGitServer
+
+    class TCPGitServer(ShutdownServerMixIn, TCPGitServer):
+        """Subclass of TCPGitServer that can be shut down."""
+
+        def __init__(self, *args, **kwargs):
+            # BaseServer is old-style so we have to call both __init__s
+            ShutdownServerMixIn.__init__(self)
+            _TCPGitServer.__init__(self, *args, **kwargs)
+
+        serve = ShutdownServerMixIn.serve_forever
+
+
+class GitServerTestCase(ServerTests, CompatTestCase):
+    """Tests for client/server compatibility."""
+
+    protocol = 'git'
+
+    def setUp(self):
+        ServerTests.setUp(self)
+        CompatTestCase.setUp(self)
+
+    def tearDown(self):
+        ServerTests.tearDown(self)
+        CompatTestCase.tearDown(self)
+
+    def _start_server(self, repo):
+        backend = DictBackend({'/': repo})
+        dul_server = TCPGitServer(backend, 'localhost', 0)
+        threading.Thread(target=dul_server.serve).start()
+        self._server = dul_server
+        _, port = self._server.socket.getsockname()
+        return port

+ 116 - 0
dulwich/tests/compat/test_web.py

@@ -0,0 +1,116 @@
+# test_web.py -- Compatibilty tests for the git web server.
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Compatibilty tests between Dulwich and the cgit HTTP server.
+
+Warning: these tests should be fairly stable, but when writing/debugging new
+tests, deadlocks may freeze the test process such that it cannot be Ctrl-C'ed.
+On *nix, you can kill the tests with Ctrl-Z, "kill %".
+"""
+
+import threading
+from wsgiref import simple_server
+
+from dulwich.server import (
+    DictBackend,
+    )
+from dulwich.tests import (
+    TestSkipped,
+    )
+from dulwich.web import (
+    HTTPGitApplication,
+    )
+
+from server_utils import (
+    ServerTests,
+    ShutdownServerMixIn,
+    )
+from utils import (
+    CompatTestCase,
+    )
+
+
+if getattr(simple_server.WSGIServer, 'shutdown', None):
+    WSGIServer = simple_server.WSGIServer
+else:
+    class WSGIServer(ShutdownServerMixIn, simple_server.WSGIServer):
+        """Subclass of WSGIServer that can be shut down."""
+
+        def __init__(self, *args, **kwargs):
+            # BaseServer is old-style so we have to call both __init__s
+            ShutdownServerMixIn.__init__(self)
+            simple_server.WSGIServer.__init__(self, *args, **kwargs)
+
+        serve = ShutdownServerMixIn.serve_forever
+
+
+class WebTests(ServerTests):
+    """Base tests for web server tests.
+
+    Contains utility and setUp/tearDown methods, but does non inherit from
+    TestCase so tests are not automatically run.
+    """
+
+    protocol = 'http'
+
+    def _start_server(self, repo):
+        backend = DictBackend({'/': repo})
+        app = self._make_app(backend)
+        dul_server = simple_server.make_server('localhost', 0, app,
+                                               server_class=WSGIServer)
+        threading.Thread(target=dul_server.serve_forever).start()
+        self._server = dul_server
+        _, port = dul_server.socket.getsockname()
+        return port
+
+
+class SmartWebTestCase(WebTests, CompatTestCase):
+    """Test cases for smart HTTP server."""
+
+    min_git_version = (1, 6, 6)
+
+    def setUp(self):
+        WebTests.setUp(self)
+        CompatTestCase.setUp(self)
+
+    def tearDown(self):
+        WebTests.tearDown(self)
+        CompatTestCase.tearDown(self)
+
+    def _make_app(self, backend):
+        return HTTPGitApplication(backend)
+
+
+class DumbWebTestCase(WebTests, CompatTestCase):
+    """Test cases for dumb HTTP server."""
+
+    def setUp(self):
+        WebTests.setUp(self)
+        CompatTestCase.setUp(self)
+
+    def tearDown(self):
+        WebTests.tearDown(self)
+        CompatTestCase.tearDown(self)
+
+    def _make_app(self, backend):
+        return HTTPGitApplication(backend, dumb=True)
+
+    def test_push_to_dulwich(self):
+        # Note: remove this if dumb pushing is supported
+        raise TestSkipped('Dumb web pushing not supported.')

+ 196 - 0
dulwich/tests/compat/utils.py

@@ -0,0 +1,196 @@
+# utils.py -- Git compatibility utilities
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Utilities for interacting with cgit."""
+
+import errno
+import os
+import socket
+import subprocess
+import tempfile
+import time
+import unittest
+
+from dulwich.repo import Repo
+from dulwich.protocol import TCP_GIT_PORT
+
+from dulwich.tests import (
+    TestSkipped,
+    )
+
+_DEFAULT_GIT = 'git'
+
+
+def git_version(git_path=_DEFAULT_GIT):
+    """Attempt to determine the version of git currently installed.
+
+    :param git_path: Path to the git executable; defaults to the version in
+        the system path.
+    :return: A tuple of ints of the form (major, minor, point), or None if no
+        git installation was found.
+    """
+    try:
+        _, output = run_git(['--version'], git_path=git_path,
+                            capture_stdout=True)
+    except OSError:
+        return None
+    version_prefix = 'git version '
+    if not output.startswith(version_prefix):
+        return None
+    output = output[len(version_prefix):]
+    nums = output.split('.')
+    if len(nums) == 2:
+        nums.add('0')
+    else:
+        nums = nums[:3]
+    try:
+        return tuple(int(x) for x in nums)
+    except ValueError:
+        return None
+
+
+def require_git_version(required_version, git_path=_DEFAULT_GIT):
+    """Require git version >= version, or skip the calling test."""
+    found_version = git_version(git_path=git_path)
+    if found_version < required_version:
+        required_version = '.'.join(map(str, required_version))
+        found_version = '.'.join(map(str, found_version))
+        raise TestSkipped('Test requires git >= %s, found %s' %
+                         (required_version, found_version))
+
+
+def run_git(args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False,
+            **popen_kwargs):
+    """Run a git command.
+
+    Input is piped from the input parameter and output is sent to the standard
+    streams, unless capture_stdout is set.
+
+    :param args: A list of args to the git command.
+    :param git_path: Path to to the git executable.
+    :param input: Input data to be sent to stdin.
+    :param capture_stdout: Whether to capture and return stdout.
+    :param popen_kwargs: Additional kwargs for subprocess.Popen;
+        stdin/stdout args are ignored.
+    :return: A tuple of (returncode, stdout contents). If capture_stdout is
+        False, None will be returned as stdout contents.
+    :raise OSError: if the git executable was not found.
+    """
+    args = [git_path] + args
+    popen_kwargs['stdin'] = subprocess.PIPE
+    if capture_stdout:
+        popen_kwargs['stdout'] = subprocess.PIPE
+    else:
+        popen_kwargs.pop('stdout', None)
+    p = subprocess.Popen(args, **popen_kwargs)
+    stdout, stderr = p.communicate(input=input)
+    return (p.returncode, stdout)
+
+
+def run_git_or_fail(args, git_path=_DEFAULT_GIT, input=None, **popen_kwargs):
+    """Run a git command, capture stdout/stderr, and fail if git fails."""
+    popen_kwargs['stderr'] = subprocess.STDOUT
+    returncode, stdout = run_git(args, git_path=git_path, input=input,
+                                 capture_stdout=True, **popen_kwargs)
+    assert returncode == 0
+    return stdout
+
+
+def import_repo_to_dir(name):
+    """Import a repo from a fast-export file in a temporary directory.
+
+    These are used rather than binary repos for compat tests because they are
+    more compact an human-editable, and we already depend on git.
+
+    :param name: The name of the repository export file, relative to
+        dulwich/tests/data/repos.
+    :returns: The path to the imported repository.
+    """
+    temp_dir = tempfile.mkdtemp()
+    export_path = os.path.join(os.path.dirname(__file__), os.pardir, 'data',
+                               'repos', name)
+    temp_repo_dir = os.path.join(temp_dir, name)
+    export_file = open(export_path, 'rb')
+    run_git_or_fail(['init', '--bare', temp_repo_dir])
+    run_git_or_fail(['fast-import'], input=export_file.read(),
+                    cwd=temp_repo_dir)
+    export_file.close()
+    return temp_repo_dir
+
+def import_repo(name):
+    """Import a repo from a fast-export file in a temporary directory.
+
+    :param name: The name of the repository export file, relative to
+        dulwich/tests/data/repos.
+    :returns: An initialized Repo object that lives in a temporary directory.
+    """
+    return Repo(import_repo_to_dir(name))
+
+
+def check_for_daemon(limit=10, delay=0.1, timeout=0.1, port=TCP_GIT_PORT):
+    """Check for a running TCP daemon.
+
+    Defaults to checking 10 times with a delay of 0.1 sec between tries.
+
+    :param limit: Number of attempts before deciding no daemon is running.
+    :param delay: Delay between connection attempts.
+    :param timeout: Socket timeout for connection attempts.
+    :param port: Port on which we expect the daemon to appear.
+    :returns: A boolean, true if a daemon is running on the specified port,
+        false if not.
+    """
+    for _ in xrange(limit):
+        time.sleep(delay)
+        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        s.settimeout(delay)
+        try:
+            s.connect(('localhost', port))
+            s.close()
+            return True
+        except socket.error, e:
+            if getattr(e, 'errno', False) and e.errno != errno.ECONNREFUSED:
+                raise
+            elif e.args[0] != errno.ECONNREFUSED:
+                raise
+    return False
+
+
+class CompatTestCase(unittest.TestCase):
+    """Test case that requires git for compatibility checks.
+
+    Subclasses can change the git version required by overriding
+    min_git_version.
+    """
+
+    min_git_version = (1, 5, 0)
+
+    def setUp(self):
+        require_git_version(self.min_git_version)
+
+    def assertReposEqual(self, repo1, repo2):
+        self.assertEqual(repo1.get_refs(), repo2.get_refs())
+        self.assertEqual(sorted(set(repo1.object_store)),
+                         sorted(set(repo2.object_store)))
+
+    def assertReposNotEqual(self, repo1, repo2):
+        refs1 = repo1.get_refs()
+        objs1 = set(repo1.object_store)
+        refs2 = repo2.get_refs()
+        objs2 = set(repo2.object_store)
+        self.assertFalse(refs1 == refs2 and objs1 == objs2)

TEMPAT SAMPAH
dulwich/tests/data/blobs/11/11111111111111111111111111111111111111


+ 0 - 0
dulwich/tests/data/blobs/6f670c0fb53f9463760b7295fbb814e965fb20c8 → dulwich/tests/data/blobs/6f/670c0fb53f9463760b7295fbb814e965fb20c8


+ 0 - 0
dulwich/tests/data/blobs/954a536f7819d40e6f637f849ee187dd10066349 → dulwich/tests/data/blobs/95/4a536f7819d40e6f637f849ee187dd10066349


+ 0 - 0
dulwich/tests/data/blobs/e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 → dulwich/tests/data/blobs/e6/9de29bb2d1d6434b8b29ae775ad8c2e48c5391


+ 0 - 0
dulwich/tests/data/commits/0d89f20333fbb1d2f3a94da77f4981373d8f4310 → dulwich/tests/data/commits/0d/89f20333fbb1d2f3a94da77f4981373d8f4310


+ 0 - 0
dulwich/tests/data/commits/5dac377bdded4c9aeb8dff595f0faeebcc8498cc → dulwich/tests/data/commits/5d/ac377bdded4c9aeb8dff595f0faeebcc8498cc


+ 0 - 0
dulwich/tests/data/commits/60dacdc733de308bb77bb76ce0fb0f9b44c9769e → dulwich/tests/data/commits/60/dacdc733de308bb77bb76ce0fb0f9b44c9769e


+ 2 - 0
dulwich/tests/data/repos/a.git/objects/28/237f4dc30d0d462658d6b937b08a0f0b6ef55a

@@ -0,0 +1,2 @@
+x5ÌA
+Â0…a×9Å\@™¦i›�""ÁLÚ1T"uPêéMA7�oó~å•ó»î2(0á�íHˆ\uB\]ÛMÞN‚c+ÄH�Ñõ!0ä”&5Zi-»)Ê~	œó’ß“~ Ã�§˜sœåP~G¨lÛÖ®Á†`�јkéÌüÔ÷ÀN0—

+ 3 - 0
dulwich/tests/data/repos/a.git/objects/b0/931cadc54336e78a1d980420e3268903b57a50

@@ -0,0 +1,3 @@
+x-�[
+Β0ύΞ*ξ*IΜ��Έ7�Η5T[o©΅RWo†Γΐ™
+­wο�*Θ`eφ�/“Ωi­·7sΰΒjƒpΑθμ«Ϋ��h�†ΚjkL[c7‡τΐόΈ„αL½‡ϊ�>Η�<Ά2βΎέ� ¤1JrηtάqΞΨµεhΜ°βςθΙΎΦ¥2v

+ 3 - 0
dulwich/tests/data/repos/a.git/packed-refs

@@ -0,0 +1,3 @@
+# pack-refs with: peeled 
+b0931cadc54336e78a1d980420e3268903b57a50 refs/tags/mytag-packed
+^2a72d929692c41d8554c07f6301757ba18a65d91

+ 1 - 0
dulwich/tests/data/repos/a.git/refs/tags/mytag

@@ -0,0 +1 @@
+28237f4dc30d0d462658d6b937b08a0f0b6ef55a

+ 3 - 0
dulwich/tests/data/repos/refs.git/objects/3e/c9c43c84ff242e3ef4a9fc5bc111fd780a76a8

@@ -0,0 +1,3 @@
+x-�Q
+Â0DýÎ)ö-›mšVñ^ i6±Ò6’.ŠžÞ~ÍÌÇ{#Cm›]rwy´Î×u�=’uº5^[³o¸õ<¸®H<*y?Æ´,“()ŽÌa«°¦ßˆœá2<Î)§×$8x÷¯§˜Rœ¹.è4Ykˆt�Pa�¨Ôµ¨
+q?…À™W)'«”ÜÔǧ6

+ 5 - 0
dulwich/tests/data/repos/refs.git/objects/cd/a609072918d7b70057b6bef9f4c2537843fcfe

@@ -0,0 +1,5 @@
+x-ŤQ
+Â0DýÎ)öm7i’Vń^ i6±bIEOo
+~Íc`ŢđAví.ą;ŤZy´Îk×u�<*ë¤Ń^Z˝oÉx\×T4
+ţ<	Ć4Ď.ŽLam
+¬ÍFÖj«#e¸/‚sĘé=ńŢýńSŠŞ‹äRY«•Bc ÂQ�k‘–Ĺ
üeZ¸Ü-\r?)Y9Ţ

+ 1 - 0
dulwich/tests/data/repos/refs.git/packed-refs

@@ -1,3 +1,4 @@
 # pack-refs with: peeled 
 df6800012397fb85c56e7418dd4eb9405dee075c refs/tags/refs-0.1
 ^42d06bd4b77fed026b154d16493e5deab78f02ec
+42d06bd4b77fed026b154d16493e5deab78f02ec refs/heads/packed

+ 1 - 0
dulwich/tests/data/repos/refs.git/refs/tags/refs-0.2

@@ -0,0 +1 @@
+3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8

+ 99 - 0
dulwich/tests/data/repos/server_new.export

@@ -0,0 +1,99 @@
+blob
+mark :1
+data 13
+foo contents
+
+reset refs/heads/master
+commit refs/heads/master
+mark :2
+author Dave Borowitz <dborowitz@google.com> 1265755064 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755064 -0800
+data 16
+initial checkin
+M 100644 :1 foo
+
+blob
+mark :3
+data 13
+baz contents
+
+blob
+mark :4
+data 21
+updated foo contents
+
+commit refs/heads/master
+mark :5
+author Dave Borowitz <dborowitz@google.com> 1265755140 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755140 -0800
+data 15
+master checkin
+from :2
+M 100644 :3 baz
+M 100644 :4 foo
+
+blob
+mark :6
+data 24
+updated foo contents v2
+
+commit refs/heads/master
+mark :7
+author Dave Borowitz <dborowitz@google.com> 1265755287 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755287 -0800
+data 17
+master checkin 2
+from :5
+M 100644 :6 foo
+
+blob
+mark :8
+data 24
+updated foo contents v3
+
+commit refs/heads/master
+mark :9
+author Dave Borowitz <dborowitz@google.com> 1265755295 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755295 -0800
+data 17
+master checkin 3
+from :7
+M 100644 :8 foo
+
+blob
+mark :10
+data 22
+branched bar contents
+
+blob
+mark :11
+data 22
+branched foo contents
+
+commit refs/heads/branch
+mark :12
+author Dave Borowitz <dborowitz@google.com> 1265755111 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755111 -0800
+data 15
+branch checkin
+from :2
+M 100644 :10 bar
+M 100644 :11 foo
+
+blob
+mark :13
+data 25
+branched bar contents v2
+
+commit refs/heads/branch
+mark :14
+author Dave Borowitz <dborowitz@google.com> 1265755319 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755319 -0800
+data 17
+branch checkin 2
+from :12
+M 100644 :13 bar
+
+reset refs/heads/master
+from :9
+

+ 57 - 0
dulwich/tests/data/repos/server_old.export

@@ -0,0 +1,57 @@
+blob
+mark :1
+data 13
+foo contents
+
+reset refs/heads/master
+commit refs/heads/master
+mark :2
+author Dave Borowitz <dborowitz@google.com> 1265755064 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755064 -0800
+data 16
+initial checkin
+M 100644 :1 foo
+
+blob
+mark :3
+data 22
+branched bar contents
+
+blob
+mark :4
+data 22
+branched foo contents
+
+commit refs/heads/branch
+mark :5
+author Dave Borowitz <dborowitz@google.com> 1265755111 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755111 -0800
+data 15
+branch checkin
+from :2
+M 100644 :3 bar
+M 100644 :4 foo
+
+blob
+mark :6
+data 13
+baz contents
+
+blob
+mark :7
+data 21
+updated foo contents
+
+commit refs/heads/master
+mark :8
+author Dave Borowitz <dborowitz@google.com> 1265755140 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755140 -0800
+data 15
+master checkin
+from :2
+M 100644 :6 baz
+M 100644 :7 foo
+
+reset refs/heads/master
+from :8
+

+ 0 - 0
dulwich/tests/data/tags/71033db03a03c6a36721efcf1968dd8f8e0cf023 → dulwich/tests/data/tags/71/033db03a03c6a36721efcf1968dd8f8e0cf023


+ 0 - 0
dulwich/tests/data/trees/70c190eb48fa8bbb50ddc692a17b44cb781af7f6 → dulwich/tests/data/trees/70/c190eb48fa8bbb50ddc692a17b44cb781af7f6


+ 11 - 5
dulwich/tests/test_client.py

@@ -1,16 +1,16 @@
 # test_client.py -- Tests for the git protocol, client side
 # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # or (at your option) any later version of the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -23,16 +23,22 @@ from dulwich.client import (
     GitClient,
     )
 
+
+# TODO(durin42): add unit-level tests of GitClient
 class GitClientTests(TestCase):
 
     def setUp(self):
         self.rout = StringIO()
         self.rin = StringIO()
-        self.client = GitClient(lambda x: True, self.rin.read, 
+        self.client = GitClient(lambda x: True, self.rin.read,
             self.rout.write)
 
     def test_caps(self):
-        self.assertEquals(['multi_ack', 'side-band-64k', 'ofs-delta', 'thin-pack'], self.client._capabilities)
+        self.assertEquals(set(['multi_ack', 'side-band-64k', 'ofs-delta',
+                               'thin-pack']),
+                          set(self.client._fetch_capabilities))
+        self.assertEquals(set(['ofs-delta', 'report-status']),
+                          set(self.client._send_capabilities))
 
     def test_fetch_pack_none(self):
         self.rin.write(

+ 78 - 0
dulwich/tests/test_fastexport.py

@@ -0,0 +1,78 @@
+# test_fastexport.py -- Fast export/import functionality
+# Copyright (C) 2010 Jelmer Vernooij <jelmer@samba.org>
+# 
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of 
+# the License.
+# 
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+# 
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+from cStringIO import StringIO
+import stat
+from unittest import TestCase
+
+from dulwich.fastexport import (
+    FastExporter,
+    )
+from dulwich.object_store import (
+    MemoryObjectStore,
+    )
+from dulwich.objects import (
+    Blob,
+    Commit,
+    Tree,
+    )
+
+
+class FastExporterTests(TestCase):
+
+    def setUp(self):
+        super(FastExporterTests, self).setUp()
+        self.store = MemoryObjectStore()
+        self.stream = StringIO()
+        self.fastexporter = FastExporter(self.stream, self.store)
+
+    def test_export_blob(self):
+        b = Blob()
+        b.data = "fooBAR"
+        self.assertEquals(1, self.fastexporter.export_blob(b))
+        self.assertEquals('blob\nmark :1\ndata 6\nfooBAR\n',
+            self.stream.getvalue())
+
+    def test_export_commit(self):
+        b = Blob()
+        b.data = "FOO"
+        t = Tree()
+        t.add(stat.S_IFREG | 0644, "foo", b.id)
+        c = Commit()
+        c.committer = c.author = "Jelmer <jelmer@host>"
+        c.author_time = c.commit_time = 1271345553.47
+        c.author_timezone = c.commit_timezone = 0
+        c.message = "msg"
+        c.tree = t.id
+        self.store.add_objects([(b, None), (t, None), (c, None)])
+        self.assertEquals(2,
+                self.fastexporter.export_commit(c, "refs/heads/master"))
+        self.assertEquals("""blob
+mark :1
+data 3
+FOO
+commit refs/heads/master
+mark :2
+author Jelmer <jelmer@host> 1271345553.47 +0000
+committer Jelmer <jelmer@host> 1271345553.47 +0000
+data 3
+msg
+M 100644 :1 foo
+
+""", self.stream.getvalue())

+ 68 - 2
dulwich/tests/test_file.py

@@ -20,12 +20,71 @@
 import errno
 import os
 import shutil
+import sys
 import tempfile
 import unittest
 
-from dulwich.file import GitFile
+from dulwich.file import GitFile, fancy_rename
+from dulwich.tests import TestSkipped
+
+
+class FancyRenameTests(unittest.TestCase):
+
+    def setUp(self):
+        self._tempdir = tempfile.mkdtemp()
+        self.foo = self.path('foo')
+        self.bar = self.path('bar')
+        self.create(self.foo, 'foo contents')
+
+    def tearDown(self):
+        shutil.rmtree(self._tempdir)
+
+    def path(self, filename):
+        return os.path.join(self._tempdir, filename)
+
+    def create(self, path, contents):
+        f = open(path, 'wb')
+        f.write(contents)
+        f.close()
+
+    def test_no_dest_exists(self):
+        self.assertFalse(os.path.exists(self.bar))
+        fancy_rename(self.foo, self.bar)
+        self.assertFalse(os.path.exists(self.foo))
+
+        new_f = open(self.bar, 'rb')
+        self.assertEquals('foo contents', new_f.read())
+        new_f.close()
+         
+    def test_dest_exists(self):
+        self.create(self.bar, 'bar contents')
+        fancy_rename(self.foo, self.bar)
+        self.assertFalse(os.path.exists(self.foo))
+
+        new_f = open(self.bar, 'rb')
+        self.assertEquals('foo contents', new_f.read())
+        new_f.close()
+
+    def test_dest_opened(self):
+        if sys.platform != "win32":
+            raise TestSkipped("platform allows overwriting open files")
+        self.create(self.bar, 'bar contents')
+        dest_f = open(self.bar, 'rb')
+        self.assertRaises(OSError, fancy_rename, self.foo, self.bar)
+        dest_f.close()
+        self.assertTrue(os.path.exists(self.path('foo')))
+
+        new_f = open(self.foo, 'rb')
+        self.assertEquals('foo contents', new_f.read())
+        new_f.close()
+
+        new_f = open(self.bar, 'rb')
+        self.assertEquals('bar contents', new_f.read())
+        new_f.close()
+
 
 class GitFileTests(unittest.TestCase):
+
     def setUp(self):
         self._tempdir = tempfile.mkdtemp()
         f = open(self.path('foo'), 'wb')
@@ -85,7 +144,7 @@ class GitFileTests(unittest.TestCase):
         f1.write('new')
         try:
             f2 = GitFile(foo, 'wb')
-            fail()
+            self.fail()
         except OSError, e:
             self.assertEquals(errno.EEXIST, e.errno)
         f1.write(' contents')
@@ -129,3 +188,10 @@ class GitFileTests(unittest.TestCase):
             f.abort()
         except (IOError, OSError):
             self.fail()
+
+    def test_abort_close_removed(self):
+        foo = self.path('foo')
+        f = GitFile(foo, 'wb')
+        os.remove(foo+".lock")
+        f.abort()
+        self.assertTrue(f._closed)

+ 31 - 11
dulwich/tests/test_index.py

@@ -1,16 +1,16 @@
 # test_index.py -- Tests for the git index
 # Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # or (at your option) any later version of the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -24,8 +24,10 @@ from cStringIO import (
     StringIO,
     )
 import os
+import shutil
 import stat
 import struct
+import tempfile
 from unittest import TestCase
 
 from dulwich.index import (
@@ -43,6 +45,7 @@ from dulwich.objects import (
     Blob,
     )
 
+
 class IndexTestCase(TestCase):
 
     datadir = os.path.join(os.path.dirname(__file__), 'data/indexes')
@@ -51,7 +54,7 @@ class IndexTestCase(TestCase):
         return Index(os.path.join(self.datadir, name))
 
 
-class SimpleIndexTestcase(IndexTestCase):
+class SimpleIndexTestCase(IndexTestCase):
 
     def test_len(self):
         self.assertEquals(1, len(self.get_simple_index("index")))
@@ -60,21 +63,38 @@ class SimpleIndexTestcase(IndexTestCase):
         self.assertEquals(['bla'], list(self.get_simple_index("index")))
 
     def test_getitem(self):
-        self.assertEquals( ((1230680220, 0), (1230680220, 0), 2050, 3761020, 33188, 1000, 1000, 0, 'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 0)
-            , 
-                self.get_simple_index("index")["bla"])
+        self.assertEquals(((1230680220, 0), (1230680220, 0), 2050, 3761020,
+                           33188, 1000, 1000, 0,
+                           'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 0),
+                          self.get_simple_index("index")["bla"])
+
+    def test_empty(self):
+        i = self.get_simple_index("notanindex")
+        self.assertEquals(0, len(i))
+        self.assertFalse(os.path.exists(i._filename))
 
 
 class SimpleIndexWriterTestCase(IndexTestCase):
 
+    def setUp(self):
+        IndexTestCase.setUp(self)
+        self.tempdir = tempfile.mkdtemp()
+
+    def tearDown(self):
+        IndexTestCase.tearDown(self)
+        shutil.rmtree(self.tempdir)
+
     def test_simple_write(self):
-        entries = [('barbla', (1230680220, 0), (1230680220, 0), 2050, 3761020, 33188, 1000, 1000, 0, 'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 0)]
-        x = open('test-simple-write-index', 'w+')
+        entries = [('barbla', (1230680220, 0), (1230680220, 0), 2050, 3761020,
+                    33188, 1000, 1000, 0,
+                    'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 0)]
+        filename = os.path.join(self.tempdir, 'test-simple-write-index')
+        x = open(filename, 'w+')
         try:
             write_index(x, entries)
         finally:
             x.close()
-        x = open('test-simple-write-index', 'r')
+        x = open(filename, 'r')
         try:
             self.assertEquals(entries, list(read_index(x)))
         finally:
@@ -108,7 +128,7 @@ class CommitTreeTests(TestCase):
         self.assertEquals(dirid, "c1a1deb9788150829579a8b4efa6311e7b638650")
         self.assertEquals((stat.S_IFDIR, dirid), self.store[rootid]["bla"])
         self.assertEquals((stat.S_IFREG, blob.id), self.store[dirid]["bar"])
-        self.assertEquals(set([rootid, dirid, blob.id]), 
+        self.assertEquals(set([rootid, dirid, blob.id]),
                           set(self.store._data.keys()))
 
 

+ 33 - 29
dulwich/tests/test_object_store.py

@@ -1,16 +1,16 @@
 # test_object_store.py -- tests for object_store.py
 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # or (at your option) any later version of the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -20,6 +20,9 @@
 """Tests for the object store interface."""
 
 
+import os
+import shutil
+import tempfile
 from unittest import TestCase
 
 from dulwich.objects import (
@@ -29,24 +32,12 @@ from dulwich.object_store import (
     DiskObjectStore,
     MemoryObjectStore,
     )
-import os
-import shutil
-
-
-testobject = Blob()
-testobject.data = "yummy data"
-
+from utils import (
+    make_object,
+    )
 
-class SpecificDiskObjectStoreTests(TestCase):
-
-    def test_pack_dir(self):
-        o = DiskObjectStore("foo")
-        self.assertEquals(os.path.join("foo", "pack"), o.pack_dir)
-
-    def test_empty_packs(self):
-        o = DiskObjectStore("foo")
-        self.assertEquals([], o.packs)
 
+testobject = make_object(Blob, data="yummy data")
 
 
 class ObjectStoreTests(object):
@@ -55,10 +46,10 @@ class ObjectStoreTests(object):
         self.assertEquals([], list(self.store))
 
     def test_get_nonexistant(self):
-        self.assertRaises(KeyError, self.store.__getitem__, "a" * 40)
+        self.assertRaises(KeyError, lambda: self.store["a" * 40])
 
     def test_contains_nonexistant(self):
-        self.assertFalse(self.store.__contains__("a" * 40))
+        self.assertFalse(("a" * 40) in self.store)
 
     def test_add_objects_empty(self):
         self.store.add_objects([])
@@ -71,7 +62,7 @@ class ObjectStoreTests(object):
     def test_add_object(self):
         self.store.add_object(testobject)
         self.assertEquals(set([testobject.id]), set(self.store))
-        self.assertTrue(self.store.__contains__(testobject.id))
+        self.assertTrue(testobject.id in self.store)
         r = self.store[testobject.id]
         self.assertEquals(r, testobject)
 
@@ -79,23 +70,36 @@ class ObjectStoreTests(object):
         data = [(testobject, "mypath")]
         self.store.add_objects(data)
         self.assertEquals(set([testobject.id]), set(self.store))
-        self.assertTrue(self.store.__contains__(testobject.id))
+        self.assertTrue(testobject.id in self.store)
         r = self.store[testobject.id]
         self.assertEquals(r, testobject)
 
 
-class MemoryObjectStoreTests(ObjectStoreTests,TestCase):
+class MemoryObjectStoreTests(ObjectStoreTests, TestCase):
 
     def setUp(self):
         TestCase.setUp(self)
         self.store = MemoryObjectStore()
 
 
-class DiskObjectStoreTests(ObjectStoreTests,TestCase):
+class DiskObjectStoreTests(ObjectStoreTests, TestCase):
 
     def setUp(self):
         TestCase.setUp(self)
-        if os.path.exists("foo"):
-            shutil.rmtree("foo")
-        os.makedirs(os.path.join("foo", "pack"))
-        self.store = DiskObjectStore("foo")
+        self.store_dir = tempfile.mkdtemp()
+        self.store = DiskObjectStore.init(self.store_dir)
+
+    def tearDown(self):
+        TestCase.tearDown(self)
+        shutil.rmtree(self.store_dir)
+
+    def test_pack_dir(self):
+        o = DiskObjectStore(self.store_dir)
+        self.assertEquals(os.path.join(self.store_dir, "pack"), o.pack_dir)
+
+    def test_empty_packs(self):
+        o = DiskObjectStore(self.store_dir)
+        self.assertEquals([], o.packs)
+
+
+# TODO: MissingObjectFinderTests

+ 420 - 104
dulwich/tests/test_objects.py

@@ -20,11 +20,18 @@
 
 """Tests for git base objects."""
 
+# TODO: Round-trip parse-serialize-parse and serialize-parse-serialize tests.
 
+
+import datetime
 import os
 import stat
 import unittest
 
+from dulwich.errors import (
+    ChecksumMismatch,
+    ObjectFormatException,
+    )
 from dulwich.objects import (
     Blob,
     Tree,
@@ -32,7 +39,20 @@ from dulwich.objects import (
     Tag,
     format_timezone,
     hex_to_sha,
+    sha_to_hex,
+    hex_to_filename,
+    check_hexsha,
+    check_identity,
     parse_timezone,
+    parse_tree,
+    _parse_tree_py,
+    )
+from dulwich.tests import (
+    TestSkipped,
+    )
+from utils import (
+    make_commit,
+    make_object,
     )
 
 a_sha = '6f670c0fb53f9463760b7295fbb814e965fb20c8'
@@ -41,13 +61,57 @@ c_sha = '954a536f7819d40e6f637f849ee187dd10066349'
 tree_sha = '70c190eb48fa8bbb50ddc692a17b44cb781af7f6'
 tag_sha = '71033db03a03c6a36721efcf1968dd8f8e0cf023'
 
+
+try:
+    from itertools import permutations
+except ImportError:
+    # Implementation of permutations from Python 2.6 documentation:
+    # http://docs.python.org/2.6/library/itertools.html#itertools.permutations
+    # Copyright (c) 2001-2010 Python Software Foundation; All Rights Reserved
+    # Modified syntax slightly to run under Python 2.4.
+    def permutations(iterable, r=None):
+        # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
+        # permutations(range(3)) --> 012 021 102 120 201 210
+        pool = tuple(iterable)
+        n = len(pool)
+        if r is None:
+            r = n
+        if r > n:
+            return
+        indices = range(n)
+        cycles = range(n, n-r, -1)
+        yield tuple(pool[i] for i in indices[:r])
+        while n:
+            for i in reversed(range(r)):
+                cycles[i] -= 1
+                if cycles[i] == 0:
+                    indices[i:] = indices[i+1:] + indices[i:i+1]
+                    cycles[i] = n - i
+                else:
+                    j = cycles[i]
+                    indices[i], indices[-j] = indices[-j], indices[i]
+                    yield tuple(pool[i] for i in indices[:r])
+                    break
+            else:
+                return
+
+
+class TestHexToSha(unittest.TestCase):
+
+    def test_simple(self):
+        self.assertEquals("\xab\xcd" * 10, hex_to_sha("abcd" * 10))
+
+    def test_reverse(self):
+        self.assertEquals("abcd" * 10, sha_to_hex("\xab\xcd" * 10))
+
+
 class BlobReadTests(unittest.TestCase):
     """Test decompression of blobs"""
-  
-    def get_sha_file(self, obj, base, sha):
-        return obj.from_file(os.path.join(os.path.dirname(__file__),
-                                          'data', base, sha))
-  
+
+    def get_sha_file(self, cls, base, sha):
+        dir = os.path.join(os.path.dirname(__file__), 'data', base)
+        return cls.from_path(hex_to_filename(dir, sha))
+
     def get_blob(self, sha):
         """Return the blob named sha from the test data dir"""
         return self.get_sha_file(Blob, 'blobs', sha)
@@ -82,6 +146,18 @@ class BlobReadTests(unittest.TestCase):
         b = Blob.from_string(string)
         self.assertEqual(b.data, string)
         self.assertEqual(b.sha().hexdigest(), b_sha)
+
+    def test_chunks(self):
+        string = 'test 5\n'
+        b = Blob.from_string(string)
+        self.assertEqual([string], b.chunked)
+
+    def test_set_chunks(self):
+        b = Blob()
+        b.chunked = ['te', 'st', ' 5\n']
+        self.assertEqual('test 5\n', b.data)
+        b.chunked = ['te', 'st', ' 6\n']
+        self.assertEqual('test 6\n', b.as_raw_string())
   
     def test_parse_legacy_blob(self):
         string = 'test 3\n'
@@ -107,12 +183,12 @@ class BlobReadTests(unittest.TestCase):
         self.assertEqual(t.tag_time, 1231203091)
         self.assertEqual(t.message, 'This is a signed tag\n-----BEGIN PGP SIGNATURE-----\nVersion: GnuPG v1.4.9 (GNU/Linux)\n\niEYEABECAAYFAkliqx8ACgkQqSMmLy9u/kcx5ACfakZ9NnPl02tOyYP6pkBoEkU1\n5EcAn0UFgokaSvS371Ym/4W9iJj6vh3h\n=ql7y\n-----END PGP SIGNATURE-----\n')
   
-  
     def test_read_commit_from_file(self):
         sha = '60dacdc733de308bb77bb76ce0fb0f9b44c9769e'
         c = self.commit(sha)
         self.assertEqual(c.tree, tree_sha)
-        self.assertEqual(c.parents, ['0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
+        self.assertEqual(c.parents,
+            ['0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
         self.assertEqual(c.author,
             'James Westby <jw+debian@jameswestby.net>')
         self.assertEqual(c.committer,
@@ -150,95 +226,185 @@ class BlobReadTests(unittest.TestCase):
         self.assertEqual(c.commit_timezone, 0)
         self.assertEqual(c.author_timezone, 0)
         self.assertEqual(c.message, 'Merge ../b\n')
-  
+
+
+class ShaFileCheckTests(unittest.TestCase):
+
+    def assertCheckFails(self, cls, data):
+        obj = cls()
+        def do_check():
+            obj.set_raw_string(data)
+            obj.check()
+        self.assertRaises(ObjectFormatException, do_check)
+
+    def assertCheckSucceeds(self, cls, data):
+        obj = cls()
+        obj.set_raw_string(data)
+        self.assertEqual(None, obj.check())
 
 
 class CommitSerializationTests(unittest.TestCase):
 
-    def make_base(self):
-        c = Commit()
-        c.tree = 'd80c186a03f423a81b39df39dc87fd269736ca86'
-        c.parents = ['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd', '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6']
-        c.author = 'James Westby <jw+debian@jameswestby.net>'
-        c.committer = 'James Westby <jw+debian@jameswestby.net>'
-        c.commit_time = 1174773719
-        c.author_time = 1174773719
-        c.commit_timezone = 0
-        c.author_timezone = 0
-        c.message =  'Merge ../b\n'
-        return c
+    def make_commit(self, **kwargs):
+        attrs = {'tree': 'd80c186a03f423a81b39df39dc87fd269736ca86',
+                 'parents': ['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
+                             '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
+                 'author': 'James Westby <jw+debian@jameswestby.net>',
+                 'committer': 'James Westby <jw+debian@jameswestby.net>',
+                 'commit_time': 1174773719,
+                 'author_time': 1174773719,
+                 'commit_timezone': 0,
+                 'author_timezone': 0,
+                 'message':  'Merge ../b\n'}
+        attrs.update(kwargs)
+        return make_commit(**attrs)
 
     def test_encoding(self):
-        c = self.make_base()
-        c.encoding = "iso8859-1"
-        self.assertTrue("encoding iso8859-1\n" in c.as_raw_string())        
+        c = self.make_commit(encoding='iso8859-1')
+        self.assertTrue('encoding iso8859-1\n' in c.as_raw_string())
 
     def test_short_timestamp(self):
-        c = self.make_base()
-        c.commit_time = 30
+        c = self.make_commit(commit_time=30)
         c1 = Commit()
         c1.set_raw_string(c.as_raw_string())
         self.assertEquals(30, c1.commit_time)
 
+    def test_raw_length(self):
+        c = self.make_commit()
+        self.assertEquals(len(c.as_raw_string()), c.raw_length())
+
     def test_simple(self):
-        c = self.make_base()
+        c = self.make_commit()
         self.assertEquals(c.id, '5dac377bdded4c9aeb8dff595f0faeebcc8498cc')
         self.assertEquals(
                 'tree d80c186a03f423a81b39df39dc87fd269736ca86\n'
                 'parent ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd\n'
                 'parent 4cffe90e0a41ad3f5190079d7c8f036bde29cbe6\n'
-                'author James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                'committer James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
+                'author James Westby <jw+debian@jameswestby.net> '
+                '1174773719 +0000\n'
+                'committer James Westby <jw+debian@jameswestby.net> '
+                '1174773719 +0000\n'
                 '\n'
                 'Merge ../b\n', c.as_raw_string())
 
     def test_timezone(self):
-        c = self.make_base()
-        c.commit_timezone = 5 * 60
+        c = self.make_commit(commit_timezone=(5 * 60))
         self.assertTrue(" +0005\n" in c.as_raw_string())
 
     def test_neg_timezone(self):
-        c = self.make_base()
-        c.commit_timezone = -1 * 3600
+        c = self.make_commit(commit_timezone=(-1 * 3600))
         self.assertTrue(" -0100\n" in c.as_raw_string())
 
 
-class CommitDeserializationTests(unittest.TestCase):
+default_committer = 'James Westby <jw+debian@jameswestby.net> 1174773719 +0000'
+
+class CommitParseTests(ShaFileCheckTests):
+
+    def make_commit_lines(self,
+                          tree='d80c186a03f423a81b39df39dc87fd269736ca86',
+                          parents=['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
+                                   '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
+                          author=default_committer,
+                          committer=default_committer,
+                          encoding=None,
+                          message='Merge ../b\n',
+                          extra=None):
+        lines = []
+        if tree is not None:
+            lines.append('tree %s' % tree)
+        if parents is not None:
+            lines.extend('parent %s' % p for p in parents)
+        if author is not None:
+            lines.append('author %s' % author)
+        if committer is not None:
+            lines.append('committer %s' % committer)
+        if encoding is not None:
+            lines.append('encoding %s' % encoding)
+        if extra is not None:
+            for name, value in sorted(extra.iteritems()):
+                lines.append('%s %s' % (name, value))
+        lines.append('')
+        if message is not None:
+            lines.append(message)
+        return lines
+
+    def make_commit_text(self, **kwargs):
+        return '\n'.join(self.make_commit_lines(**kwargs))
 
     def test_simple(self):
-        c = Commit.from_string(
-                'tree d80c186a03f423a81b39df39dc87fd269736ca86\n'
-                'parent ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd\n'
-                'parent 4cffe90e0a41ad3f5190079d7c8f036bde29cbe6\n'
-                'author James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                'committer James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                '\n'
-                'Merge ../b\n')
+        c = Commit.from_string(self.make_commit_text())
         self.assertEquals('Merge ../b\n', c.message)
+        self.assertEquals('James Westby <jw+debian@jameswestby.net>', c.author)
         self.assertEquals('James Westby <jw+debian@jameswestby.net>',
-            c.author)
-        self.assertEquals('James Westby <jw+debian@jameswestby.net>',
-            c.committer)
-        self.assertEquals('d80c186a03f423a81b39df39dc87fd269736ca86',
-            c.tree)
+                          c.committer)
+        self.assertEquals('d80c186a03f423a81b39df39dc87fd269736ca86', c.tree)
         self.assertEquals(['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
-                          '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
-            c.parents)
+                           '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
+                          c.parents)
+        expected_time = datetime.datetime(2007, 3, 24, 22, 1, 59)
+        self.assertEquals(expected_time,
+                          datetime.datetime.utcfromtimestamp(c.commit_time))
+        self.assertEquals(0, c.commit_timezone)
+        self.assertEquals(expected_time,
+                          datetime.datetime.utcfromtimestamp(c.author_time))
+        self.assertEquals(0, c.author_timezone)
+        self.assertEquals(None, c.encoding)
 
     def test_custom(self):
-        c = Commit.from_string(
-                'tree d80c186a03f423a81b39df39dc87fd269736ca86\n'
-                'parent ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd\n'
-                'parent 4cffe90e0a41ad3f5190079d7c8f036bde29cbe6\n'
-                'author James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                'committer James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                'extra-field data\n'
-                '\n'
-                'Merge ../b\n')
+        c = Commit.from_string(self.make_commit_text(
+          extra={'extra-field': 'data'}))
         self.assertEquals([('extra-field', 'data')], c.extra)
 
-
-class TreeSerializationTests(unittest.TestCase):
+    def test_encoding(self):
+        c = Commit.from_string(self.make_commit_text(encoding='UTF-8'))
+        self.assertEquals('UTF-8', c.encoding)
+
+    def test_check(self):
+        self.assertCheckSucceeds(Commit, self.make_commit_text())
+        self.assertCheckSucceeds(Commit, self.make_commit_text(parents=None))
+        self.assertCheckSucceeds(Commit,
+                                 self.make_commit_text(encoding='UTF-8'))
+
+        self.assertCheckFails(Commit, self.make_commit_text(tree='xxx'))
+        self.assertCheckFails(Commit, self.make_commit_text(
+          parents=[a_sha, 'xxx']))
+        bad_committer = "some guy without an email address 1174773719 +0000"
+        self.assertCheckFails(Commit,
+                              self.make_commit_text(committer=bad_committer))
+        self.assertCheckFails(Commit,
+                              self.make_commit_text(author=bad_committer))
+        self.assertCheckFails(Commit, self.make_commit_text(author=None))
+        self.assertCheckFails(Commit, self.make_commit_text(committer=None))
+        self.assertCheckFails(Commit, self.make_commit_text(
+          author=None, committer=None))
+
+    def test_check_duplicates(self):
+        # duplicate each of the header fields
+        for i in xrange(5):
+            lines = self.make_commit_lines(parents=[a_sha], encoding='UTF-8')
+            lines.insert(i, lines[i])
+            text = '\n'.join(lines)
+            if lines[i].startswith('parent'):
+                # duplicate parents are ok for now
+                self.assertCheckSucceeds(Commit, text)
+            else:
+                self.assertCheckFails(Commit, text)
+
+    def test_check_order(self):
+        lines = self.make_commit_lines(parents=[a_sha], encoding='UTF-8')
+        headers = lines[:5]
+        rest = lines[5:]
+        # of all possible permutations, ensure only the original succeeds
+        for perm in permutations(headers):
+            perm = list(perm)
+            text = '\n'.join(perm + rest)
+            if perm == headers:
+                self.assertCheckSucceeds(Commit, text)
+            else:
+                self.assertCheckFails(Commit, text)
+
+
+class TreeTests(ShaFileCheckTests):
 
     def test_simple(self):
         myhexsha = "d80c186a03f423a81b39df39dc87fd269736ca86"
@@ -247,6 +413,13 @@ class TreeSerializationTests(unittest.TestCase):
         self.assertEquals('100755 myname\0' + hex_to_sha(myhexsha),
                 x.as_raw_string())
 
+    def test_tree_update_id(self):
+        x = Tree()
+        x["a.c"] = (0100755, "d80c186a03f423a81b39df39dc87fd269736ca86")
+        self.assertEquals("0c5c6bc2c081accfbc250331b19e43b904ab9cdd", x.id)
+        x["a.b"] = (stat.S_IFDIR, "d80c186a03f423a81b39df39dc87fd269736ca86")
+        self.assertEquals("07bfcb5f3ada15bbebdfa3bbb8fd858a363925c8", x.id)
+
     def test_tree_dir_sort(self):
         x = Tree()
         x["a.c"] = (0100755, "d80c186a03f423a81b39df39dc87fd269736ca86")
@@ -254,35 +427,79 @@ class TreeSerializationTests(unittest.TestCase):
         x["a/c"] = (stat.S_IFDIR, "d80c186a03f423a81b39df39dc87fd269736ca86")
         self.assertEquals(["a.c", "a", "a/c"], [p[0] for p in x.iteritems()])
 
+    def _do_test_parse_tree(self, parse_tree):
+        dir = os.path.join(os.path.dirname(__file__), 'data', 'trees')
+        o = Tree.from_path(hex_to_filename(dir, tree_sha))
+        self.assertEquals([('a', 0100644, a_sha), ('b', 0100644, b_sha)],
+                          list(parse_tree(o.as_raw_string())))
+
+    def test_parse_tree(self):
+        self._do_test_parse_tree(_parse_tree_py)
+
+    def test_parse_tree_extension(self):
+        if parse_tree is _parse_tree_py:
+            raise TestSkipped('parse_tree extension not found')
+        self._do_test_parse_tree(parse_tree)
+
+    def test_check(self):
+        t = Tree
+        sha = hex_to_sha(a_sha)
+
+        # filenames
+        self.assertCheckSucceeds(t, '100644 .a\0%s' % sha)
+        self.assertCheckFails(t, '100644 \0%s' % sha)
+        self.assertCheckFails(t, '100644 .\0%s' % sha)
+        self.assertCheckFails(t, '100644 a/a\0%s' % sha)
+        self.assertCheckFails(t, '100644 ..\0%s' % sha)
+
+        # modes
+        self.assertCheckSucceeds(t, '100644 a\0%s' % sha)
+        self.assertCheckSucceeds(t, '100755 a\0%s' % sha)
+        self.assertCheckSucceeds(t, '160000 a\0%s' % sha)
+        # TODO more whitelisted modes
+        self.assertCheckFails(t, '123456 a\0%s' % sha)
+        self.assertCheckFails(t, '123abc a\0%s' % sha)
+
+        # shas
+        self.assertCheckFails(t, '100644 a\0%s' % ('x' * 5))
+        self.assertCheckFails(t, '100644 a\0%s' % ('x' * 18 + '\0'))
+        self.assertCheckFails(t, '100644 a\0%s\n100644 b\0%s' % ('x' * 21, sha))
+
+        # ordering
+        sha2 = hex_to_sha(b_sha)
+        self.assertCheckSucceeds(t, '100644 a\0%s\n100644 b\0%s' % (sha, sha))
+        self.assertCheckSucceeds(t, '100644 a\0%s\n100644 b\0%s' % (sha, sha2))
+        self.assertCheckFails(t, '100644 a\0%s\n100755 a\0%s' % (sha, sha2))
+        self.assertCheckFails(t, '100644 b\0%s\n100644 a\0%s' % (sha2, sha))
+
+    def test_iter(self):
+        t = Tree()
+        t["foo"] = (0100644, a_sha)
+        self.assertEquals(set(["foo"]), set(t))
+
 
 class TagSerializeTests(unittest.TestCase):
 
     def test_serialize_simple(self):
-        x = Tag()
-        x.tagger = "Jelmer Vernooij <jelmer@samba.org>"
-        x.name = "0.1"
-        x.message = "Tag 0.1"
-        x.object = (3, "d80c186a03f423a81b39df39dc87fd269736ca86")
-        x.tag_time = 423423423
-        x.tag_timezone = 0
-        self.assertEquals("""object d80c186a03f423a81b39df39dc87fd269736ca86
-type blob
-tag 0.1
-tagger Jelmer Vernooij <jelmer@samba.org> 423423423 +0000
-
-Tag 0.1""", x.as_raw_string())
-
-
-class TagParseTests(unittest.TestCase):
-
-    def test_parse_ctime(self):
-        x = Tag()
-        x.set_raw_string("""object a38d6181ff27824c79fc7df825164a212eff6a3f
-type commit
-tag v2.6.22-rc7
-tagger Linus Torvalds <torvalds@woody.linux-foundation.org> Sun Jul 1 12:54:34 2007 -0700
-
-Linux 2.6.22-rc7
+        x = make_object(Tag,
+                        tagger='Jelmer Vernooij <jelmer@samba.org>',
+                        name='0.1',
+                        message='Tag 0.1',
+                        object=(Blob, 'd80c186a03f423a81b39df39dc87fd269736ca86'),
+                        tag_time=423423423,
+                        tag_timezone=0)
+        self.assertEquals(('object d80c186a03f423a81b39df39dc87fd269736ca86\n'
+                           'type blob\n'
+                           'tag 0.1\n'
+                           'tagger Jelmer Vernooij <jelmer@samba.org> '
+                           '423423423 +0000\n'
+                           '\n'
+                           'Tag 0.1'), x.as_raw_string())
+
+
+default_tagger = ('Linus Torvalds <torvalds@woody.linux-foundation.org> '
+                  '1183319674 -0700')
+default_message = """Linux 2.6.22-rc7
 -----BEGIN PGP SIGNATURE-----
 Version: GnuPG v1.4.7 (GNU/Linux)
 
@@ -290,39 +507,136 @@ iD8DBQBGiAaAF3YsRnbiHLsRAitMAKCiLboJkQECM/jpYsY3WPfvUgLXkACgg3ql
 OK2XeQOiEeXtT76rV4t2WR4=
 =ivrA
 -----END PGP SIGNATURE-----
-""")
-        self.assertEquals("Linus Torvalds <torvalds@woody.linux-foundation.org>", x.tagger)
+"""
+
+
+class TagParseTests(ShaFileCheckTests):
+    def make_tag_lines(self,
+                       object_sha="a38d6181ff27824c79fc7df825164a212eff6a3f",
+                       object_type_name="commit",
+                       name="v2.6.22-rc7",
+                       tagger=default_tagger,
+                       message=default_message):
+        lines = []
+        if object_sha is not None:
+            lines.append("object %s" % object_sha)
+        if object_type_name is not None:
+            lines.append("type %s" % object_type_name)
+        if name is not None:
+            lines.append("tag %s" % name)
+        if tagger is not None:
+            lines.append("tagger %s" % tagger)
+        lines.append("")
+        if message is not None:
+            lines.append(message)
+        return lines
+
+    def make_tag_text(self, **kwargs):
+        return "\n".join(self.make_tag_lines(**kwargs))
+
+    def test_parse(self):
+        x = Tag()
+        x.set_raw_string(self.make_tag_text())
+        self.assertEquals(
+            "Linus Torvalds <torvalds@woody.linux-foundation.org>", x.tagger)
         self.assertEquals("v2.6.22-rc7", x.name)
+        object_type, object_sha = x.object
+        self.assertEquals("a38d6181ff27824c79fc7df825164a212eff6a3f",
+                          object_sha)
+        self.assertEquals(Commit, object_type)
+        self.assertEquals(datetime.datetime.utcfromtimestamp(x.tag_time),
+                          datetime.datetime(2007, 7, 1, 19, 54, 34))
+        self.assertEquals(-25200, x.tag_timezone)
 
     def test_parse_no_tagger(self):
         x = Tag()
-        x.set_raw_string("""object a38d6181ff27824c79fc7df825164a212eff6a3f
-type commit
-tag v2.6.22-rc7
-
-Linux 2.6.22-rc7
------BEGIN PGP SIGNATURE-----
-Version: GnuPG v1.4.7 (GNU/Linux)
-
-iD8DBQBGiAaAF3YsRnbiHLsRAitMAKCiLboJkQECM/jpYsY3WPfvUgLXkACgg3ql
-OK2XeQOiEeXtT76rV4t2WR4=
-=ivrA
------END PGP SIGNATURE-----
-""")
+        x.set_raw_string(self.make_tag_text(tagger=None))
         self.assertEquals(None, x.tagger)
         self.assertEquals("v2.6.22-rc7", x.name)
 
+    def test_check(self):
+        self.assertCheckSucceeds(Tag, self.make_tag_text())
+        self.assertCheckFails(Tag, self.make_tag_text(object_sha=None))
+        self.assertCheckFails(Tag, self.make_tag_text(object_type_name=None))
+        self.assertCheckFails(Tag, self.make_tag_text(name=None))
+        self.assertCheckFails(Tag, self.make_tag_text(name=''))
+        self.assertCheckFails(Tag, self.make_tag_text(
+          object_type_name="foobar"))
+        self.assertCheckFails(Tag, self.make_tag_text(
+          tagger="some guy without an email address 1183319674 -0700"))
+        self.assertCheckFails(Tag, self.make_tag_text(
+          tagger=("Linus Torvalds <torvalds@woody.linux-foundation.org> "
+                  "Sun 7 Jul 2007 12:54:34 +0700")))
+        self.assertCheckFails(Tag, self.make_tag_text(object_sha="xxx"))
+
+    def test_check_duplicates(self):
+        # duplicate each of the header fields
+        for i in xrange(4):
+            lines = self.make_tag_lines()
+            lines.insert(i, lines[i])
+            self.assertCheckFails(Tag, '\n'.join(lines))
+
+    def test_check_order(self):
+        lines = self.make_tag_lines()
+        headers = lines[:4]
+        rest = lines[4:]
+        # of all possible permutations, ensure only the original succeeds
+        for perm in permutations(headers):
+            perm = list(perm)
+            text = '\n'.join(perm + rest)
+            if perm == headers:
+                self.assertCheckSucceeds(Tag, text)
+            else:
+                self.assertCheckFails(Tag, text)
+
+
+class CheckTests(unittest.TestCase):
+
+    def test_check_hexsha(self):
+        check_hexsha(a_sha, "failed to check good sha")
+        self.assertRaises(ObjectFormatException, check_hexsha, '1' * 39,
+                          'sha too short')
+        self.assertRaises(ObjectFormatException, check_hexsha, '1' * 41,
+                          'sha too long')
+        self.assertRaises(ObjectFormatException, check_hexsha, 'x' * 40,
+                          'invalid characters')
+
+    def test_check_identity(self):
+        check_identity("Dave Borowitz <dborowitz@google.com>",
+                       "failed to check good identity")
+        check_identity("<dborowitz@google.com>",
+                       "failed to check good identity")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz", "no email")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <dborowitz", "incomplete email")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "dborowitz@google.com>", "incomplete email")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <<dborowitz@google.com>", "typo")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <dborowitz@google.com>>", "typo")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <dborowitz@google.com>xxx",
+                          "trailing characters")
+
 
 class TimezoneTests(unittest.TestCase):
 
     def test_parse_timezone_utc(self):
-        self.assertEquals(0, parse_timezone("+0000"))
+        self.assertEquals((0, False), parse_timezone("+0000"))
+
+    def test_parse_timezone_utc_negative(self):
+        self.assertEquals((0, True), parse_timezone("-0000"))
 
     def test_generate_timezone_utc(self):
         self.assertEquals("+0000", format_timezone(0))
 
+    def test_generate_timezone_utc_negative(self):
+        self.assertEquals("-0000", format_timezone(0, True))
+
     def test_parse_timezone_cet(self):
-        self.assertEquals(60 * 60, parse_timezone("+0100"))
+        self.assertEquals((60 * 60, False), parse_timezone("+0100"))
 
     def test_format_timezone_cet(self):
         self.assertEquals("+0100", format_timezone(60 * 60))
@@ -331,10 +645,12 @@ class TimezoneTests(unittest.TestCase):
         self.assertEquals("-0400", format_timezone(-4 * 60 * 60))
 
     def test_parse_timezone_pdt(self):
-        self.assertEquals(-4 * 60 * 60, parse_timezone("-0400"))
+        self.assertEquals((-4 * 60 * 60, False), parse_timezone("-0400"))
 
     def test_format_timezone_pdt_half(self):
-        self.assertEquals("-0440", format_timezone(int(((-4 * 60) - 40) * 60)))
+        self.assertEquals("-0440",
+            format_timezone(int(((-4 * 60) - 40) * 60)))
 
     def test_parse_timezone_pdt_half(self):
-        self.assertEquals(((-4 * 60) - 40) * 60, parse_timezone("-0440"))
+        self.assertEquals((((-4 * 60) - 40) * 60, False),
+            parse_timezone("-0440"))

+ 212 - 86
dulwich/tests/test_pack.py

@@ -1,17 +1,17 @@
 # test_pack.py -- Tests for the handling of git packs.
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # of the License, or (at your option) any later version of the license.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -23,9 +23,17 @@
 
 from cStringIO import StringIO
 import os
+import shutil
+import tempfile
 import unittest
+import zlib
 
+from dulwich.errors import (
+    ChecksumMismatch,
+    )
 from dulwich.objects import (
+    hex_to_sha,
+    sha_to_hex,
     Tree,
     )
 from dulwich.pack import (
@@ -35,7 +43,7 @@ from dulwich.pack import (
     create_delta,
     load_pack_index,
     hex_to_sha,
-    read_zlib,
+    read_zlib_chunks,
     sha_to_hex,
     write_pack_index_v1,
     write_pack_index_v2,
@@ -48,26 +56,39 @@ a_sha = '6f670c0fb53f9463760b7295fbb814e965fb20c8'
 tree_sha = 'b2a2766a2879c209ab1176e7e778b81ae422eeaa'
 commit_sha = 'f18faa16531ac570a3fdc8c7ca16682548dafd12'
 
+
 class PackTests(unittest.TestCase):
     """Base class for testing packs"""
-  
+
+    def setUp(self):
+        self.tempdir = tempfile.mkdtemp()
+
+    def tearDown(self):
+        shutil.rmtree(self.tempdir)
+
     datadir = os.path.join(os.path.dirname(__file__), 'data/packs')
-  
+
     def get_pack_index(self, sha):
         """Returns a PackIndex from the datadir with the given sha"""
         return load_pack_index(os.path.join(self.datadir, 'pack-%s.idx' % sha))
-  
+
     def get_pack_data(self, sha):
         """Returns a PackData object from the datadir with the given sha"""
         return PackData(os.path.join(self.datadir, 'pack-%s.pack' % sha))
-  
+
     def get_pack(self, sha):
         return Pack(os.path.join(self.datadir, 'pack-%s' % sha))
 
+    def assertSucceeds(self, func, *args, **kwargs):
+        try:
+            func(*args, **kwargs)
+        except ChecksumMismatch, e:
+            self.fail(e)
+
 
 class PackIndexTests(PackTests):
     """Class that tests the index of packfiles"""
-  
+
     def test_object_index(self):
         """Tests that the correct object offset is returned from the index."""
         p = self.get_pack_index(pack1_sha)
@@ -75,88 +96,118 @@ class PackIndexTests(PackTests):
         self.assertEqual(p.object_index(a_sha), 178)
         self.assertEqual(p.object_index(tree_sha), 138)
         self.assertEqual(p.object_index(commit_sha), 12)
-  
+
     def test_index_len(self):
         p = self.get_pack_index(pack1_sha)
         self.assertEquals(3, len(p))
-  
+
     def test_get_stored_checksum(self):
         p = self.get_pack_index(pack1_sha)
-        self.assertEquals("\xf2\x84\x8e*\xd1o2\x9a\xe1\xc9.;\x95\xe9\x18\x88\xda\xa5\xbd\x01", str(p.get_stored_checksum()))
-        self.assertEquals( 'r\x19\x80\xe8f\xaf\x9a_\x93\xadgAD\xe1E\x9b\x8b\xa3\xe7\xb7' , str(p.get_pack_checksum()))
-  
+        self.assertEquals('f2848e2ad16f329ae1c92e3b95e91888daa5bd01',
+                          sha_to_hex(p.get_stored_checksum()))
+        self.assertEquals('721980e866af9a5f93ad674144e1459b8ba3e7b7',
+                          sha_to_hex(p.get_pack_checksum()))
+
     def test_index_check(self):
         p = self.get_pack_index(pack1_sha)
-        self.assertEquals(True, p.check())
-  
+        self.assertSucceeds(p.check)
+
     def test_iterentries(self):
         p = self.get_pack_index(pack1_sha)
-        self.assertEquals([('og\x0c\x0f\xb5?\x94cv\x0br\x95\xfb\xb8\x14\xe9e\xfb \xc8', 178, None), ('\xb2\xa2vj(y\xc2\t\xab\x11v\xe7\xe7x\xb8\x1a\xe4"\xee\xaa', 138, None), ('\xf1\x8f\xaa\x16S\x1a\xc5p\xa3\xfd\xc8\xc7\xca\x16h%H\xda\xfd\x12', 12, None)], list(p.iterentries()))
-  
+        entries = [(sha_to_hex(s), o, c) for s, o, c in p.iterentries()]
+        self.assertEquals([
+          ('6f670c0fb53f9463760b7295fbb814e965fb20c8', 178, None),
+          ('b2a2766a2879c209ab1176e7e778b81ae422eeaa', 138, None),
+          ('f18faa16531ac570a3fdc8c7ca16682548dafd12', 12, None)
+          ], entries)
+
     def test_iter(self):
         p = self.get_pack_index(pack1_sha)
         self.assertEquals(set([tree_sha, commit_sha, a_sha]), set(p))
-  
+
 
 class TestPackDeltas(unittest.TestCase):
-  
-    test_string1 = "The answer was flailing in the wind"
-    test_string2 = "The answer was falling down the pipe"
-    test_string3 = "zzzzz"
-  
-    test_string_empty = ""
-    test_string_big = "Z" * 8192
-  
+
+    test_string1 = 'The answer was flailing in the wind'
+    test_string2 = 'The answer was falling down the pipe'
+    test_string3 = 'zzzzz'
+
+    test_string_empty = ''
+    test_string_big = 'Z' * 8192
+
     def _test_roundtrip(self, base, target):
         self.assertEquals(target,
-            apply_delta(base, create_delta(base, target)))
-  
+                          ''.join(apply_delta(base, create_delta(base, target))))
+
     def test_nochange(self):
         self._test_roundtrip(self.test_string1, self.test_string1)
-  
+
     def test_change(self):
         self._test_roundtrip(self.test_string1, self.test_string2)
-  
+
     def test_rewrite(self):
         self._test_roundtrip(self.test_string1, self.test_string3)
-  
+
     def test_overflow(self):
         self._test_roundtrip(self.test_string_empty, self.test_string_big)
 
 
 class TestPackData(PackTests):
     """Tests getting the data from the packfile."""
-  
+
     def test_create_pack(self):
         p = self.get_pack_data(pack1_sha)
-  
+
     def test_pack_len(self):
         p = self.get_pack_data(pack1_sha)
         self.assertEquals(3, len(p))
-  
+
     def test_index_check(self):
         p = self.get_pack_data(pack1_sha)
-        self.assertEquals(True, p.check())
-  
+        self.assertSucceeds(p.check)
+
     def test_iterobjects(self):
         p = self.get_pack_data(pack1_sha)
-        self.assertEquals([(12, 1, 'tree b2a2766a2879c209ab1176e7e778b81ae422eeaa\nauthor James Westby <jw+debian@jameswestby.net> 1174945067 +0100\ncommitter James Westby <jw+debian@jameswestby.net> 1174945067 +0100\n\nTest commit\n', 3775879613L), (138, 2, '100644 a\x00og\x0c\x0f\xb5?\x94cv\x0br\x95\xfb\xb8\x14\xe9e\xfb \xc8', 912998690L), (178, 3, 'test 1\n', 1373561701L)], list(p.iterobjects()))
-  
+        commit_data = ('tree b2a2766a2879c209ab1176e7e778b81ae422eeaa\n'
+                       'author James Westby <jw+debian@jameswestby.net> '
+                       '1174945067 +0100\n'
+                       'committer James Westby <jw+debian@jameswestby.net> '
+                       '1174945067 +0100\n'
+                       '\n'
+                       'Test commit\n')
+        blob_sha = '6f670c0fb53f9463760b7295fbb814e965fb20c8'
+        tree_data = '100644 a\0%s' % hex_to_sha(blob_sha)
+        actual = []
+        for offset, type_num, chunks, crc32 in p.iterobjects():
+            actual.append((offset, type_num, ''.join(chunks), crc32))
+        self.assertEquals([
+          (12, 1, commit_data, 3775879613L),
+          (138, 2, tree_data, 912998690L),
+          (178, 3, 'test 1\n', 1373561701L)
+          ], actual)
+
     def test_iterentries(self):
         p = self.get_pack_data(pack1_sha)
-        self.assertEquals(set([('og\x0c\x0f\xb5?\x94cv\x0br\x95\xfb\xb8\x14\xe9e\xfb \xc8', 178, 1373561701L), ('\xb2\xa2vj(y\xc2\t\xab\x11v\xe7\xe7x\xb8\x1a\xe4"\xee\xaa', 138, 912998690L), ('\xf1\x8f\xaa\x16S\x1a\xc5p\xa3\xfd\xc8\xc7\xca\x16h%H\xda\xfd\x12', 12, 3775879613L)]), set(p.iterentries()))
-  
+        entries = set((sha_to_hex(s), o, c) for s, o, c in p.iterentries())
+        self.assertEquals(set([
+          ('6f670c0fb53f9463760b7295fbb814e965fb20c8', 178, 1373561701L),
+          ('b2a2766a2879c209ab1176e7e778b81ae422eeaa', 138, 912998690L),
+          ('f18faa16531ac570a3fdc8c7ca16682548dafd12', 12, 3775879613L),
+          ]), entries)
+
     def test_create_index_v1(self):
         p = self.get_pack_data(pack1_sha)
-        p.create_index_v1("v1test.idx")
-        idx1 = load_pack_index("v1test.idx")
+        filename = os.path.join(self.tempdir, 'v1test.idx')
+        p.create_index_v1(filename)
+        idx1 = load_pack_index(filename)
         idx2 = self.get_pack_index(pack1_sha)
         self.assertEquals(idx1, idx2)
-  
+
     def test_create_index_v2(self):
         p = self.get_pack_data(pack1_sha)
-        p.create_index_v2("v2test.idx")
-        idx1 = load_pack_index("v2test.idx")
+        filename = os.path.join(self.tempdir, 'v2test.idx')
+        p.create_index_v2(filename)
+        idx1 = load_pack_index(filename)
         idx2 = self.get_pack_index(pack1_sha)
         self.assertEquals(idx1, idx2)
 
@@ -183,36 +234,38 @@ class TestPack(PackTests):
         """Tests random access for non-delta objects"""
         p = self.get_pack(pack1_sha)
         obj = p[a_sha]
-        self.assertEqual(obj._type, 'blob')
+        self.assertEqual(obj.type_name, 'blob')
         self.assertEqual(obj.sha().hexdigest(), a_sha)
         obj = p[tree_sha]
-        self.assertEqual(obj._type, 'tree')
+        self.assertEqual(obj.type_name, 'tree')
         self.assertEqual(obj.sha().hexdigest(), tree_sha)
         obj = p[commit_sha]
-        self.assertEqual(obj._type, 'commit')
+        self.assertEqual(obj.type_name, 'commit')
         self.assertEqual(obj.sha().hexdigest(), commit_sha)
 
     def test_copy(self):
         origpack = self.get_pack(pack1_sha)
-        self.assertEquals(True, origpack.index.check())
-        write_pack("Elch", [(x, "") for x in origpack.iterobjects()], 
-            len(origpack))
-        newpack = Pack("Elch")
+        self.assertSucceeds(origpack.index.check)
+        basename = os.path.join(self.tempdir, 'Elch')
+        write_pack(basename, [(x, '') for x in origpack.iterobjects()],
+                   len(origpack))
+        newpack = Pack(basename)
         self.assertEquals(origpack, newpack)
-        self.assertEquals(True, newpack.index.check())
+        self.assertSucceeds(newpack.index.check)
         self.assertEquals(origpack.name(), newpack.name())
-        self.assertEquals(origpack.index.get_pack_checksum(), 
+        self.assertEquals(origpack.index.get_pack_checksum(),
                           newpack.index.get_pack_checksum())
-        
-        self.assertTrue(
-                (origpack.index.version != newpack.index.version) or
-                (origpack.index.get_stored_checksum() == newpack.index.get_stored_checksum()))
+
+        wrong_version = origpack.index.version != newpack.index.version
+        orig_checksum = origpack.index.get_stored_checksum()
+        new_checksum = newpack.index.get_stored_checksum()
+        self.assertTrue(wrong_version or orig_checksum == new_checksum)
 
     def test_commit_obj(self):
         p = self.get_pack(pack1_sha)
         commit = p[commit_sha]
-        self.assertEquals("James Westby <jw+debian@jameswestby.net>",
-            commit.author)
+        self.assertEquals('James Westby <jw+debian@jameswestby.net>',
+                          commit.author)
         self.assertEquals([], commit.parents)
 
     def test_name(self):
@@ -220,69 +273,142 @@ class TestPack(PackTests):
         self.assertEquals(pack1_sha, p.name())
 
 
-class TestHexToSha(unittest.TestCase):
+pack_checksum = hex_to_sha('721980e866af9a5f93ad674144e1459b8ba3e7b7')
 
-    def test_simple(self):
-        self.assertEquals('\xab\xcd' * 10, hex_to_sha("abcd" * 10))
 
-    def test_reverse(self):
-        self.assertEquals("abcd" * 10, sha_to_hex('\xab\xcd' * 10))
+class BaseTestPackIndexWriting(object):
 
+    def setUp(self):
+        self.tempdir = tempfile.mkdtemp()
 
-class BaseTestPackIndexWriting(object):
+    def tearDown(self):
+        shutil.rmtree(self.tempdir)
+
+    def assertSucceeds(self, func, *args, **kwargs):
+        try:
+            func(*args, **kwargs)
+        except ChecksumMismatch, e:
+            self.fail(e)
 
     def test_empty(self):
-        pack_checksum = 'r\x19\x80\xe8f\xaf\x9a_\x93\xadgAD\xe1E\x9b\x8b\xa3\xe7\xb7'
-        self._write_fn("empty.idx", [], pack_checksum)
-        idx = load_pack_index("empty.idx")
-        self.assertTrue(idx.check())
+        filename = os.path.join(self.tempdir, 'empty.idx')
+        self._write_fn(filename, [], pack_checksum)
+        idx = load_pack_index(filename)
+        self.assertSucceeds(idx.check)
         self.assertEquals(idx.get_pack_checksum(), pack_checksum)
         self.assertEquals(0, len(idx))
 
     def test_single(self):
-        pack_checksum = 'r\x19\x80\xe8f\xaf\x9a_\x93\xadgAD\xe1E\x9b\x8b\xa3\xe7\xb7'
-        my_entries = [('og\x0c\x0f\xb5?\x94cv\x0br\x95\xfb\xb8\x14\xe9e\xfb \xc8', 178, 42)]
-        my_entries.sort()
-        self._write_fn("single.idx", my_entries, pack_checksum)
-        idx = load_pack_index("single.idx")
+        entry_sha = hex_to_sha('6f670c0fb53f9463760b7295fbb814e965fb20c8')
+        my_entries = [(entry_sha, 178, 42)]
+        filename = os.path.join(self.tempdir, 'single.idx')
+        self._write_fn(filename, my_entries, pack_checksum)
+        idx = load_pack_index(filename)
         self.assertEquals(idx.version, self._expected_version)
-        self.assertTrue(idx.check())
+        self.assertSucceeds(idx.check)
         self.assertEquals(idx.get_pack_checksum(), pack_checksum)
         self.assertEquals(1, len(idx))
         actual_entries = list(idx.iterentries())
         self.assertEquals(len(my_entries), len(actual_entries))
-        for a, b in zip(my_entries, actual_entries):
-            self.assertEquals(a[0], b[0])
-            self.assertEquals(a[1], b[1])
+        for mine, actual in zip(my_entries, actual_entries):
+            my_sha, my_offset, my_crc = mine
+            actual_sha, actual_offset, actual_crc = actual
+            self.assertEquals(my_sha, actual_sha)
+            self.assertEquals(my_offset, actual_offset)
             if self._has_crc32_checksum:
-                self.assertEquals(a[2], b[2])
+                self.assertEquals(my_crc, actual_crc)
             else:
-                self.assertTrue(b[2] is None)
+                self.assertTrue(actual_crc is None)
 
 
 class TestPackIndexWritingv1(unittest.TestCase, BaseTestPackIndexWriting):
 
     def setUp(self):
         unittest.TestCase.setUp(self)
+        BaseTestPackIndexWriting.setUp(self)
         self._has_crc32_checksum = False
         self._expected_version = 1
         self._write_fn = write_pack_index_v1
 
+    def tearDown(self):
+        unittest.TestCase.tearDown(self)
+        BaseTestPackIndexWriting.tearDown(self)
+
 
 class TestPackIndexWritingv2(unittest.TestCase, BaseTestPackIndexWriting):
 
     def setUp(self):
         unittest.TestCase.setUp(self)
+        BaseTestPackIndexWriting.setUp(self)
         self._has_crc32_checksum = True
         self._expected_version = 2
         self._write_fn = write_pack_index_v2
 
-TEST_COMP1 = """\x78\x9c\x9d\x8e\xc1\x0a\xc2\x30\x10\x44\xef\xf9\x8a\xbd\xa9\x08\x92\x86\xb4\x26\x20\xe2\xd9\x83\x78\xf2\xbe\x49\x37\xb5\xa5\x69\xca\x36\xf5\xfb\x4d\xfd\x04\x67\x6e\x33\xcc\xf0\x32\x13\x81\xc6\x16\x8d\xa9\xbd\xad\x6c\xe3\x8a\x03\x4a\x73\xd6\xda\xd5\xa6\x51\x2e\x58\x65\x6c\x13\xbc\x94\x4a\xcc\xc8\x34\x65\x78\xa4\x89\x04\xae\xf9\x9d\x18\xee\x34\x46\x62\x78\x11\x4f\x29\xf5\x03\x5c\x86\x5f\x70\x5b\x30\x3a\x3c\x25\xee\xae\x50\xa9\xf2\x60\xa4\xaa\x34\x1c\x65\x91\xf0\x29\xc6\x3e\x67\xfa\x6f\x2d\x9e\x9c\x3e\x7d\x4b\xc0\x34\x8f\xe8\x29\x6e\x48\xa1\xa0\xc4\x88\xf3\xfe\xb0\x5b\x20\x85\xb0\x50\x06\xe4\x6e\xdd\xca\xd3\x17\x26\xfa\x49\x23"""
+    def tearDown(self):
+        unittest.TestCase.tearDown(self)
+        BaseTestPackIndexWriting.tearDown(self)
+
 
+class ReadZlibTests(unittest.TestCase):
 
-class ZlibTests(unittest.TestCase):
+    decomp = (
+      'tree 4ada885c9196b6b6fa08744b5862bf92896fc002\n'
+      'parent None\n'
+      'author Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\n'
+      'committer Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\n'
+      '\n'
+      "Provide replacement for mmap()'s offset argument.")
+    comp = zlib.compress(decomp)
+    extra = 'nextobject'
+
+    def setUp(self):
+        self.read = StringIO(self.comp + self.extra).read
+
+    def test_decompress_size(self):
+        good_decomp_len = len(self.decomp)
+        self.assertRaises(ValueError, read_zlib_chunks, self.read, -1)
+        self.assertRaises(zlib.error, read_zlib_chunks, self.read,
+                          good_decomp_len - 1)
+        self.assertRaises(zlib.error, read_zlib_chunks, self.read,
+                          good_decomp_len + 1)
+
+    def test_decompress_truncated(self):
+        read = StringIO(self.comp[:10]).read
+        self.assertRaises(zlib.error, read_zlib_chunks, read, len(self.decomp))
+
+        read = StringIO(self.comp).read
+        self.assertRaises(zlib.error, read_zlib_chunks, read, len(self.decomp))
+
+    def test_decompress_empty(self):
+        comp = zlib.compress('')
+        read = StringIO(comp + self.extra).read
+        decomp, comp_len, unused_data = read_zlib_chunks(read, 0)
+        self.assertEqual('', ''.join(decomp))
+        self.assertEqual(len(comp), comp_len)
+        self.assertNotEquals('', unused_data)
+        self.assertEquals(self.extra, unused_data + read())
+
+    def _do_decompress_test(self, buffer_size):
+        decomp, comp_len, unused_data = read_zlib_chunks(
+          self.read, len(self.decomp), buffer_size=buffer_size)
+        self.assertEquals(self.decomp, ''.join(decomp))
+        self.assertEquals(len(self.comp), comp_len)
+        self.assertNotEquals('', unused_data)
+        self.assertEquals(self.extra, unused_data + self.read())
 
     def test_simple_decompress(self):
-        self.assertEquals(("tree 4ada885c9196b6b6fa08744b5862bf92896fc002\nparent None\nauthor Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\ncommitter Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\n\nProvide replacement for mmap()'s offset argument.", 158, 'Z'), 
-        read_zlib(StringIO(TEST_COMP1).read, 229))
+        self._do_decompress_test(4096)
+
+    # These buffer sizes are not intended to be realistic, but rather simulate
+    # larger buffer sizes that may end at various places.
+    def test_decompress_buffer_size_1(self):
+        self._do_decompress_test(1)
+
+    def test_decompress_buffer_size_2(self):
+        self._do_decompress_test(2)
+
+    def test_decompress_buffer_size_3(self):
+        self._do_decompress_test(3)
 
+    def test_decompress_buffer_size_4(self):
+        self._do_decompress_test(4)

+ 88 - 0
dulwich/tests/test_patch.py

@@ -0,0 +1,88 @@
+# test_patch.py -- tests for patch.py
+# Copryight (C) 2010 Jelmer Vernooij <jelmer@samba.org>
+# 
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) a later version.
+# 
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+# 
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Tests for patch.py."""
+
+from cStringIO import StringIO
+from unittest import TestCase
+
+from dulwich.objects import (
+    Commit,
+    Tree,
+    )
+from dulwich.patch import (
+    git_am_patch_split,
+    write_commit_patch,
+    )
+
+
+class WriteCommitPatchTests(TestCase):
+
+    def test_simple(self):
+        f = StringIO()
+        c = Commit()
+        c.committer = c.author = "Jelmer <jelmer@samba.org>"
+        c.commit_time = c.author_time = 1271350201
+        c.commit_timezone = c.author_timezone = 0
+        c.message = "This is the first line\nAnd this is the second line.\n"
+        c.tree = Tree().id
+        write_commit_patch(f, c, "CONTENTS", (1, 1), version="custom")
+        f.seek(0)
+        lines = f.readlines()
+        self.assertTrue(lines[0].startswith("From 0b0d34d1b5b596c928adc9a727a4b9e03d025298"))
+        self.assertEquals(lines[1], "From: Jelmer <jelmer@samba.org>\n")
+        self.assertTrue(lines[2].startswith("Date: "))
+        self.assertEquals([
+            "Subject: [PATCH 1/1] This is the first line\n",
+            "And this is the second line.\n",
+            "\n",
+            "\n",
+            "---\n"], lines[3:8])
+        self.assertEquals([
+            "CONTENTS-- \n",
+            "custom\n"], lines[-2:])
+        if len(lines) >= 12:
+            # diffstat may not be present
+            self.assertEquals(lines[8], " 0 files changed\n")
+
+
+class ReadGitAmPatch(TestCase):
+
+    def test_extract(self):
+        text = """From ff643aae102d8870cac88e8f007e70f58f3a7363 Mon Sep 17 00:00:00 2001
+From: Jelmer Vernooij <jelmer@samba.org>
+Date: Thu, 15 Apr 2010 15:40:28 +0200
+Subject: [PATCH 1/2] Remove executable bit from prey.ico (triggers a lintian warning).
+
+---
+ pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
+ 1 files changed, 0 insertions(+), 0 deletions(-)
+ mode change 100755 => 100644 pixmaps/prey.ico
+
+-- 
+1.7.0.4
+"""
+        c, diff, version = git_am_patch_split(StringIO(text))
+        self.assertEquals("Jelmer Vernooij <jelmer@samba.org>", c.committer)
+        self.assertEquals("Jelmer Vernooij <jelmer@samba.org>", c.author)
+        self.assertEquals(""" pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
+ 1 files changed, 0 insertions(+), 0 deletions(-)
+ mode change 100755 => 100644 pixmaps/prey.ico
+
+""", diff)
+        self.assertEquals("1.7.0.4", version)

+ 90 - 7
dulwich/tests/test_protocol.py

@@ -20,11 +20,12 @@
 """Tests for the smart protocol utility functions."""
 
 
-from cStringIO import StringIO
+from StringIO import StringIO
 from unittest import TestCase
 
 from dulwich.protocol import (
     Protocol,
+    ReceivableProtocol,
     extract_capabilities,
     extract_want_line_capabilities,
     ack_type,
@@ -33,12 +34,7 @@ from dulwich.protocol import (
     MULTI_ACK_DETAILED,
     )
 
-class ProtocolTests(TestCase):
-
-    def setUp(self):
-        self.rout = StringIO()
-        self.rin = StringIO()
-        self.proto = Protocol(self.rin.read, self.rout.write)
+class BaseProtocolTests(object):
 
     def test_write_pkt_line_none(self):
         self.proto.write_pkt_line(None)
@@ -82,6 +78,93 @@ class ProtocolTests(TestCase):
         self.assertRaises(AssertionError, self.proto.read_cmd)
 
 
+class ProtocolTests(BaseProtocolTests, TestCase):
+
+    def setUp(self):
+        TestCase.setUp(self)
+        self.rout = StringIO()
+        self.rin = StringIO()
+        self.proto = Protocol(self.rin.read, self.rout.write)
+
+
+class ReceivableStringIO(StringIO):
+    """StringIO with socket-like recv semantics for testing."""
+
+    def recv(self, size):
+        # fail fast if no bytes are available; in a real socket, this would
+        # block forever
+        if self.tell() == len(self.getvalue()):
+            raise AssertionError("Blocking read past end of socket")
+        if size == 1:
+            return self.read(1)
+        # calls shouldn't return quite as much as asked for
+        return self.read(size - 1)
+
+
+class ReceivableProtocolTests(BaseProtocolTests, TestCase):
+
+    def setUp(self):
+        TestCase.setUp(self)
+        self.rout = StringIO()
+        self.rin = ReceivableStringIO()
+        self.proto = ReceivableProtocol(self.rin.recv, self.rout.write)
+        self.proto._rbufsize = 8
+
+    def test_recv(self):
+        all_data = "1234567" * 10  # not a multiple of bufsize
+        self.rin.write(all_data)
+        self.rin.seek(0)
+        data = ""
+        # We ask for 8 bytes each time and actually read 7, so it should take
+        # exactly 10 iterations.
+        for _ in xrange(10):
+            data += self.proto.recv(10)
+        # any more reads would block
+        self.assertRaises(AssertionError, self.proto.recv, 10)
+        self.assertEquals(all_data, data)
+
+    def test_recv_read(self):
+        all_data = "1234567"  # recv exactly in one call
+        self.rin.write(all_data)
+        self.rin.seek(0)
+        self.assertEquals("1234", self.proto.recv(4))
+        self.assertEquals("567", self.proto.read(3))
+        self.assertRaises(AssertionError, self.proto.recv, 10)
+
+    def test_read_recv(self):
+        all_data = "12345678abcdefg"
+        self.rin.write(all_data)
+        self.rin.seek(0)
+        self.assertEquals("1234", self.proto.read(4))
+        self.assertEquals("5678abc", self.proto.recv(8))
+        self.assertEquals("defg", self.proto.read(4))
+        self.assertRaises(AssertionError, self.proto.recv, 10)
+
+    def test_mixed(self):
+        # arbitrary non-repeating string
+        all_data = ",".join(str(i) for i in xrange(100))
+        self.rin.write(all_data)
+        self.rin.seek(0)
+        data = ""
+
+        for i in xrange(1, 100):
+            data += self.proto.recv(i)
+            # if we get to the end, do a non-blocking read instead of blocking
+            if len(data) + i > len(all_data):
+                data += self.proto.recv(i)
+                # ReceivableStringIO leaves off the last byte unless we ask
+                # nicely
+                data += self.proto.recv(1)
+                break
+            else:
+                data += self.proto.read(i)
+        else:
+            # didn't break, something must have gone wrong
+            self.fail()
+
+        self.assertEquals(all_data, data)
+
+
 class CapabilitiesTestCase(TestCase):
 
     def test_plain(self):

+ 477 - 141
dulwich/tests/test_repository.py

@@ -1,17 +1,17 @@
 # test_repository.py -- tests for repository.py
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
-# of the License or (at your option) any later version of 
+# of the License or (at your option) any later version of
 # the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -25,44 +25,30 @@ import os
 import shutil
 import tempfile
 import unittest
+import warnings
 
 from dulwich import errors
+from dulwich.object_store import (
+    tree_lookup_path,
+    )
+from dulwich import objects
 from dulwich.repo import (
     check_ref_format,
+    DictRefsContainer,
     Repo,
     read_packed_refs,
     read_packed_refs_with_peeled,
     write_packed_refs,
     _split_ref_line,
     )
+from dulwich.tests.utils import (
+    open_repo,
+    tear_down_repo,
+    )
 
 missing_sha = 'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
 
 
-def open_repo(name):
-    """Open a copy of a repo in a temporary directory.
-
-    Use this function for accessing repos in dulwich/tests/data/repos to avoid
-    accidentally or intentionally modifying those repos in place. Use
-    tear_down_repo to delete any temp files created.
-
-    :param name: The name of the repository, relative to
-        dulwich/tests/data/repos
-    :returns: An initialized Repo object that lives in a temporary directory.
-    """
-    temp_dir = tempfile.mkdtemp()
-    repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos', name)
-    temp_repo_dir = os.path.join(temp_dir, name)
-    shutil.copytree(repo_dir, temp_repo_dir, symlinks=True)
-    return Repo(temp_repo_dir)
-
-def tear_down_repo(repo):
-    """Tear down a test repository."""
-    temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
-    shutil.rmtree(temp_dir)
-
-
-
 class CreateRepositoryTests(unittest.TestCase):
 
     def test_create(self):
@@ -86,73 +72,158 @@ class RepositoryTests(unittest.TestCase):
     def test_simple_props(self):
         r = self._repo = open_repo('a.git')
         self.assertEqual(r.controldir(), r.path)
-  
+
     def test_ref(self):
         r = self._repo = open_repo('a.git')
         self.assertEqual(r.ref('refs/heads/master'),
                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
-  
+
+    def test_setitem(self):
+        r = self._repo = open_repo('a.git')
+        r["refs/tags/foo"] = 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
+        self.assertEquals('a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+                          r["refs/tags/foo"].id)
+
     def test_get_refs(self):
         r = self._repo = open_repo('a.git')
         self.assertEqual({
-            'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097', 
-            'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
+            'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+            'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+            'refs/tags/mytag': '28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
+            'refs/tags/mytag-packed': 'b0931cadc54336e78a1d980420e3268903b57a50',
             }, r.get_refs())
-  
+
     def test_head(self):
         r = self._repo = open_repo('a.git')
         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
-  
+
     def test_get_object(self):
         r = self._repo = open_repo('a.git')
         obj = r.get_object(r.head())
-        self.assertEqual(obj._type, 'commit')
-  
+        self.assertEqual(obj.type_name, 'commit')
+
     def test_get_object_non_existant(self):
         r = self._repo = open_repo('a.git')
         self.assertRaises(KeyError, r.get_object, missing_sha)
-  
+
+    def test_contains_object(self):
+        r = self._repo = open_repo('a.git')
+        self.assertTrue(r.head() in r)
+
+    def test_contains_ref(self):
+        r = self._repo = open_repo('a.git')
+        self.assertTrue("HEAD" in r)
+
+    def test_contains_missing(self):
+        r = self._repo = open_repo('a.git')
+        self.assertFalse("bar" in r)
+
     def test_commit(self):
         r = self._repo = open_repo('a.git')
-        obj = r.commit(r.head())
-        self.assertEqual(obj._type, 'commit')
-  
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            obj = r.commit(r.head())
+        finally:
+            warnings.resetwarnings()
+        self.assertEqual(obj.type_name, 'commit')
+
     def test_commit_not_commit(self):
         r = self._repo = open_repo('a.git')
-        self.assertRaises(errors.NotCommitError,
-                          r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
-  
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            self.assertRaises(errors.NotCommitError,
+                r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
+        finally:
+            warnings.resetwarnings()
+
     def test_tree(self):
         r = self._repo = open_repo('a.git')
-        commit = r.commit(r.head())
-        tree = r.tree(commit.tree)
-        self.assertEqual(tree._type, 'tree')
+        commit = r[r.head()]
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            tree = r.tree(commit.tree)
+        finally:
+            warnings.resetwarnings()
+        self.assertEqual(tree.type_name, 'tree')
         self.assertEqual(tree.sha().hexdigest(), commit.tree)
-  
+
     def test_tree_not_tree(self):
         r = self._repo = open_repo('a.git')
-        self.assertRaises(errors.NotTreeError, r.tree, r.head())
-  
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            self.assertRaises(errors.NotTreeError, r.tree, r.head())
+        finally:
+            warnings.resetwarnings()
+
+    def test_tag(self):
+        r = self._repo = open_repo('a.git')
+        tag_sha = '28237f4dc30d0d462658d6b937b08a0f0b6ef55a'
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            tag = r.tag(tag_sha)
+        finally:
+            warnings.resetwarnings()
+        self.assertEqual(tag.type_name, 'tag')
+        self.assertEqual(tag.sha().hexdigest(), tag_sha)
+        obj_class, obj_sha = tag.object
+        self.assertEqual(obj_class, objects.Commit)
+        self.assertEqual(obj_sha, r.head())
+
+    def test_tag_not_tag(self):
+        r = self._repo = open_repo('a.git')
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            self.assertRaises(errors.NotTagError, r.tag, r.head())
+        finally:
+            warnings.resetwarnings()
+
+    def test_get_peeled(self):
+        # unpacked ref
+        r = self._repo = open_repo('a.git')
+        tag_sha = '28237f4dc30d0d462658d6b937b08a0f0b6ef55a'
+        self.assertNotEqual(r[tag_sha].sha().hexdigest(), r.head())
+        self.assertEqual(r.get_peeled('refs/tags/mytag'), r.head())
+
+        # packed ref with cached peeled value
+        packed_tag_sha = 'b0931cadc54336e78a1d980420e3268903b57a50'
+        parent_sha = r[r.head()].parents[0]
+        self.assertNotEqual(r[packed_tag_sha].sha().hexdigest(), parent_sha)
+        self.assertEqual(r.get_peeled('refs/tags/mytag-packed'), parent_sha)
+
+        # TODO: add more corner cases to test repo
+
+    def test_get_peeled_not_tag(self):
+        r = self._repo = open_repo('a.git')
+        self.assertEqual(r.get_peeled('HEAD'), r.head())
+
     def test_get_blob(self):
         r = self._repo = open_repo('a.git')
-        commit = r.commit(r.head())
-        tree = r.tree(commit.tree)
+        commit = r[r.head()]
+        tree = r[commit.tree]
         blob_sha = tree.entries()[0][2]
-        blob = r.get_blob(blob_sha)
-        self.assertEqual(blob._type, 'blob')
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            blob = r.get_blob(blob_sha)
+        finally:
+            warnings.resetwarnings()
+        self.assertEqual(blob.type_name, 'blob')
         self.assertEqual(blob.sha().hexdigest(), blob_sha)
-  
+
     def test_get_blob_notblob(self):
         r = self._repo = open_repo('a.git')
-        self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
-    
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
+        finally:
+            warnings.resetwarnings()
+
     def test_linear_history(self):
         r = self._repo = open_repo('a.git')
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, [r.head(),
                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
-  
+
     def test_merge_history(self):
         r = self._repo = open_repo('simple_merge.git')
         history = r.revision_history(r.head())
@@ -162,12 +233,12 @@ class RepositoryTests(unittest.TestCase):
                                 '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6',
                                 '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
                                 '0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
-  
+
     def test_revision_history_missing_commit(self):
         r = self._repo = open_repo('simple_merge.git')
         self.assertRaises(errors.MissingCommitError, r.revision_history,
                           missing_sha)
-  
+
     def test_out_of_order_merge(self):
         """Test that revision history is ordered by date, not parent order."""
         r = self._repo = open_repo('ooo_merge.git')
@@ -177,7 +248,7 @@ class RepositoryTests(unittest.TestCase):
                                 'f507291b64138b875c28e03469025b1ea20bc614',
                                 'fb5b0425c7ce46959bec94d54b9a157645e114f5',
                                 'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
-  
+
     def test_get_tags_empty(self):
         r = self._repo = open_repo('ooo_merge.git')
         self.assertEqual({}, r.refs.as_dict('refs/tags'))
@@ -186,6 +257,158 @@ class RepositoryTests(unittest.TestCase):
         r = self._repo = open_repo('ooo_merge.git')
         self.assertEquals({}, r.get_config())
 
+    def test_common_revisions(self):
+        """
+        This test demonstrates that ``find_common_revisions()`` actually returns
+        common heads, not revisions; dulwich already uses
+        ``find_common_revisions()`` in such a manner (see
+        ``Repo.fetch_objects()``).
+        """
+
+        expected_shas = set(['60dacdc733de308bb77bb76ce0fb0f9b44c9769e'])
+
+        # Source for objects.
+        r_base = open_repo('simple_merge.git')
+
+        # Re-create each-side of the merge in simple_merge.git.
+        #
+        # Since the trees and blobs are missing, the repository created is
+        # corrupted, but we're only checking for commits for the purpose of this
+        # test, so it's immaterial.
+        r1_dir = tempfile.mkdtemp()
+        r1_commits = ['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd', # HEAD
+                      '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
+                      '0d89f20333fbb1d2f3a94da77f4981373d8f4310']
+
+        r2_dir = tempfile.mkdtemp()
+        r2_commits = ['4cffe90e0a41ad3f5190079d7c8f036bde29cbe6', # HEAD
+                      '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
+                      '0d89f20333fbb1d2f3a94da77f4981373d8f4310']
+
+        try:
+            r1 = Repo.init_bare(r1_dir)
+            map(lambda c: r1.object_store.add_object(r_base.get_object(c)), \
+                r1_commits)
+            r1.refs['HEAD'] = r1_commits[0]
+
+            r2 = Repo.init_bare(r2_dir)
+            map(lambda c: r2.object_store.add_object(r_base.get_object(c)), \
+                r2_commits)
+            r2.refs['HEAD'] = r2_commits[0]
+
+            # Finally, the 'real' testing!
+            shas = r2.object_store.find_common_revisions(r1.get_graph_walker())
+            self.assertEqual(set(shas), expected_shas)
+
+            shas = r1.object_store.find_common_revisions(r2.get_graph_walker())
+            self.assertEqual(set(shas), expected_shas)
+        finally:
+            shutil.rmtree(r1_dir)
+            shutil.rmtree(r2_dir)
+
+
+class BuildRepoTests(unittest.TestCase):
+    """Tests that build on-disk repos from scratch.
+
+    Repos live in a temp dir and are torn down after each test. They start with
+    a single commit in master having single file named 'a'.
+    """
+
+    def setUp(self):
+        repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
+        os.makedirs(repo_dir)
+        r = self._repo = Repo.init(repo_dir)
+        self.assertFalse(r.bare)
+        self.assertEqual('ref: refs/heads/master', r.refs.read_ref('HEAD'))
+        self.assertRaises(KeyError, lambda: r.refs['refs/heads/master'])
+
+        f = open(os.path.join(r.path, 'a'), 'wb')
+        try:
+            f.write('file contents')
+        finally:
+            f.close()
+        r.stage(['a'])
+        commit_sha = r.do_commit('msg',
+                                 committer='Test Committer <test@nodomain.com>',
+                                 author='Test Author <test@nodomain.com>',
+                                 commit_timestamp=12345, commit_timezone=0,
+                                 author_timestamp=12345, author_timezone=0)
+        self.assertEqual([], r[commit_sha].parents)
+        self._root_commit = commit_sha
+
+    def tearDown(self):
+        tear_down_repo(self._repo)
+
+    def test_build_repo(self):
+        r = self._repo
+        self.assertEqual('ref: refs/heads/master', r.refs.read_ref('HEAD'))
+        self.assertEqual(self._root_commit, r.refs['refs/heads/master'])
+        expected_blob = objects.Blob.from_string('file contents')
+        self.assertEqual(expected_blob.data, r[expected_blob.id].data)
+        actual_commit = r[self._root_commit]
+        self.assertEqual('msg', actual_commit.message)
+
+    def test_commit_modified(self):
+        r = self._repo
+        f = open(os.path.join(r.path, 'a'), 'wb')
+        try:
+            f.write('new contents')
+        finally:
+            f.close()
+        r.stage(['a'])
+        commit_sha = r.do_commit('modified a',
+                                 committer='Test Committer <test@nodomain.com>',
+                                 author='Test Author <test@nodomain.com>',
+                                 commit_timestamp=12395, commit_timezone=0,
+                                 author_timestamp=12395, author_timezone=0)
+        self.assertEqual([self._root_commit], r[commit_sha].parents)
+        _, blob_id = tree_lookup_path(r.get_object, r[commit_sha].tree, 'a')
+        self.assertEqual('new contents', r[blob_id].data)
+
+    def test_commit_deleted(self):
+        r = self._repo
+        os.remove(os.path.join(r.path, 'a'))
+        r.stage(['a'])
+        commit_sha = r.do_commit('deleted a',
+                                 committer='Test Committer <test@nodomain.com>',
+                                 author='Test Author <test@nodomain.com>',
+                                 commit_timestamp=12395, commit_timezone=0,
+                                 author_timestamp=12395, author_timezone=0)
+        self.assertEqual([self._root_commit], r[commit_sha].parents)
+        self.assertEqual([], list(r.open_index()))
+        tree = r[r[commit_sha].tree]
+        self.assertEqual([], tree.iteritems())
+
+    def test_commit_fail_ref(self):
+        r = self._repo
+
+        def set_if_equals(name, old_ref, new_ref):
+            return False
+        r.refs.set_if_equals = set_if_equals
+
+        def add_if_new(name, new_ref):
+            self.fail('Unexpected call to add_if_new')
+        r.refs.add_if_new = add_if_new
+
+        old_shas = set(r.object_store)
+        self.assertRaises(errors.CommitError, r.do_commit, 'failed commit',
+                          committer='Test Committer <test@nodomain.com>',
+                          author='Test Author <test@nodomain.com>',
+                          commit_timestamp=12345, commit_timezone=0,
+                          author_timestamp=12345, author_timezone=0)
+        new_shas = set(r.object_store) - old_shas
+        self.assertEqual(1, len(new_shas))
+        # Check that the new commit (now garbage) was added.
+        new_commit = r[new_shas.pop()]
+        self.assertEqual(r[self._root_commit].tree, new_commit.tree)
+        self.assertEqual('failed commit', new_commit.message)
+
+    def test_stage_deleted(self):
+        r = self._repo
+        os.remove(os.path.join(r.path, 'a'))
+        r.stage(['a'])
+        r.stage(['a'])  # double-stage a deleted path
+
 
 class CheckRefFormatTests(unittest.TestCase):
     """Tests for the check_ref_format function.
@@ -239,12 +462,12 @@ class PackedRefsFileTests(unittest.TestCase):
 
     def test_read_with_peeled(self):
         f = StringIO('%s ref/1\n%s ref/2\n^%s\n%s ref/4' % (
-            ONES, TWOS, THREES, FOURS))
+          ONES, TWOS, THREES, FOURS))
         self.assertEqual([
-            (ONES, 'ref/1', None),
-            (TWOS, 'ref/2', THREES),
-            (FOURS, 'ref/4', None),
-            ], list(read_packed_refs_with_peeled(f)))
+          (ONES, 'ref/1', None),
+          (TWOS, 'ref/2', THREES),
+          (FOURS, 'ref/4', None),
+          ], list(read_packed_refs_with_peeled(f)))
 
     def test_read_with_peeled_errors(self):
         f = StringIO('^%s\n%s ref/1' % (TWOS, ONES))
@@ -258,8 +481,8 @@ class PackedRefsFileTests(unittest.TestCase):
         write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS},
                           {'ref/1': THREES})
         self.assertEqual(
-            "# pack-refs with: peeled\n%s ref/1\n^%s\n%s ref/2\n" % (
-            ONES, THREES, TWOS), f.getvalue())
+          "# pack-refs with: peeled\n%s ref/1\n^%s\n%s ref/2\n" % (
+          ONES, THREES, TWOS), f.getvalue())
 
     def test_write_without_peeled(self):
         f = StringIO()
@@ -267,62 +490,39 @@ class PackedRefsFileTests(unittest.TestCase):
         self.assertEqual("%s ref/1\n%s ref/2\n" % (ONES, TWOS), f.getvalue())
 
 
-class RefsContainerTests(unittest.TestCase):
+# Dict of refs that we expect all RefsContainerTests subclasses to define.
+_TEST_REFS = {
+  'HEAD': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+  'refs/heads/master': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+  'refs/heads/packed': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+  'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
+  'refs/tags/refs-0.2': '3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8',
+  }
 
-    def setUp(self):
-        self._repo = open_repo('refs.git')
-        self._refs = self._repo.refs
 
-    def tearDown(self):
-        tear_down_repo(self._repo)
-
-    def test_get_packed_refs(self):
-        self.assertEqual(
-            {'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c'},
-            self._refs.get_packed_refs())
+class RefsContainerTests(object):
 
     def test_keys(self):
-        self.assertEqual([
-            'HEAD',
-            'refs/heads/loop',
-            'refs/heads/master',
-            'refs/tags/refs-0.1',
-            ], sorted(list(self._refs.keys())))
-        self.assertEqual(['loop', 'master'],
-                         sorted(self._refs.keys('refs/heads')))
-        self.assertEqual(['refs-0.1'], list(self._refs.keys('refs/tags')))
+        actual_keys = set(self._refs.keys())
+        self.assertEqual(set(self._refs.allkeys()), actual_keys)
+        # ignore the symref loop if it exists
+        actual_keys.discard('refs/heads/loop')
+        self.assertEqual(set(_TEST_REFS.iterkeys()), actual_keys)
+
+        actual_keys = self._refs.keys('refs/heads')
+        actual_keys.discard('loop')
+        self.assertEqual(['master', 'packed'], sorted(actual_keys))
+        self.assertEqual(['refs-0.1', 'refs-0.2'],
+                         sorted(self._refs.keys('refs/tags')))
 
     def test_as_dict(self):
-        # refs/heads/loop does not show up
-        self.assertEqual({
-            'HEAD': '42d06bd4b77fed026b154d16493e5deab78f02ec',
-            'refs/heads/master': '42d06bd4b77fed026b154d16493e5deab78f02ec',
-            'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
-            }, self._refs.as_dict())
+        # refs/heads/loop does not show up even if it exists
+        self.assertEqual(_TEST_REFS, self._refs.as_dict())
 
     def test_setitem(self):
         self._refs['refs/some/ref'] = '42d06bd4b77fed026b154d16493e5deab78f02ec'
         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
                          self._refs['refs/some/ref'])
-        f = open(os.path.join(self._refs.path, 'refs', 'some', 'ref'), 'rb')
-        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
-                          f.read()[:40])
-        f.close()
-
-    def test_setitem_symbolic(self):
-        ones = '1' * 40
-        self._refs['HEAD'] = ones
-        self.assertEqual(ones, self._refs['HEAD'])
-
-        # ensure HEAD was not modified
-        f = open(os.path.join(self._refs.path, 'HEAD'), 'rb')
-        self.assertEqual('ref: refs/heads/master', iter(f).next().rstrip('\n'))
-        f.close()
-
-        # ensure the symbolic link was written through
-        f = open(os.path.join(self._refs.path, 'refs', 'heads', 'master'), 'rb')
-        self.assertEqual(ones, f.read()[:40])
-        f.close()
 
     def test_set_if_equals(self):
         nines = '9' * 40
@@ -331,17 +531,13 @@ class RefsContainerTests(unittest.TestCase):
                          self._refs['HEAD'])
 
         self.assertTrue(self._refs.set_if_equals(
-            'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec', nines))
+          'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec', nines))
         self.assertEqual(nines, self._refs['HEAD'])
 
-        # ensure symref was followed
+        self.assertTrue(self._refs.set_if_equals('refs/heads/master', None,
+                                                 nines))
         self.assertEqual(nines, self._refs['refs/heads/master'])
 
-        self.assertFalse(os.path.exists(
-            os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
-        self.assertFalse(os.path.exists(
-            os.path.join(self._refs.path, 'HEAD.lock')))
-
     def test_add_if_new(self):
         nines = '9' * 40
         self.assertFalse(self._refs.add_if_new('refs/heads/master', nines))
@@ -351,10 +547,23 @@ class RefsContainerTests(unittest.TestCase):
         self.assertTrue(self._refs.add_if_new('refs/some/ref', nines))
         self.assertEqual(nines, self._refs['refs/some/ref'])
 
-        # don't overwrite packed ref
-        self.assertFalse(self._refs.add_if_new('refs/tags/refs-0.1', nines))
-        self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
-                         self._refs['refs/tags/refs-0.1'])
+    def test_set_symbolic_ref(self):
+        self._refs.set_symbolic_ref('refs/heads/symbolic', 'refs/heads/master')
+        self.assertEqual('ref: refs/heads/master',
+                         self._refs.read_loose_ref('refs/heads/symbolic'))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/heads/symbolic'])
+
+    def test_set_symbolic_ref_overwrite(self):
+        nines = '9' * 40
+        self.assertFalse('refs/heads/symbolic' in self._refs)
+        self._refs['refs/heads/symbolic'] = nines
+        self.assertEqual(nines, self._refs.read_loose_ref('refs/heads/symbolic'))
+        self._refs.set_symbolic_ref('refs/heads/symbolic', 'refs/heads/master')
+        self.assertEqual('ref: refs/heads/master',
+                         self._refs.read_loose_ref('refs/heads/symbolic'))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/heads/symbolic'])
 
     def test_check_refname(self):
         try:
@@ -370,21 +579,131 @@ class RefsContainerTests(unittest.TestCase):
         self.assertRaises(KeyError, self._refs._check_refname, 'refs')
         self.assertRaises(KeyError, self._refs._check_refname, 'notrefs/foo')
 
+    def test_contains(self):
+        self.assertTrue('refs/heads/master' in self._refs)
+        self.assertFalse('refs/heads/bar' in self._refs)
+
+    def test_delitem(self):
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                          self._refs['refs/heads/master'])
+        del self._refs['refs/heads/master']
+        self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
+
+    def test_remove_if_equals(self):
+        self.assertFalse(self._refs.remove_if_equals('HEAD', 'c0ffee'))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['HEAD'])
+        self.assertTrue(self._refs.remove_if_equals(
+          'refs/tags/refs-0.2', '3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8'))
+        self.assertFalse('refs/tags/refs-0.2' in self._refs)
+
+
+class DictRefsContainerTests(RefsContainerTests, unittest.TestCase):
+
+    def setUp(self):
+        self._refs = DictRefsContainer(dict(_TEST_REFS))
+
+
+class DiskRefsContainerTests(RefsContainerTests, unittest.TestCase):
+
+    def setUp(self):
+        self._repo = open_repo('refs.git')
+        self._refs = self._repo.refs
+
+    def tearDown(self):
+        tear_down_repo(self._repo)
+
+    def test_get_packed_refs(self):
+        self.assertEqual({
+          'refs/heads/packed': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+          'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
+          }, self._refs.get_packed_refs())
+
+    def test_get_peeled_not_packed(self):
+        # not packed
+        self.assertEqual(None, self._refs.get_peeled('refs/tags/refs-0.2'))
+        self.assertEqual('3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8',
+                         self._refs['refs/tags/refs-0.2'])
+
+        # packed, known not peelable
+        self.assertEqual(self._refs['refs/heads/packed'],
+                         self._refs.get_peeled('refs/heads/packed'))
+
+        # packed, peeled
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs.get_peeled('refs/tags/refs-0.1'))
+
+    def test_setitem(self):
+        RefsContainerTests.test_setitem(self)
+        f = open(os.path.join(self._refs.path, 'refs', 'some', 'ref'), 'rb')
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                          f.read()[:40])
+        f.close()
+
+    def test_setitem_symbolic(self):
+        ones = '1' * 40
+        self._refs['HEAD'] = ones
+        self.assertEqual(ones, self._refs['HEAD'])
+
+        # ensure HEAD was not modified
+        f = open(os.path.join(self._refs.path, 'HEAD'), 'rb')
+        self.assertEqual('ref: refs/heads/master', iter(f).next().rstrip('\n'))
+        f.close()
+
+        # ensure the symbolic link was written through
+        f = open(os.path.join(self._refs.path, 'refs', 'heads', 'master'), 'rb')
+        self.assertEqual(ones, f.read()[:40])
+        f.close()
+
+    def test_set_if_equals(self):
+        RefsContainerTests.test_set_if_equals(self)
+
+        # ensure symref was followed
+        self.assertEqual('9' * 40, self._refs['refs/heads/master'])
+
+        # ensure lockfile was deleted
+        self.assertFalse(os.path.exists(
+          os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
+        self.assertFalse(os.path.exists(
+          os.path.join(self._refs.path, 'HEAD.lock')))
+
+    def test_add_if_new_packed(self):
+        # don't overwrite packed ref
+        self.assertFalse(self._refs.add_if_new('refs/tags/refs-0.1', '9' * 40))
+        self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
+                         self._refs['refs/tags/refs-0.1'])
+
+    def test_add_if_new_symbolic(self):
+        # Use an empty repo instead of the default.
+        tear_down_repo(self._repo)
+        repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
+        os.makedirs(repo_dir)
+        self._repo = Repo.init(repo_dir)
+        refs = self._repo.refs
+
+        nines = '9' * 40
+        self.assertEqual('ref: refs/heads/master', refs.read_ref('HEAD'))
+        self.assertFalse('refs/heads/master' in refs)
+        self.assertTrue(refs.add_if_new('HEAD', nines))
+        self.assertEqual('ref: refs/heads/master', refs.read_ref('HEAD'))
+        self.assertEqual(nines, refs['HEAD'])
+        self.assertEqual(nines, refs['refs/heads/master'])
+        self.assertFalse(refs.add_if_new('HEAD', '1' * 40))
+        self.assertEqual(nines, refs['HEAD'])
+        self.assertEqual(nines, refs['refs/heads/master'])
+
     def test_follow(self):
         self.assertEquals(
-            ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
-            self._refs._follow('HEAD'))
+          ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
+          self._refs._follow('HEAD'))
         self.assertEquals(
-            ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
-            self._refs._follow('refs/heads/master'))
+          ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
+          self._refs._follow('refs/heads/master'))
         self.assertRaises(KeyError, self._refs._follow, 'notrefs/foo')
         self.assertRaises(KeyError, self._refs._follow, 'refs/heads/loop')
 
     def test_delitem(self):
-        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
-                          self._refs['refs/heads/master'])
-        del self._refs['refs/heads/master']
-        self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
+        RefsContainerTests.test_delitem(self)
         ref_file = os.path.join(self._refs.path, 'refs', 'heads', 'master')
         self.assertFalse(os.path.exists(ref_file))
         self.assertFalse('refs/heads/master' in self._refs.get_packed_refs())
@@ -398,17 +717,12 @@ class RefsContainerTests(unittest.TestCase):
                          self._refs['refs/heads/master'])
         self.assertFalse(os.path.exists(os.path.join(self._refs.path, 'HEAD')))
 
-    def test_remove_if_equals(self):
-        nines = '9' * 40
-        self.assertFalse(self._refs.remove_if_equals('HEAD', 'c0ffee'))
-        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
-                         self._refs['HEAD'])
-
+    def test_remove_if_equals_symref(self):
         # HEAD is a symref, so shouldn't equal its dereferenced value
         self.assertFalse(self._refs.remove_if_equals(
-            'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
+          'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
         self.assertTrue(self._refs.remove_if_equals(
-            'refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
+          'refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
         self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
 
         # HEAD is now a broken symref
@@ -421,10 +735,32 @@ class RefsContainerTests(unittest.TestCase):
         self.assertFalse(os.path.exists(
             os.path.join(self._refs.path, 'HEAD.lock')))
 
+    def test_remove_packed_without_peeled(self):
+        refs_file = os.path.join(self._repo.path, 'packed-refs')
+        f = open(refs_file)
+        refs_data = f.read()
+        f.close()
+        f = open(refs_file, 'w')
+        f.write('\n'.join(l for l in refs_data.split('\n')
+                          if not l or l[0] not in '#^'))
+        f.close()
+        self._repo = Repo(self._repo.path)
+        refs = self._repo.refs
+        self.assertTrue(refs.remove_if_equals(
+          'refs/heads/packed', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
+
+    def test_remove_if_equals_packed(self):
         # test removing ref that is only packed
         self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
                          self._refs['refs/tags/refs-0.1'])
         self.assertTrue(
-            self._refs.remove_if_equals('refs/tags/refs-0.1',
-            'df6800012397fb85c56e7418dd4eb9405dee075c'))
+          self._refs.remove_if_equals('refs/tags/refs-0.1',
+          'df6800012397fb85c56e7418dd4eb9405dee075c'))
         self.assertRaises(KeyError, lambda: self._refs['refs/tags/refs-0.1'])
+
+    def test_read_ref(self):
+        self.assertEqual('ref: refs/heads/master', self._refs.read_ref("HEAD"))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+            self._refs.read_ref("refs/heads/packed"))
+        self.assertEqual(None,
+            self._refs.read_ref("nonexistant"))

+ 199 - 78
dulwich/tests/test_server.py

@@ -20,32 +20,34 @@
 """Tests for the smart protocol server."""
 
 
-from cStringIO import StringIO
 from unittest import TestCase
 
 from dulwich.errors import (
     GitProtocolError,
     )
 from dulwich.server import (
-    UploadPackHandler,
-    ProtocolGraphWalker,
-    SingleAckGraphWalkerImpl,
+    Backend,
+    DictBackend,
+    BackendRepo,
+    Handler,
     MultiAckGraphWalkerImpl,
     MultiAckDetailedGraphWalkerImpl,
+    ProtocolGraphWalker,
+    SingleAckGraphWalkerImpl,
+    UploadPackHandler,
     )
 
-from dulwich.protocol import (
-    SINGLE_ACK,
-    MULTI_ACK,
-    )
 
 ONE = '1' * 40
 TWO = '2' * 40
 THREE = '3' * 40
 FOUR = '4' * 40
 FIVE = '5' * 40
+SIX = '6' * 40
+
 
 class TestProto(object):
+
     def __init__(self):
         self._output = []
         self._received = {0: [], 1: [], 2: [], 3: []}
@@ -75,76 +77,158 @@ class TestProto(object):
             return None
 
 
-class UploadPackHandlerTestCase(TestCase):
+class HandlerTestCase(TestCase):
+
     def setUp(self):
-        self._handler = UploadPackHandler(None, None, None)
+        self._handler = Handler(Backend(), None)
+        self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
+        self._handler.required_capabilities = lambda: ('cap2',)
 
-    def test_set_client_capabilities(self):
+    def assertSucceeds(self, func, *args, **kwargs):
         try:
-            self._handler.set_client_capabilities([])
-        except GitProtocolError:
-            self.fail()
+            func(*args, **kwargs)
+        except GitProtocolError, e:
+            self.fail(e)
 
-        try:
-            self._handler.set_client_capabilities([
-                'multi_ack', 'side-band-64k', 'thin-pack', 'ofs-delta'])
-        except GitProtocolError:
-            self.fail()
-
-    def test_set_client_capabilities_error(self):
-        self.assertRaises(GitProtocolError,
-                          self._handler.set_client_capabilities,
-                          ['weird_ack_level', 'ofs-delta'])
-        try:
-            self._handler.set_client_capabilities(['include-tag'])
-        except GitProtocolError:
-            self.fail()
+    def test_capability_line(self):
+        self.assertEquals('cap1 cap2 cap3', self._handler.capability_line())
+
+    def test_set_client_capabilities(self):
+        set_caps = self._handler.set_client_capabilities
+        self.assertSucceeds(set_caps, ['cap2'])
+        self.assertSucceeds(set_caps, ['cap1', 'cap2'])
+
+        # different order
+        self.assertSucceeds(set_caps, ['cap3', 'cap1', 'cap2'])
+
+        # error cases
+        self.assertRaises(GitProtocolError, set_caps, ['capxxx', 'cap2'])
+        self.assertRaises(GitProtocolError, set_caps, ['cap1', 'cap3'])
+
+        # ignore innocuous but unknown capabilities
+        self.assertRaises(GitProtocolError, set_caps, ['cap2', 'ignoreme'])
+        self.assertFalse('ignoreme' in self._handler.capabilities())
+        self._handler.innocuous_capabilities = lambda: ('ignoreme',)
+        self.assertSucceeds(set_caps, ['cap2', 'ignoreme'])
+
+    def test_has_capability(self):
+        self.assertRaises(GitProtocolError, self._handler.has_capability, 'cap')
+        caps = self._handler.capabilities()
+        self._handler.set_client_capabilities(caps)
+        for cap in caps:
+            self.assertTrue(self._handler.has_capability(cap))
+        self.assertFalse(self._handler.has_capability('capxxx'))
+
+
+class UploadPackHandlerTestCase(TestCase):
+
+    def setUp(self):
+        self._backend = DictBackend({"/": BackendRepo()})
+        self._handler = UploadPackHandler(self._backend,
+                ["/", "host=lolcathost"], None, None)
+        self._handler.proto = TestProto()
+
+    def test_progress(self):
+        caps = self._handler.required_capabilities()
+        self._handler.set_client_capabilities(caps)
+        self._handler.progress('first message')
+        self._handler.progress('second message')
+        self.assertEqual('first message',
+                         self._handler.proto.get_received_line(2))
+        self.assertEqual('second message',
+                         self._handler.proto.get_received_line(2))
+        self.assertEqual(None, self._handler.proto.get_received_line(2))
+
+    def test_no_progress(self):
+        caps = list(self._handler.required_capabilities()) + ['no-progress']
+        self._handler.set_client_capabilities(caps)
+        self._handler.progress('first message')
+        self._handler.progress('second message')
+        self.assertEqual(None, self._handler.proto.get_received_line(2))
+
+    def test_get_tagged(self):
+        refs = {
+            'refs/tags/tag1': ONE,
+            'refs/tags/tag2': TWO,
+            'refs/heads/master': FOUR,  # not a tag, no peeled value
+            }
+        peeled = {
+            'refs/tags/tag1': '1234',
+            'refs/tags/tag2': '5678',
+            }
+
+        class TestRepo(object):
+            def get_peeled(self, ref):
+                return peeled.get(ref, refs[ref])
+
+        caps = list(self._handler.required_capabilities()) + ['include-tag']
+        self._handler.set_client_capabilities(caps)
+        self.assertEquals({'1234': ONE, '5678': TWO},
+                          self._handler.get_tagged(refs, repo=TestRepo()))
+
+        # non-include-tag case
+        caps = self._handler.required_capabilities()
+        self._handler.set_client_capabilities(caps)
+        self.assertEquals({}, self._handler.get_tagged(refs, repo=TestRepo()))
 
 
 class TestCommit(object):
+
     def __init__(self, sha, parents, commit_time):
         self.id = sha
-        self._parents = parents
+        self.parents = parents
         self.commit_time = commit_time
-
-    def get_parents(self):
-        return self._parents
+        self.type_name = "commit"
 
     def __repr__(self):
         return '%s(%s)' % (self.__class__.__name__, self._sha)
 
 
+class TestRepo(object):
+    def __init__(self):
+        self.peeled = {}
+
+    def get_peeled(self, name):
+        return self.peeled[name]
+
+
 class TestBackend(object):
-    def __init__(self, objects):
+
+    def __init__(self, repo, objects):
+        self.repo = repo
         self.object_store = objects
 
 
-class TestHandler(object):
+class TestUploadPackHandler(Handler):
+
     def __init__(self, objects, proto):
-        self.backend = TestBackend(objects)
+        self.backend = TestBackend(TestRepo(), objects)
         self.proto = proto
         self.stateless_rpc = False
         self.advertise_refs = False
 
     def capabilities(self):
-        return 'multi_ack'
+        return ('multi_ack',)
 
 
 class ProtocolGraphWalkerTestCase(TestCase):
+
     def setUp(self):
         # Create the following commit tree:
         #   3---5
         #  /
         # 1---2---4
         self._objects = {
-            ONE: TestCommit(ONE, [], 111),
-            TWO: TestCommit(TWO, [ONE], 222),
-            THREE: TestCommit(THREE, [ONE], 333),
-            FOUR: TestCommit(FOUR, [TWO], 444),
-            FIVE: TestCommit(FIVE, [THREE], 555),
-            }
+          ONE: TestCommit(ONE, [], 111),
+          TWO: TestCommit(TWO, [ONE], 222),
+          THREE: TestCommit(THREE, [ONE], 333),
+          FOUR: TestCommit(FOUR, [TWO], 444),
+          FIVE: TestCommit(FIVE, [THREE], 555),
+          }
+
         self._walker = ProtocolGraphWalker(
-            TestHandler(self._objects, TestProto()))
+            TestUploadPackHandler(self._objects, TestProto()),
+            self._objects, None)
 
     def test_is_satisfied_no_haves(self):
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))
@@ -173,13 +257,13 @@ class ProtocolGraphWalkerTestCase(TestCase):
 
     def test_read_proto_line(self):
         self._walker.proto.set_output([
-            'want %s' % ONE,
-            'want %s' % TWO,
-            'have %s' % THREE,
-            'foo %s' % FOUR,
-            'bar',
-            'done',
-            ])
+          'want %s' % ONE,
+          'want %s' % TWO,
+          'have %s' % THREE,
+          'foo %s' % FOUR,
+          'bar',
+          'done',
+          ])
         self.assertEquals(('want', ONE), self._walker.read_proto_line())
         self.assertEquals(('want', TWO), self._walker.read_proto_line())
         self.assertEquals(('have', THREE), self._walker.read_proto_line())
@@ -192,10 +276,11 @@ class ProtocolGraphWalkerTestCase(TestCase):
         self.assertRaises(GitProtocolError, self._walker.determine_wants, {})
 
         self._walker.proto.set_output([
-            'want %s multi_ack' % ONE,
-            'want %s' % TWO,
-            ])
+          'want %s multi_ack' % ONE,
+          'want %s' % TWO,
+          ])
         heads = {'ref1': ONE, 'ref2': TWO, 'ref3': THREE}
+        self._walker.get_peeled = heads.get
         self.assertEquals([ONE, TWO], self._walker.determine_wants(heads))
 
         self._walker.proto.set_output(['want %s multi_ack' % FOUR])
@@ -210,10 +295,40 @@ class ProtocolGraphWalkerTestCase(TestCase):
         self._walker.proto.set_output(['want %s multi_ack' % FOUR])
         self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
 
+    def test_determine_wants_advertisement(self):
+        self._walker.proto.set_output([])
+        # advertise branch tips plus tag
+        heads = {'ref4': FOUR, 'ref5': FIVE, 'tag6': SIX}
+        peeled = {'ref4': FOUR, 'ref5': FIVE, 'tag6': FIVE}
+        self._walker.get_peeled = peeled.get
+        self._walker.determine_wants(heads)
+        lines = []
+        while True:
+            line = self._walker.proto.get_received_line()
+            if line == 'None':
+                break
+            # strip capabilities list if present
+            if '\x00' in line:
+                line = line[:line.index('\x00')]
+            lines.append(line.rstrip())
+
+        self.assertEquals([
+          '%s ref4' % FOUR,
+          '%s ref5' % FIVE,
+          '%s tag6^{}' % FIVE,
+          '%s tag6' % SIX,
+          ], sorted(lines))
+
+        # ensure peeled tag was advertised immediately following tag
+        for i, line in enumerate(lines):
+            if line.endswith(' tag6'):
+                self.assertEquals('%s tag6^{}' % FIVE, lines[i+1])
+
     # TODO: test commit time cutoff
 
 
 class TestProtocolGraphWalker(object):
+
     def __init__(self):
         self.acks = []
         self.lines = []
@@ -241,14 +356,15 @@ class TestProtocolGraphWalker(object):
 
 class AckGraphWalkerImplTestCase(TestCase):
     """Base setup and asserts for AckGraphWalker tests."""
+
     def setUp(self):
         self._walker = TestProtocolGraphWalker()
         self._walker.lines = [
-            ('have', TWO),
-            ('have', ONE),
-            ('have', THREE),
-            ('done', None),
-            ]
+          ('have', TWO),
+          ('have', ONE),
+          ('have', THREE),
+          ('done', None),
+          ]
         self._impl = self.impl_cls(self._walker)
 
     def assertNoAck(self):
@@ -270,6 +386,7 @@ class AckGraphWalkerImplTestCase(TestCase):
 
 
 class SingleAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
+
     impl_cls = SingleAckGraphWalkerImpl
 
     def test_single_ack(self):
@@ -335,7 +452,9 @@ class SingleAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
         self.assertNextEquals(None)
         self.assertNak()
 
+
 class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
+
     impl_cls = MultiAckGraphWalkerImpl
 
     def test_multi_ack(self):
@@ -371,17 +490,17 @@ class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
 
     def test_multi_ack_flush(self):
         self._walker.lines = [
-            ('have', TWO),
-            (None, None),
-            ('have', ONE),
-            ('have', THREE),
-            ('done', None),
-            ]
+          ('have', TWO),
+          (None, None),
+          ('have', ONE),
+          ('have', THREE),
+          ('done', None),
+          ]
         self.assertNextEquals(TWO)
         self.assertNoAck()
 
         self.assertNextEquals(ONE)
-        self.assertNak() # nak the flush-pkt
+        self.assertNak()  # nak the flush-pkt
 
         self._walker.done = True
         self._impl.ack(ONE)
@@ -407,7 +526,9 @@ class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
         self.assertNextEquals(None)
         self.assertNak()
 
+
 class MultiAckDetailedGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
+
     impl_cls = MultiAckDetailedGraphWalkerImpl
 
     def test_multi_ack(self):
@@ -444,17 +565,17 @@ class MultiAckDetailedGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
     def test_multi_ack_flush(self):
         # same as ack test but contains a flush-pkt in the middle
         self._walker.lines = [
-            ('have', TWO),
-            (None, None),
-            ('have', ONE),
-            ('have', THREE),
-            ('done', None),
-            ]
+          ('have', TWO),
+          (None, None),
+          ('have', ONE),
+          ('have', THREE),
+          ('done', None),
+          ]
         self.assertNextEquals(TWO)
         self.assertNoAck()
 
         self.assertNextEquals(ONE)
-        self.assertNak() # nak the flush-pkt
+        self.assertNak()  # nak the flush-pkt
 
         self._walker.done = True
         self._impl.ack(ONE)
@@ -483,12 +604,12 @@ class MultiAckDetailedGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
     def test_multi_ack_nak_flush(self):
         # same as nak test but contains a flush-pkt in the middle
         self._walker.lines = [
-            ('have', TWO),
-            (None, None),
-            ('have', ONE),
-            ('have', THREE),
-            ('done', None),
-            ]
+          ('have', TWO),
+          (None, None),
+          ('have', ONE),
+          ('have', THREE),
+          ('done', None),
+          ]
         self.assertNextEquals(TWO)
         self.assertNoAck()
 

+ 69 - 55
dulwich/tests/test_web.py

@@ -23,8 +23,6 @@ import re
 from unittest import TestCase
 
 from dulwich.objects import (
-    type_map,
-    Tag,
     Blob,
     )
 from dulwich.web import (
@@ -42,9 +40,11 @@ from dulwich.web import (
 
 class WebTestCase(TestCase):
     """Base TestCase that sets up some useful instance vars."""
+
     def setUp(self):
         self._environ = {}
-        self._req = HTTPGitRequest(self._environ, self._start_response)
+        self._req = HTTPGitRequest(self._environ, self._start_response,
+                                   handlers=self._handlers())
         self._status = None
         self._headers = []
 
@@ -52,6 +52,9 @@ class WebTestCase(TestCase):
         self._status = status
         self._headers = list(headers)
 
+    def _handlers(self):
+        return None
+
 
 class DumbHandlersTestCase(WebTestCase):
 
@@ -97,15 +100,11 @@ class DumbHandlersTestCase(WebTestCase):
         self._environ['QUERY_STRING'] = ''
 
         class TestTag(object):
-            type = Tag().type
-
-            def __init__(self, sha, obj_type, obj_sha):
+            def __init__(self, sha, obj_class, obj_sha):
                 self.sha = lambda: sha
-                self.object = (obj_type, obj_sha)
+                self.object = (obj_class, obj_sha)
 
         class TestBlob(object):
-            type = Blob().type
-
             def __init__(self, sha):
                 self.sha = lambda: sha
 
@@ -113,13 +112,19 @@ class DumbHandlersTestCase(WebTestCase):
         blob2 = TestBlob('222')
         blob3 = TestBlob('333')
 
-        tag1 = TestTag('aaa', TestTag.type, 'bbb')
-        tag2 = TestTag('bbb', TestBlob.type, '222')
+        tag1 = TestTag('aaa', Blob, '222')
 
-        class TestBackend(object):
-            def __init__(self):
-                objects = [blob1, blob2, blob3, tag1, tag2]
-                self.repo = dict((o.sha(), o) for o in objects)
+        class TestRepo(object):
+
+            def __init__(self, objects, peeled):
+                self._objects = dict((o.sha(), o) for o in objects)
+                self._peeled = peeled
+
+            def get_peeled(self, sha):
+                return self._peeled[sha]
+
+            def __getitem__(self, sha):
+                return self._objects[sha]
 
             def get_refs(self):
                 return {
@@ -129,43 +134,55 @@ class DumbHandlersTestCase(WebTestCase):
                     'refs/tags/blob-tag': blob3.sha(),
                     }
 
+        class TestBackend(object):
+            def __init__(self):
+                objects = [blob1, blob2, blob3, tag1]
+                self.repo = TestRepo(objects, {
+                  'HEAD': '000',
+                  'refs/heads/master': blob1.sha(),
+                  'refs/tags/tag-tag': blob2.sha(),
+                  'refs/tags/blob-tag': blob3.sha(),
+                  })
+
+            def open_repository(self, path):
+                assert path == '/'
+                return self.repo
+
+            def get_refs(self):
+                return {
+                  'HEAD': '000',
+                  'refs/heads/master': blob1.sha(),
+                  'refs/tags/tag-tag': tag1.sha(),
+                  'refs/tags/blob-tag': blob3.sha(),
+                  }
+
+        mat = re.search('.*', '//info/refs')
         self.assertEquals(['111\trefs/heads/master\n',
                            '333\trefs/tags/blob-tag\n',
                            'aaa\trefs/tags/tag-tag\n',
                            '222\trefs/tags/tag-tag^{}\n'],
-                          list(get_info_refs(self._req, TestBackend(), None)))
+                          list(get_info_refs(self._req, TestBackend(), mat)))
 
 
 class SmartHandlersTestCase(WebTestCase):
 
-    class TestProtocol(object):
-        def __init__(self, handler):
-            self._handler = handler
-
-        def write_pkt_line(self, line):
-            if line is None:
-                self._handler.write('flush-pkt\n')
-            else:
-                self._handler.write('pkt-line: %s' % line)
-
     class _TestUploadPackHandler(object):
-        def __init__(self, backend, read, write, stateless_rpc=False,
+        def __init__(self, backend, args, proto, stateless_rpc=False,
                      advertise_refs=False):
-            self.read = read
-            self.write = write
-            self.proto = SmartHandlersTestCase.TestProtocol(self)
+            self.args = args
+            self.proto = proto
             self.stateless_rpc = stateless_rpc
             self.advertise_refs = advertise_refs
 
         def handle(self):
-            self.write('handled input: %s' % self.read())
+            self.proto.write('handled input: %s' % self.proto.recv(1024))
 
-    def _MakeHandler(self, *args, **kwargs):
+    def _make_handler(self, *args, **kwargs):
         self._handler = self._TestUploadPackHandler(*args, **kwargs)
         return self._handler
 
-    def services(self):
-        return {'git-upload-pack': self._MakeHandler}
+    def _handlers(self):
+        return {'git-upload-pack': self._make_handler}
 
     def test_handle_service_request_unknown(self):
         mat = re.search('.*', '/git-evil-handler')
@@ -175,8 +192,7 @@ class SmartHandlersTestCase(WebTestCase):
     def test_handle_service_request(self):
         self._environ['wsgi.input'] = StringIO('foo')
         mat = re.search('.*', '/git-upload-pack')
-        output = ''.join(handle_service_request(self._req, 'backend', mat,
-                                                services=self.services()))
+        output = ''.join(handle_service_request(self._req, 'backend', mat))
         self.assertEqual('handled input: foo', output)
         response_type = 'application/x-git-upload-pack-response'
         self.assertTrue(('Content-Type', response_type) in self._headers)
@@ -187,26 +203,24 @@ class SmartHandlersTestCase(WebTestCase):
         self._environ['wsgi.input'] = StringIO('foobar')
         self._environ['CONTENT_LENGTH'] = 3
         mat = re.search('.*', '/git-upload-pack')
-        output = ''.join(handle_service_request(self._req, 'backend', mat,
-                                                services=self.services()))
+        output = ''.join(handle_service_request(self._req, 'backend', mat))
         self.assertEqual('handled input: foo', output)
         response_type = 'application/x-git-upload-pack-response'
         self.assertTrue(('Content-Type', response_type) in self._headers)
 
     def test_get_info_refs_unknown(self):
         self._environ['QUERY_STRING'] = 'service=git-evil-handler'
-        list(get_info_refs(self._req, 'backend', None,
-                           services=self.services()))
+        list(get_info_refs(self._req, 'backend', None))
         self.assertEquals(HTTP_FORBIDDEN, self._status)
 
     def test_get_info_refs(self):
         self._environ['wsgi.input'] = StringIO('foo')
         self._environ['QUERY_STRING'] = 'service=git-upload-pack'
 
-        output = ''.join(get_info_refs(self._req, 'backend', None,
-                                       services=self.services()))
-        self.assertEquals(('pkt-line: # service=git-upload-pack\n'
-                           'flush-pkt\n'
+        mat = re.search('.*', '/git-upload-pack')
+        output = ''.join(get_info_refs(self._req, 'backend', mat))
+        self.assertEquals(('001e# service=git-upload-pack\n'
+                           '0000'
                            # input is ignored by the handler
                            'handled input: '), output)
         self.assertTrue(self._handler.advertise_refs)
@@ -257,13 +271,13 @@ class HTTPGitRequestTestCase(WebTestCase):
         self._req.respond(status=402, content_type='some/type',
                           headers=[('X-Foo', 'foo'), ('X-Bar', 'bar')])
         self.assertEquals(set([
-            ('X-Foo', 'foo'),
-            ('X-Bar', 'bar'),
-            ('Content-Type', 'some/type'),
-            ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
-            ('Pragma', 'no-cache'),
-            ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
-            ]), set(self._headers))
+          ('X-Foo', 'foo'),
+          ('X-Bar', 'bar'),
+          ('Content-Type', 'some/type'),
+          ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
+          ('Pragma', 'no-cache'),
+          ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
+          ]), set(self._headers))
         self.assertEquals(402, self._status)
 
 
@@ -280,10 +294,10 @@ class HTTPGitApplicationTestCase(TestCase):
             return 'output'
 
         self._app.services = {
-            ('GET', re.compile('/foo$')): test_handler,
+          ('GET', re.compile('/foo$')): test_handler,
         }
         environ = {
-            'PATH_INFO': '/foo',
-            'REQUEST_METHOD': 'GET',
-            }
+          'PATH_INFO': '/foo',
+          'REQUEST_METHOD': 'GET',
+          }
         self.assertEquals('output', self._app(environ, None))

+ 86 - 0
dulwich/tests/utils.py

@@ -0,0 +1,86 @@
+# utils.py -- Test utilities for Dulwich.
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Utility functions common to Dulwich tests."""
+
+
+import datetime
+import os
+import shutil
+import tempfile
+import time
+
+from dulwich.objects import Commit
+from dulwich.repo import Repo
+
+
+def open_repo(name):
+    """Open a copy of a repo in a temporary directory.
+
+    Use this function for accessing repos in dulwich/tests/data/repos to avoid
+    accidentally or intentionally modifying those repos in place. Use
+    tear_down_repo to delete any temp files created.
+
+    :param name: The name of the repository, relative to
+        dulwich/tests/data/repos
+    :returns: An initialized Repo object that lives in a temporary directory.
+    """
+    temp_dir = tempfile.mkdtemp()
+    repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos', name)
+    temp_repo_dir = os.path.join(temp_dir, name)
+    shutil.copytree(repo_dir, temp_repo_dir, symlinks=True)
+    return Repo(temp_repo_dir)
+
+
+def tear_down_repo(repo):
+    """Tear down a test repository."""
+    temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
+    shutil.rmtree(temp_dir)
+
+
+def make_object(cls, **attrs):
+    """Make an object for testing and assign some members.
+
+    :param attrs: dict of attributes to set on the new object.
+    :return: A newly initialized object of type cls.
+    """
+    obj = cls()
+    for name, value in attrs.iteritems():
+        setattr(obj, name, value)
+    return obj
+
+
+def make_commit(**attrs):
+    """Make a Commit object with a default set of members.
+
+    :param attrs: dict of attributes to overwrite from the default values.
+    :return: A newly initialized Commit object.
+    """
+    default_time = int(time.mktime(datetime.datetime(2010, 1, 1).timetuple()))
+    all_attrs = {'author': 'Test Author <test@nodomain.com>',
+                 'author_time': default_time,
+                 'author_timezone': 0,
+                 'committer': 'Test Committer <test@nodomain.com>',
+                 'commit_time': default_time,
+                 'commit_timezone': 0,
+                 'message': 'Test message.',
+                 'parents': [],
+                 'tree': '0' * 40}
+    all_attrs.update(attrs)
+    return make_object(Commit, **all_attrs)

+ 90 - 75
dulwich/web.py

@@ -19,30 +19,30 @@
 """HTTP server for dulwich that implements the git smart HTTP protocol."""
 
 from cStringIO import StringIO
-import cgi
-import os
 import re
 import time
 
-from dulwich.objects import (
-    Tag,
-    num_type_map,
-    )
-from dulwich.repo import (
-    Repo,
+try:
+    from urlparse import parse_qs
+except ImportError:
+    from dulwich.misc import parse_qs
+from dulwich.protocol import (
+    ReceivableProtocol,
     )
 from dulwich.server import (
-    GitBackend,
     ReceivePackHandler,
     UploadPackHandler,
+    DEFAULT_HANDLERS,
     )
 
+
+# HTTP error strings
 HTTP_OK = '200 OK'
 HTTP_NOT_FOUND = '404 Not Found'
 HTTP_FORBIDDEN = '403 Forbidden'
 
 
-def date_time_string(self, timestamp=None):
+def date_time_string(timestamp=None):
     # Based on BaseHTTPServer.py in python2.5
     weekdays = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
     months = [None,
@@ -55,6 +55,22 @@ def date_time_string(self, timestamp=None):
             weekdays[wd], day, months[month], year, hh, mm, ss)
 
 
+def url_prefix(mat):
+    """Extract the URL prefix from a regex match.
+
+    :param mat: A regex match object.
+    :returns: The URL prefix, defined as the text before the match in the
+        original string. Normalized to start with one leading slash and end with
+        zero.
+    """
+    return '/' + mat.string[:mat.start()].strip('/')
+
+
+def get_repo(backend, mat):
+    """Get a Repo instance for the given backend and URL regex match."""
+    return backend.open_repository(url_prefix(mat))
+
+
 def send_file(req, f, content_type):
     """Send a file-like object to the request output.
 
@@ -67,28 +83,30 @@ def send_file(req, f, content_type):
         yield req.not_found('File not found')
         return
     try:
-        try:
-            req.respond(HTTP_OK, content_type)
-            while True:
-                data = f.read(10240)
-                if not data:
-                    break
-                yield data
-        except IOError:
-            yield req.not_found('Error reading file')
-    finally:
+        req.respond(HTTP_OK, content_type)
+        while True:
+            data = f.read(10240)
+            if not data:
+                break
+            yield data
+        f.close()
+    except IOError:
         f.close()
+        yield req.not_found('Error reading file')
+    except:
+        f.close()
+        raise
 
 
 def get_text_file(req, backend, mat):
     req.nocache()
-    return send_file(req, backend.repo.get_named_file(mat.group()),
+    return send_file(req, get_repo(backend, mat).get_named_file(mat.group()),
                      'text/plain')
 
 
 def get_loose_object(req, backend, mat):
     sha = mat.group(1) + mat.group(2)
-    object_store = backend.object_store
+    object_store = get_repo(backend, mat).object_store
     if not object_store.contains_loose(sha):
         yield req.not_found('Object not found')
         return
@@ -103,33 +121,29 @@ def get_loose_object(req, backend, mat):
 
 def get_pack_file(req, backend, mat):
     req.cache_forever()
-    return send_file(req, backend.repo.get_named_file(mat.group()),
-                     'application/x-git-packed-objects', False)
+    return send_file(req, get_repo(backend, mat).get_named_file(mat.group()),
+                     'application/x-git-packed-objects')
 
 
 def get_idx_file(req, backend, mat):
     req.cache_forever()
-    return send_file(req, backend.repo.get_named_file(mat.group()),
-                     'application/x-git-packed-objects-toc', False)
+    return send_file(req, get_repo(backend, mat).get_named_file(mat.group()),
+                     'application/x-git-packed-objects-toc')
 
 
-services = {'git-upload-pack': UploadPackHandler,
-            'git-receive-pack': ReceivePackHandler}
-def get_info_refs(req, backend, mat, services=None):
-    if services is None:
-        services = services
-    params = cgi.parse_qs(req.environ['QUERY_STRING'])
+def get_info_refs(req, backend, mat):
+    params = parse_qs(req.environ['QUERY_STRING'])
     service = params.get('service', [None])[0]
-    if service:
-        handler_cls = services.get(service, None)
+    if service and not req.dumb:
+        handler_cls = req.handlers.get(service, None)
         if handler_cls is None:
             yield req.forbidden('Unsupported service %s' % service)
             return
         req.nocache()
         req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
         output = StringIO()
-        dummy_input = StringIO()  # GET request, handler doesn't need to read
-        handler = handler_cls(backend, dummy_input.read, output.write,
+        proto = ReceivableProtocol(StringIO().read, output.write)
+        handler = handler_cls(backend, [url_prefix(mat)], proto,
                               stateless_rpc=True, advertise_refs=True)
         handler.proto.write_pkt_line('# service=%s\n' % service)
         handler.proto.write_pkt_line(None)
@@ -140,32 +154,27 @@ def get_info_refs(req, backend, mat, services=None):
         # TODO: select_getanyfile() (see http-backend.c)
         req.nocache()
         req.respond(HTTP_OK, 'text/plain')
-        refs = backend.get_refs()
+        repo = get_repo(backend, mat)
+        refs = repo.get_refs()
         for name in sorted(refs.iterkeys()):
             # get_refs() includes HEAD as a special case, but we don't want to
             # advertise it
             if name == 'HEAD':
                 continue
             sha = refs[name]
-            o = backend.repo[sha]
+            o = repo[sha]
             if not o:
                 continue
             yield '%s\t%s\n' % (sha, name)
-            obj_type = num_type_map[o.type]
-            if obj_type == Tag:
-                while obj_type == Tag:
-                    num_type, sha = o.object
-                    obj_type = num_type_map[num_type]
-                    o = backend.repo[sha]
-                if not o:
-                    continue
-                yield '%s\t%s^{}\n' % (o.sha(), name)
+            peeled_sha = repo.get_peeled(name)
+            if peeled_sha != sha:
+                yield '%s\t%s^{}\n' % (peeled_sha, name)
 
 
 def get_info_packs(req, backend, mat):
     req.nocache()
     req.respond(HTTP_OK, 'text/plain')
-    for pack in backend.object_store.packs:
+    for pack in get_repo(backend, mat).object_store.packs:
         yield 'P pack-%s.pack\n' % pack.name()
 
 
@@ -176,6 +185,7 @@ class _LengthLimitedFile(object):
     Content-Length bytes are read. This behavior is required by the WSGI spec
     but not implemented in wsgiref as of 2.5.
     """
+
     def __init__(self, input, max_bytes):
         self._input = input
         self._bytes_avail = max_bytes
@@ -190,11 +200,10 @@ class _LengthLimitedFile(object):
 
     # TODO: support more methods as necessary
 
-def handle_service_request(req, backend, mat, services=services):
-    if services is None:
-        services = services
+
+def handle_service_request(req, backend, mat):
     service = mat.group().lstrip('/')
-    handler_cls = services.get(service, None)
+    handler_cls = req.handlers.get(service, None)
     if handler_cls is None:
         yield req.forbidden('Unsupported service %s' % service)
         return
@@ -209,7 +218,8 @@ def handle_service_request(req, backend, mat, services=services):
     # content-length
     if 'CONTENT_LENGTH' in req.environ:
         input = _LengthLimitedFile(input, int(req.environ['CONTENT_LENGTH']))
-    handler = handler_cls(backend, input.read, output.write, stateless_rpc=True)
+    proto = ReceivableProtocol(input.read, output.write)
+    handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True)
     handler.handle()
     yield output.getvalue()
 
@@ -220,8 +230,10 @@ class HTTPGitRequest(object):
     :ivar environ: the WSGI environment for the request.
     """
 
-    def __init__(self, environ, start_response):
+    def __init__(self, environ, start_response, dumb=False, handlers=None):
         self.environ = environ
+        self.dumb = dumb
+        self.handlers = handlers and handlers or DEFAULT_HANDLERS
         self._start_response = start_response
         self._cache_headers = []
         self._headers = []
@@ -255,19 +267,19 @@ class HTTPGitRequest(object):
     def nocache(self):
         """Set the response to never be cached by the client."""
         self._cache_headers = [
-            ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
-            ('Pragma', 'no-cache'),
-            ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
-            ]
+          ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
+          ('Pragma', 'no-cache'),
+          ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
+          ]
 
     def cache_forever(self):
         """Set the response to be cached forever by the client."""
         now = time.time()
         self._cache_headers = [
-            ('Date', date_time_string(now)),
-            ('Expires', date_time_string(now + 31536000)),
-            ('Cache-Control', 'public, max-age=31536000'),
-            ]
+          ('Date', date_time_string(now)),
+          ('Expires', date_time_string(now + 31536000)),
+          ('Cache-Control', 'public, max-age=31536000'),
+          ]
 
 
 class HTTPGitApplication(object):
@@ -277,26 +289,29 @@ class HTTPGitApplication(object):
     """
 
     services = {
-        ('GET', re.compile('/HEAD$')): get_text_file,
-        ('GET', re.compile('/info/refs$')): get_info_refs,
-        ('GET', re.compile('/objects/info/alternates$')): get_text_file,
-        ('GET', re.compile('/objects/info/http-alternates$')): get_text_file,
-        ('GET', re.compile('/objects/info/packs$')): get_info_packs,
-        ('GET', re.compile('/objects/([0-9a-f]{2})/([0-9a-f]{38})$')): get_loose_object,
-        ('GET', re.compile('/objects/pack/pack-([0-9a-f]{40})\\.pack$')): get_pack_file,
-        ('GET', re.compile('/objects/pack/pack-([0-9a-f]{40})\\.idx$')): get_idx_file,
-
-        ('POST', re.compile('/git-upload-pack$')): handle_service_request,
-        ('POST', re.compile('/git-receive-pack$')): handle_service_request,
+      ('GET', re.compile('/HEAD$')): get_text_file,
+      ('GET', re.compile('/info/refs$')): get_info_refs,
+      ('GET', re.compile('/objects/info/alternates$')): get_text_file,
+      ('GET', re.compile('/objects/info/http-alternates$')): get_text_file,
+      ('GET', re.compile('/objects/info/packs$')): get_info_packs,
+      ('GET', re.compile('/objects/([0-9a-f]{2})/([0-9a-f]{38})$')): get_loose_object,
+      ('GET', re.compile('/objects/pack/pack-([0-9a-f]{40})\\.pack$')): get_pack_file,
+      ('GET', re.compile('/objects/pack/pack-([0-9a-f]{40})\\.idx$')): get_idx_file,
+
+      ('POST', re.compile('/git-upload-pack$')): handle_service_request,
+      ('POST', re.compile('/git-receive-pack$')): handle_service_request,
     }
 
-    def __init__(self, backend):
+    def __init__(self, backend, dumb=False, handlers=None):
         self.backend = backend
+        self.dumb = dumb
+        self.handlers = handlers
 
     def __call__(self, environ, start_response):
         path = environ['PATH_INFO']
         method = environ['REQUEST_METHOD']
-        req = HTTPGitRequest(environ, start_response)
+        req = HTTPGitRequest(environ, start_response, dumb=self.dumb,
+                             handlers=self.handlers)
         # environ['QUERY_STRING'] has qs args
         handler = None
         for smethod, spath in self.services.iterkeys():

+ 1 - 1
setup.py

@@ -8,7 +8,7 @@ except ImportError:
     from distutils.core import setup, Extension
 from distutils.core import Distribution
 
-dulwich_version_string = '0.5.0'
+dulwich_version_string = '0.6.0'
 
 include_dirs = []
 # Windows MSVC support

Beberapa file tidak ditampilkan karena terlalu banyak file yang berubah dalam diff ini