Browse Source

New upstream release.

Jelmer Vernooij 15 years ago
parent
commit
f517a43131
69 changed files with 5743 additions and 1545 deletions
  1. 1 0
      .bzrignore
  2. 3 0
      .testr.conf
  3. 2 1
      AUTHORS
  4. 6 3
      HACKING
  5. 7 2
      Makefile
  6. 93 4
      NEWS
  7. 4 0
      README
  8. 4 3
      bin/dul-daemon
  9. 2 2
      bin/dul-web
  10. 27 35
      bin/dulwich
  11. 6 0
      debian/changelog
  12. 1 1
      dulwich/__init__.py
  13. 131 15
      dulwich/_objects.c
  14. 50 4
      dulwich/_pack.c
  15. 106 40
      dulwich/client.py
  16. 56 21
      dulwich/errors.py
  17. 82 0
      dulwich/fastexport.py
  18. 41 2
      dulwich/file.py
  19. 16 0
      dulwich/index.py
  20. 13 3
      dulwich/misc.py
  21. 128 67
      dulwich/object_store.py
  22. 662 250
      dulwich/objects.py
  23. 409 269
      dulwich/pack.py
  24. 67 5
      dulwich/patch.py
  25. 116 2
      dulwich/protocol.py
  26. 371 79
      dulwich/repo.py
  27. 251 143
      dulwich/server.py
  28. 30 0
      dulwich/tests/__init__.py
  29. 157 0
      dulwich/tests/compat/server_utils.py
  30. 179 0
      dulwich/tests/compat/test_client.py
  31. 73 0
      dulwich/tests/compat/test_pack.py
  32. 131 0
      dulwich/tests/compat/test_repository.py
  33. 78 0
      dulwich/tests/compat/test_server.py
  34. 116 0
      dulwich/tests/compat/test_web.py
  35. 196 0
      dulwich/tests/compat/utils.py
  36. BIN
      dulwich/tests/data/blobs/11/11111111111111111111111111111111111111
  37. 0 0
      dulwich/tests/data/blobs/6f/670c0fb53f9463760b7295fbb814e965fb20c8
  38. 0 0
      dulwich/tests/data/blobs/95/4a536f7819d40e6f637f849ee187dd10066349
  39. 0 0
      dulwich/tests/data/blobs/e6/9de29bb2d1d6434b8b29ae775ad8c2e48c5391
  40. 0 0
      dulwich/tests/data/commits/0d/89f20333fbb1d2f3a94da77f4981373d8f4310
  41. 0 0
      dulwich/tests/data/commits/5d/ac377bdded4c9aeb8dff595f0faeebcc8498cc
  42. 0 0
      dulwich/tests/data/commits/60/dacdc733de308bb77bb76ce0fb0f9b44c9769e
  43. 2 0
      dulwich/tests/data/repos/a.git/objects/28/237f4dc30d0d462658d6b937b08a0f0b6ef55a
  44. 3 0
      dulwich/tests/data/repos/a.git/objects/b0/931cadc54336e78a1d980420e3268903b57a50
  45. 3 0
      dulwich/tests/data/repos/a.git/packed-refs
  46. 1 0
      dulwich/tests/data/repos/a.git/refs/tags/mytag
  47. 3 0
      dulwich/tests/data/repos/refs.git/objects/3e/c9c43c84ff242e3ef4a9fc5bc111fd780a76a8
  48. 5 0
      dulwich/tests/data/repos/refs.git/objects/cd/a609072918d7b70057b6bef9f4c2537843fcfe
  49. 1 0
      dulwich/tests/data/repos/refs.git/packed-refs
  50. 1 0
      dulwich/tests/data/repos/refs.git/refs/tags/refs-0.2
  51. 99 0
      dulwich/tests/data/repos/server_new.export
  52. 57 0
      dulwich/tests/data/repos/server_old.export
  53. 0 0
      dulwich/tests/data/tags/71/033db03a03c6a36721efcf1968dd8f8e0cf023
  54. 0 0
      dulwich/tests/data/trees/70/c190eb48fa8bbb50ddc692a17b44cb781af7f6
  55. 11 5
      dulwich/tests/test_client.py
  56. 78 0
      dulwich/tests/test_fastexport.py
  57. 68 2
      dulwich/tests/test_file.py
  58. 31 11
      dulwich/tests/test_index.py
  59. 33 29
      dulwich/tests/test_object_store.py
  60. 420 104
      dulwich/tests/test_objects.py
  61. 212 86
      dulwich/tests/test_pack.py
  62. 88 0
      dulwich/tests/test_patch.py
  63. 90 7
      dulwich/tests/test_protocol.py
  64. 477 141
      dulwich/tests/test_repository.py
  65. 199 78
      dulwich/tests/test_server.py
  66. 69 55
      dulwich/tests/test_web.py
  67. 86 0
      dulwich/tests/utils.py
  68. 90 75
      dulwich/web.py
  69. 1 1
      setup.py

+ 1 - 0
.bzrignore

@@ -4,3 +4,4 @@ MANIFEST
 dist
 dist
 apidocs
 apidocs
 *,cover
 *,cover
+.testrepository

+ 3 - 0
.testr.conf

@@ -0,0 +1,3 @@
+[DEFAULT]
+test_command=PYTHONPATH=. python -m subunit.run $IDLIST
+test_id_list_default=dulwich.tests.test_suite

+ 2 - 1
AUTHORS

@@ -1,3 +1,4 @@
+Jelmer Vernooij <jelmer@samba.org>
 James Westby <jw+debian@jameswestby.net>
 James Westby <jw+debian@jameswestby.net>
 John Carr <john.carr@unrouted.co.uk>
 John Carr <john.carr@unrouted.co.uk>
-Jelmer Vernooij <jelmer@samba.org>
+Dave Borowitz <dborowitz@google.com>

+ 6 - 3
HACKING

@@ -1,5 +1,8 @@
 Please follow PEP8 with regard to coding style.
 Please follow PEP8 with regard to coding style.
 
 
-All functionality should be available in pure Python. Optional C implementations
-may be written for performance reasons, but should never replace the Python 
-implementation. The C implementations should follow the kernel/git coding style.
+All functionality should be available in pure Python. Optional C
+implementations may be written for performance reasons, but should never
+replace the Python implementation. The C implementations should follow the
+kernel/git coding style.
+
+Where possible please include updates to NEWS along with your improvements.

+ 7 - 2
Makefile

@@ -2,8 +2,9 @@ PYTHON = python
 SETUP = $(PYTHON) setup.py
 SETUP = $(PYTHON) setup.py
 PYDOCTOR ?= pydoctor
 PYDOCTOR ?= pydoctor
 TESTRUNNER = $(shell which nosetests)
 TESTRUNNER = $(shell which nosetests)
+TESTFLAGS =
 
 
-all: build 
+all: build
 
 
 doc:: pydoctor
 doc:: pydoctor
 
 
@@ -19,9 +20,13 @@ install::
 
 
 check:: build
 check:: build
 	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) dulwich
 	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) dulwich
+	which git > /dev/null && PYTHONPATH=. $(PYTHON) $(TESTRUNNER) $(TESTFLAGS) -i compat
 
 
 check-noextensions:: clean
 check-noextensions:: clean
-	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) dulwich
+	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) $(TESTFLAGS) dulwich
+
+check-compat:: build
+	PYTHONPATH=. $(PYTHON) $(TESTRUNNER) $(TESTFLAGS) -i compat
 
 
 clean::
 clean::
 	$(SETUP) clean --all
 	$(SETUP) clean --all

+ 93 - 4
NEWS

@@ -1,8 +1,96 @@
+0.6.0	2010-05-22
+
+note: This list is most likely incomplete for 0.6.0.
+
+ BUG FIXES
+ 
+  * Fix ReceivePackHandler to disallow removing refs without delete-refs.
+    (Dave Borowitz)
+
+  * Deal with capabilities required by the client, even if they 
+    can not be disabled in the server. (Dave Borowitz)
+
+  * Fix trailing newlines in generated patch files.
+    (Jelmer Vernooij)
+
+  * Implement RefsContainer.__contains__. (Jelmer Vernooij)
+
+  * Cope with \r in ref files on Windows. (
+	http://github.com/jelmer/dulwich/issues/#issue/13, Jelmer Vernooij)
+
+  * Fix GitFile breakage on Windows. (Anatoly Techtonik, #557585)
+
+  * Support packed ref deletion with no peeled refs. (Augie Fackler)
+
+  * Fix send pack when there is nothing to fetch. (Augie Fackler)
+
+  * Fix fetch if no progress function is specified. (Augie Fackler)
+
+  * Allow double-staging of files that are deleted in the index. 
+    (Dave Borowitz)
+
+  * Fix RefsContainer.add_if_new to support dangling symrefs.
+    (Dave Borowitz)
+
+  * Non-existant index files in non-bare repositories are now treated as 
+    empty. (Dave Borowitz)
+
+  * Always update ShaFile.id when the contents of the object get changed. 
+    (Jelmer Vernooij)
+
+  * Various Python2.4-compatibility fixes. (Dave Borowitz)
+
+  * Fix thin pack handling. (Dave Borowitz)
+ 
+ FEATURES
+
+  * Add include-tag capability to server. (Dave Borowitz)
+
+  * New dulwich.fastexport module that can generate fastexport 
+    streams. (Jelmer Vernooij)
+
+  * Implemented BaseRepo.__contains__. (Jelmer Vernooij)
+
+  * Add __setitem__ to DictRefsContainer. (Dave Borowitz)
+
+  * Overall improvements checking Git objects. (Dave Borowitz)
+
+  * Packs are now verified while they are received. (Dave Borowitz)
+
+ TESTS
+
+  * Add framework for testing compatibility with C Git. (Dave Borowitz)
+
+  * Add various tests for the use of non-bare repositories. (Dave Borowitz)
+
+  * Cope with diffstat not being available on all platforms. 
+    (Tay Ray Chuan, Jelmer Vernooij)
+
+  * Add make_object and make_commit convenience functions to test utils.
+    (Dave Borowitz)
+
+ API BREAKAGES
+
+  * The 'committer' and 'message' arguments to Repo.do_commit() have 
+    been swapped. 'committer' is now optional. (Jelmer Vernooij)
+
+  * Repo.get_blob, Repo.commit, Repo.tag and Repo.tree are now deprecated.
+    (Jelmer Vernooij)
+
+  * RefsContainer.set_ref() was renamed to RefsContainer.set_symbolic_ref(),
+    for clarity. (Jelmer Vernooij)
+
+ API CHANGES
+
+  * The primary serialization APIs in dulwich.objects now work 
+    with chunks of strings rather than with full-text strings. 
+    (Jelmer Vernooij)
+
 0.5.0	2010-03-03
 0.5.0	2010-03-03
 
 
  BUG FIXES
  BUG FIXES
 
 
-  * Support custom fields in commits.
+  * Support custom fields in commits (readonly). (Jelmer Vernooij)
 
 
   * Improved ref handling. (Dave Borowitz)
   * Improved ref handling. (Dave Borowitz)
 
 
@@ -31,13 +119,14 @@
 
 
  FEATURES
  FEATURES
 
 
-  * Add ObjectStore.iter_tree_contents()
+  * Add ObjectStore.iter_tree_contents(). (Jelmer Vernooij)
 
 
-  * Add Index.changes_from_tree()
+  * Add Index.changes_from_tree(). (Jelmer Vernooij)
 
 
-  * Add ObjectStore.tree_changes()
+  * Add ObjectStore.tree_changes(). (Jelmer Vernooij)
 
 
   * Add functionality for writing patches in dulwich.patch.
   * Add functionality for writing patches in dulwich.patch.
+    (Jelmer Vernooij)
 
 
 0.4.0	2009-10-07
 0.4.0	2009-10-07
 
 

+ 4 - 0
README

@@ -21,3 +21,7 @@ The project is named after the part of London that Mr. and Mrs. Git live in
 in the particular Monty Python sketch. It is based on the Python-Git module 
 in the particular Monty Python sketch. It is based on the Python-Git module 
 that James Westby <jw+debian@jameswestby.net> released in 2007 and now 
 that James Westby <jw+debian@jameswestby.net> released in 2007 and now 
 maintained by Jelmer Vernooij and John Carr.
 maintained by Jelmer Vernooij and John Carr.
+
+Please file bugs in the Dulwich project on Launchpad: 
+
+https://bugs.launchpad.net/dulwich/+filebug

+ 4 - 3
bin/dul-daemon

@@ -19,13 +19,14 @@
 
 
 import sys
 import sys
 from dulwich.repo import Repo
 from dulwich.repo import Repo
-from dulwich.server import GitBackend, TCPGitServer
+from dulwich.server import DictBackend, TCPGitServer
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    gitdir = None
     if len(sys.argv) > 1:
     if len(sys.argv) > 1:
         gitdir = sys.argv[1]
         gitdir = sys.argv[1]
+    else:
+        gitdir = "."
 
 
-    backend = GitBackend(Repo(gitdir))
+    backend = DictBackend({"/": Repo(gitdir)})
     server = TCPGitServer(backend, 'localhost')
     server = TCPGitServer(backend, 'localhost')
     server.serve_forever()
     server.serve_forever()

+ 2 - 2
bin/dul-web

@@ -20,7 +20,7 @@
 import os
 import os
 import sys
 import sys
 from dulwich.repo import Repo
 from dulwich.repo import Repo
-from dulwich.server import GitBackend
+from dulwich.server import DictBackend
 from dulwich.web import HTTPGitApplication
 from dulwich.web import HTTPGitApplication
 from wsgiref.simple_server import make_server
 from wsgiref.simple_server import make_server
 
 
@@ -30,7 +30,7 @@ if __name__ == "__main__":
     else:
     else:
         gitdir = os.getcwd()
         gitdir = os.getcwd()
 
 
-    backend = GitBackend(Repo(gitdir))
+    backend = DictBackend({"/": Repo(gitdir)})
     app = HTTPGitApplication(backend)
     app = HTTPGitApplication(backend)
     # TODO: allow serving on other ports via command-line flag
     # TODO: allow serving on other ports via command-line flag
     server = make_server('', 8000, app)
     server = make_server('', 8000, app)

+ 27 - 35
bin/dulwich

@@ -1,5 +1,5 @@
 #!/usr/bin/env python
 #!/usr/bin/env python
-# dul-daemon - Simple git smart server client
+# dulwich - Simple command-line interface to Dulwich
 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
 # 
 # 
 # This program is free software; you can redistribute it and/or
 # This program is free software; you can redistribute it and/or
@@ -17,21 +17,25 @@
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # MA  02110-1301, USA.
 # MA  02110-1301, USA.
 
 
+"""Simple command-line interface to Dulwich>
+
+This is a very simple command-line wrapper for Dulwich. It is by 
+no means intended to be a full-blown Git command-line interface but just 
+a way to test Dulwich.
+"""
+
+import os
 import sys
 import sys
 from getopt import getopt
 from getopt import getopt
 
 
-def get_transport_and_path(uri):
-    from dulwich.client import TCPGitClient, SSHGitClient, SubprocessGitClient
-    for handler, transport in (("git://", TCPGitClient), ("git+ssh://", SSHGitClient)):
-        if uri.startswith(handler):
-            host, path = uri[len(handler):].split("/", 1)
-            return transport(host), "/"+path
-    # if its not git or git+ssh, try a local url..
-    return SubprocessGitClient(), uri
+from dulwich.client import get_transport_and_path
+from dulwich.errors import ApplyDeltaError
+from dulwich.index import Index
+from dulwich.pack import Pack, sha_to_hex
+from dulwich.repo import Repo
 
 
 
 
 def cmd_fetch_pack(args):
 def cmd_fetch_pack(args):
-    from dulwich.repo import Repo
     opts, args = getopt(args, "", ["all"])
     opts, args = getopt(args, "", ["all"])
     opts = dict(opts)
     opts = dict(opts)
     client, path = get_transport_and_path(args.pop(0))
     client, path = get_transport_and_path(args.pop(0))
@@ -45,9 +49,12 @@ def cmd_fetch_pack(args):
 
 
 
 
 def cmd_log(args):
 def cmd_log(args):
-    from dulwich.repo import Repo
     opts, args = getopt(args, "", [])
     opts, args = getopt(args, "", [])
-    r = Repo(".")
+    if len(args) > 0:
+        path = args.pop(0)
+    else:
+        path = "."
+    r = Repo(path)
     todo = [r.head()]
     todo = [r.head()]
     done = set()
     done = set()
     while todo:
     while todo:
@@ -56,7 +63,7 @@ def cmd_log(args):
         if sha in done:
         if sha in done:
             continue
             continue
         done.add(sha)
         done.add(sha)
-        commit = r.commit(sha)
+        commit = r[sha]
         print "-" * 50
         print "-" * 50
         print "commit: %s" % sha
         print "commit: %s" % sha
         if len(commit.parents) > 1:
         if len(commit.parents) > 1:
@@ -70,11 +77,6 @@ def cmd_log(args):
 
 
 
 
 def cmd_dump_pack(args):
 def cmd_dump_pack(args):
-    from dulwich.errors import ApplyDeltaError
-    from dulwich.pack import Pack, sha_to_hex
-    import os
-    import sys
-
     opts, args = getopt(args, "", [])
     opts, args = getopt(args, "", [])
 
 
     if args == []:
     if args == []:
@@ -98,8 +100,6 @@ def cmd_dump_pack(args):
 
 
 
 
 def cmd_dump_index(args):
 def cmd_dump_index(args):
-    from dulwich.index import Index
-
     opts, args = getopt(args, "", [])
     opts, args = getopt(args, "", [])
 
 
     if args == []:
     if args == []:
@@ -114,8 +114,6 @@ def cmd_dump_index(args):
 
 
 
 
 def cmd_init(args):
 def cmd_init(args):
-    from dulwich.repo import Repo
-    import os
     opts, args = getopt(args, "", ["--bare"])
     opts, args = getopt(args, "", ["--bare"])
     opts = dict(opts)
     opts = dict(opts)
 
 
@@ -134,16 +132,13 @@ def cmd_init(args):
 
 
 
 
 def cmd_clone(args):
 def cmd_clone(args):
-    from dulwich.repo import Repo
-    import os
-    import sys
     opts, args = getopt(args, "", [])
     opts, args = getopt(args, "", [])
     opts = dict(opts)
     opts = dict(opts)
 
 
     if args == []:
     if args == []:
         print "usage: dulwich clone host:path [PATH]"
         print "usage: dulwich clone host:path [PATH]"
         sys.exit(1)
         sys.exit(1)
-        client, host_path = get_transport_and_path(args.pop(0))
+    client, host_path = get_transport_and_path(args.pop(0))
 
 
     if len(args) > 0:
     if len(args) > 0:
         path = args.pop(0)
         path = args.pop(0)
@@ -152,18 +147,14 @@ def cmd_clone(args):
 
 
     if not os.path.exists(path):
     if not os.path.exists(path):
         os.mkdir(path)
         os.mkdir(path)
-    Repo.init(path)
-    r = Repo(path)
-    graphwalker = r.get_graph_walker()
-    f, commit = r.object_store.add_pack()
-    client.fetch_pack(host_path, r.object_store.determine_wants_all, 
-                      graphwalker, f.write, sys.stdout.write)
-    commit()
+    r = Repo.init(path)
+    remote_refs = client.fetch(host_path, r,
+        determine_wants=r.object_store.determine_wants_all,
+        progress=sys.stdout.write)
+    r["HEAD"] = remote_refs["HEAD"]
 
 
 
 
 def cmd_commit(args):
 def cmd_commit(args):
-    from dulwich.repo import Repo
-    import os
     opts, args = getopt(args, "", ["message"])
     opts, args = getopt(args, "", ["message"])
     opts = dict(opts)
     opts = dict(opts)
     r = Repo(".")
     r = Repo(".")
@@ -173,6 +164,7 @@ def cmd_commit(args):
                           os.getenv("GIT_AUTHOR_EMAIL"))
                           os.getenv("GIT_AUTHOR_EMAIL"))
     r.do_commit(committer=committer, author=author, message=opts["--message"])
     r.do_commit(committer=committer, author=author, message=opts["--message"])
 
 
+
 commands = {
 commands = {
     "commit": cmd_commit,
     "commit": cmd_commit,
     "fetch-pack": cmd_fetch_pack,
     "fetch-pack": cmd_fetch_pack,

+ 6 - 0
debian/changelog

@@ -1,3 +1,9 @@
+dulwich (0.6.0-1) unstable; urgency=low
+
+  * New upstream release.
+
+ -- Jelmer Vernooij <jelmer@debian.org>  Sat, 22 May 2010 23:06:20 +0200
+
 dulwich (0.5.0-1) unstable; urgency=low
 dulwich (0.5.0-1) unstable; urgency=low
 
 
   * New upstream release.
   * New upstream release.

+ 1 - 1
dulwich/__init__.py

@@ -27,4 +27,4 @@ import protocol
 import repo
 import repo
 import server
 import server
 
 
-__version__ = (0, 5, 0)
+__version__ = (0, 6, 0)

+ 131 - 15
dulwich/_objects.c

@@ -18,6 +18,12 @@
  */
  */
 
 
 #include <Python.h>
 #include <Python.h>
+#include <stdlib.h>
+#include <sys/stat.h>
+
+#if (PY_VERSION_HEX < 0x02050000)
+typedef int Py_ssize_t;
+#endif
 
 
 #define bytehex(x) (((x)<0xa)?('0'+(x)):('a'-0xa+(x)))
 #define bytehex(x) (((x)<0xa)?('0'+(x)):('a'-0xa+(x)))
 
 
@@ -35,33 +41,37 @@ static PyObject *sha_to_pyhex(const unsigned char *sha)
 
 
 static PyObject *py_parse_tree(PyObject *self, PyObject *args)
 static PyObject *py_parse_tree(PyObject *self, PyObject *args)
 {
 {
-	char *text, *end;
+	char *text, *start, *end;
 	int len, namelen;
 	int len, namelen;
 	PyObject *ret, *item, *name;
 	PyObject *ret, *item, *name;
 
 
 	if (!PyArg_ParseTuple(args, "s#", &text, &len))
 	if (!PyArg_ParseTuple(args, "s#", &text, &len))
 		return NULL;
 		return NULL;
 
 
-	ret = PyDict_New();
+	/* TODO: currently this returns a list; if memory usage is a concern,
+	* consider rewriting as a custom iterator object */
+	ret = PyList_New(0);
+
 	if (ret == NULL) {
 	if (ret == NULL) {
 		return NULL;
 		return NULL;
 	}
 	}
 
 
+	start = text;
 	end = text + len;
 	end = text + len;
 
 
-    while (text < end) {
-        long mode;
+	while (text < end) {
+		long mode;
 		mode = strtol(text, &text, 8);
 		mode = strtol(text, &text, 8);
 
 
 		if (*text != ' ') {
 		if (*text != ' ') {
-			PyErr_SetString(PyExc_RuntimeError, "Expected space");
+			PyErr_SetString(PyExc_ValueError, "Expected space");
 			Py_DECREF(ret);
 			Py_DECREF(ret);
 			return NULL;
 			return NULL;
 		}
 		}
 
 
 		text++;
 		text++;
 
 
-        namelen = strlen(text);
+		namelen = strnlen(text, len - (text - start));
 
 
 		name = PyString_FromStringAndSize(text, namelen);
 		name = PyString_FromStringAndSize(text, namelen);
 		if (name == NULL) {
 		if (name == NULL) {
@@ -69,28 +79,134 @@ static PyObject *py_parse_tree(PyObject *self, PyObject *args)
 			return NULL;
 			return NULL;
 		}
 		}
 
 
-        item = Py_BuildValue("(lN)", mode, sha_to_pyhex((unsigned char *)text+namelen+1));
-        if (item == NULL) {
-            Py_DECREF(ret);
+		if (text + namelen + 20 >= end) {
+			PyErr_SetString(PyExc_ValueError, "SHA truncated");
+			Py_DECREF(ret);
 			Py_DECREF(name);
 			Py_DECREF(name);
-            return NULL;
-        }
-		if (PyDict_SetItem(ret, name, item) == -1) {
+			return NULL;
+		}
+
+		item = Py_BuildValue("(NlN)", name, mode,
+							 sha_to_pyhex((unsigned char *)text+namelen+1));
+		if (item == NULL) {
+			Py_DECREF(ret);
+			Py_DECREF(name);
+			return NULL;
+		}
+		if (PyList_Append(ret, item) == -1) {
 			Py_DECREF(ret);
 			Py_DECREF(ret);
 			Py_DECREF(item);
 			Py_DECREF(item);
 			return NULL;
 			return NULL;
 		}
 		}
-		Py_DECREF(name);
 		Py_DECREF(item);
 		Py_DECREF(item);
 
 
 		text += namelen+21;
 		text += namelen+21;
-    }
+	}
+
+	return ret;
+}
+
+struct tree_item {
+	const char *name;
+	int mode;
+	PyObject *tuple;
+};
+
+int cmp_tree_item(const void *_a, const void *_b)
+{
+	const struct tree_item *a = _a, *b = _b;
+	const char *remain_a, *remain_b;
+	int ret, common;
+	if (strlen(a->name) > strlen(b->name)) {
+		common = strlen(b->name);
+		remain_a = a->name + common;
+		remain_b = (S_ISDIR(b->mode)?"/":"");
+	} else if (strlen(b->name) > strlen(a->name)) { 
+		common = strlen(a->name);
+		remain_a = (S_ISDIR(a->mode)?"/":"");
+		remain_b = b->name + common;
+	} else { /* strlen(a->name) == strlen(b->name) */
+		common = 0;
+		remain_a = a->name;
+		remain_b = b->name;
+	}
+	ret = strncmp(a->name, b->name, common);
+	if (ret != 0)
+		return ret;
+	return strcmp(remain_a, remain_b);
+}
+
+static PyObject *py_sorted_tree_items(PyObject *self, PyObject *entries)
+{
+	struct tree_item *qsort_entries;
+	int num, i;
+	PyObject *ret;
+	Py_ssize_t pos = 0; 
+	PyObject *key, *value;
+
+	if (!PyDict_Check(entries)) {
+		PyErr_SetString(PyExc_TypeError, "Argument not a dictionary");
+		return NULL;
+	}
+
+	num = PyDict_Size(entries);
+	qsort_entries = malloc(num * sizeof(struct tree_item));
+	if (qsort_entries == NULL) {
+		PyErr_NoMemory();
+		return NULL;
+	}
+
+	i = 0;
+	while (PyDict_Next(entries, &pos, &key, &value)) {
+		PyObject *py_mode, *py_int_mode, *py_sha;
+		
+		if (PyTuple_Size(value) != 2) {
+			PyErr_SetString(PyExc_ValueError, "Tuple has invalid size");
+			free(qsort_entries);
+			return NULL;
+		}
+
+		py_mode = PyTuple_GET_ITEM(value, 0);
+		py_int_mode = PyNumber_Int(py_mode);
+		if (!py_int_mode) {
+			PyErr_SetString(PyExc_TypeError, "Mode is not an integral type");
+			free(qsort_entries);
+			return NULL;
+		}
+
+		py_sha = PyTuple_GET_ITEM(value, 1);
+		if (!PyString_CheckExact(key)) {
+			PyErr_SetString(PyExc_TypeError, "Name is not a string");
+			free(qsort_entries);
+			return NULL;
+		}
+		qsort_entries[i].name = PyString_AS_STRING(key);
+		qsort_entries[i].mode = PyInt_AS_LONG(py_mode);
+		qsort_entries[i].tuple = PyTuple_Pack(3, key, py_mode, py_sha);
+		i++;
+	}
+
+	qsort(qsort_entries, num, sizeof(struct tree_item), cmp_tree_item);
+
+	ret = PyList_New(num);
+	if (ret == NULL) {
+		free(qsort_entries);
+		PyErr_NoMemory();
+		return NULL;
+	}
+
+	for (i = 0; i < num; i++) {
+		PyList_SET_ITEM(ret, i, qsort_entries[i].tuple);
+	}
+
+	free(qsort_entries);
 
 
-    return ret;
+	return ret;
 }
 }
 
 
 static PyMethodDef py_objects_methods[] = {
 static PyMethodDef py_objects_methods[] = {
 	{ "parse_tree", (PyCFunction)py_parse_tree, METH_VARARGS, NULL },
 	{ "parse_tree", (PyCFunction)py_parse_tree, METH_VARARGS, NULL },
+	{ "sorted_tree_items", (PyCFunction)py_sorted_tree_items, METH_O, NULL },
 	{ NULL, NULL, 0, NULL }
 	{ NULL, NULL, 0, NULL }
 };
 };
 
 

+ 50 - 4
dulwich/_pack.c

@@ -47,6 +47,29 @@ static size_t get_delta_header_size(uint8_t *delta, int *index, int length)
 	return size;
 	return size;
 }
 }
 
 
+static PyObject *py_chunked_as_string(PyObject *py_buf)
+{
+	if (PyList_Check(py_buf)) {
+		PyObject *sep = PyString_FromString("");
+		if (sep == NULL) {
+			PyErr_NoMemory();
+			return NULL;
+		}
+		py_buf = _PyString_Join(sep, py_buf);
+		Py_DECREF(sep);
+		if (py_buf == NULL) {
+			PyErr_NoMemory();
+			return NULL;
+		}
+	} else if (PyString_Check(py_buf)) {
+		Py_INCREF(py_buf);
+	} else {
+		PyErr_SetString(PyExc_TypeError,
+			"src_buf is not a string or a list of chunks");
+		return NULL;
+	}
+    return py_buf;
+}
 
 
 static PyObject *py_apply_delta(PyObject *self, PyObject *args)
 static PyObject *py_apply_delta(PyObject *self, PyObject *args)
 {
 {
@@ -56,23 +79,42 @@ static PyObject *py_apply_delta(PyObject *self, PyObject *args)
 	size_t outindex = 0;
 	size_t outindex = 0;
 	int index;
 	int index;
 	uint8_t *out;
 	uint8_t *out;
-	PyObject *ret;
+	PyObject *ret, *py_src_buf, *py_delta;
 
 
-	if (!PyArg_ParseTuple(args, "s#s#", (uint8_t *)&src_buf, &src_buf_len, 
-						  (uint8_t *)&delta, &delta_len))
+	if (!PyArg_ParseTuple(args, "OO", &py_src_buf, &py_delta))
 		return NULL;
 		return NULL;
 
 
+    py_src_buf = py_chunked_as_string(py_src_buf);
+    if (py_src_buf == NULL)
+        return NULL;
+
+    py_delta = py_chunked_as_string(py_delta);
+    if (py_delta == NULL) {
+        Py_DECREF(py_src_buf);
+        return NULL;
+    }
+
+	src_buf = (uint8_t *)PyString_AS_STRING(py_src_buf);
+	src_buf_len = PyString_GET_SIZE(py_src_buf);
+
+    delta = (uint8_t *)PyString_AS_STRING(py_delta);
+    delta_len = PyString_GET_SIZE(py_delta);
+
     index = 0;
     index = 0;
     src_size = get_delta_header_size(delta, &index, delta_len);
     src_size = get_delta_header_size(delta, &index, delta_len);
     if (src_size != src_buf_len) {
     if (src_size != src_buf_len) {
 		PyErr_Format(PyExc_ValueError, 
 		PyErr_Format(PyExc_ValueError, 
 			"Unexpected source buffer size: %lu vs %d", src_size, src_buf_len);
 			"Unexpected source buffer size: %lu vs %d", src_size, src_buf_len);
+		Py_DECREF(py_src_buf);
+		Py_DECREF(py_delta);
 		return NULL;
 		return NULL;
 	}
 	}
     dest_size = get_delta_header_size(delta, &index, delta_len);
     dest_size = get_delta_header_size(delta, &index, delta_len);
 	ret = PyString_FromStringAndSize(NULL, dest_size);
 	ret = PyString_FromStringAndSize(NULL, dest_size);
 	if (ret == NULL) {
 	if (ret == NULL) {
 		PyErr_NoMemory();
 		PyErr_NoMemory();
+		Py_DECREF(py_src_buf);
+		Py_DECREF(py_delta);
 		return NULL;
 		return NULL;
 	}
 	}
 	out = (uint8_t *)PyString_AsString(ret);
 	out = (uint8_t *)PyString_AsString(ret);
@@ -111,9 +153,13 @@ static PyObject *py_apply_delta(PyObject *self, PyObject *args)
 		} else {
 		} else {
 			PyErr_SetString(PyExc_ValueError, "Invalid opcode 0");
 			PyErr_SetString(PyExc_ValueError, "Invalid opcode 0");
 			Py_DECREF(ret);
 			Py_DECREF(ret);
+            Py_DECREF(py_delta);
+			Py_DECREF(py_src_buf);
 			return NULL;
 			return NULL;
 		}
 		}
 	}
 	}
+	Py_DECREF(py_src_buf);
+    Py_DECREF(py_delta);
     
     
     if (index != delta_len) {
     if (index != delta_len) {
 		PyErr_SetString(PyExc_ValueError, "delta not empty");
 		PyErr_SetString(PyExc_ValueError, "delta not empty");
@@ -127,7 +173,7 @@ static PyObject *py_apply_delta(PyObject *self, PyObject *args)
 		return NULL;
 		return NULL;
 	}
 	}
 
 
-    return ret;
+    return Py_BuildValue("[N]", ret);
 }
 }
 
 
 static PyObject *py_bisect_find_sha(PyObject *self, PyObject *args)
 static PyObject *py_bisect_find_sha(PyObject *self, PyObject *args)

+ 106 - 40
dulwich/client.py

@@ -28,10 +28,13 @@ import subprocess
 
 
 from dulwich.errors import (
 from dulwich.errors import (
     ChecksumMismatch,
     ChecksumMismatch,
+    SendPackError,
+    UpdateRefsError,
     )
     )
 from dulwich.protocol import (
 from dulwich.protocol import (
     Protocol,
     Protocol,
     TCP_GIT_PORT,
     TCP_GIT_PORT,
+    ZERO_SHA,
     extract_capabilities,
     extract_capabilities,
     )
     )
 from dulwich.pack import (
 from dulwich.pack import (
@@ -43,16 +46,19 @@ def _fileno_can_read(fileno):
     """Check if a file descriptor is readable."""
     """Check if a file descriptor is readable."""
     return len(select.select([fileno], [], [], 0)[0]) > 0
     return len(select.select([fileno], [], [], 0)[0]) > 0
 
 
+COMMON_CAPABILITIES = ["ofs-delta"]
+FETCH_CAPABILITIES = ["multi_ack", "side-band-64k"] + COMMON_CAPABILITIES
+SEND_CAPABILITIES = ['report-status'] + COMMON_CAPABILITIES
 
 
-CAPABILITIES = ["multi_ack", "side-band-64k", "ofs-delta"]
-
-
+# TODO(durin42): this doesn't correctly degrade if the server doesn't
+# support some capabilities. This should work properly with servers
+# that don't support side-band-64k and multi_ack.
 class GitClient(object):
 class GitClient(object):
     """Git smart server client.
     """Git smart server client.
 
 
     """
     """
 
 
-    def __init__(self, can_read, read, write, thin_packs=True, 
+    def __init__(self, can_read, read, write, thin_packs=True,
         report_activity=None):
         report_activity=None):
         """Create a new GitClient instance.
         """Create a new GitClient instance.
 
 
@@ -66,12 +72,10 @@ class GitClient(object):
         """
         """
         self.proto = Protocol(read, write, report_activity)
         self.proto = Protocol(read, write, report_activity)
         self._can_read = can_read
         self._can_read = can_read
-        self._capabilities = list(CAPABILITIES)
+        self._fetch_capabilities = list(FETCH_CAPABILITIES)
+        self._send_capabilities = list(SEND_CAPABILITIES)
         if thin_packs:
         if thin_packs:
-            self._capabilities.append("thin-pack")
-
-    def capabilities(self):
-        return " ".join(self._capabilities)
+            self._fetch_capabilities.append("thin-pack")
 
 
     def read_refs(self):
     def read_refs(self):
         server_capabilities = None
         server_capabilities = None
@@ -84,44 +88,90 @@ class GitClient(object):
             refs[ref] = sha
             refs[ref] = sha
         return refs, server_capabilities
         return refs, server_capabilities
 
 
+    # TODO(durin42): add side-band-64k capability support here and advertise it
     def send_pack(self, path, determine_wants, generate_pack_contents):
     def send_pack(self, path, determine_wants, generate_pack_contents):
         """Upload a pack to a remote repository.
         """Upload a pack to a remote repository.
 
 
         :param path: Repository path
         :param path: Repository path
-        :param generate_pack_contents: Function that can return the shas of the 
+        :param generate_pack_contents: Function that can return the shas of the
             objects to upload.
             objects to upload.
+
+        :raises SendPackError: if server rejects the pack data
+        :raises UpdateRefsError: if the server supports report-status
+                                 and rejects ref updates
         """
         """
         old_refs, server_capabilities = self.read_refs()
         old_refs, server_capabilities = self.read_refs()
+        if 'report-status' not in server_capabilities:
+            self._send_capabilities.remove('report-status')
         new_refs = determine_wants(old_refs)
         new_refs = determine_wants(old_refs)
         if not new_refs:
         if not new_refs:
             self.proto.write_pkt_line(None)
             self.proto.write_pkt_line(None)
             return {}
             return {}
         want = []
         want = []
-        have = [x for x in old_refs.values() if not x == "0" * 40]
+        have = [x for x in old_refs.values() if not x == ZERO_SHA]
         sent_capabilities = False
         sent_capabilities = False
         for refname in set(new_refs.keys() + old_refs.keys()):
         for refname in set(new_refs.keys() + old_refs.keys()):
-            old_sha1 = old_refs.get(refname, "0" * 40)
-            new_sha1 = new_refs.get(refname, "0" * 40)
+            old_sha1 = old_refs.get(refname, ZERO_SHA)
+            new_sha1 = new_refs.get(refname, ZERO_SHA)
             if old_sha1 != new_sha1:
             if old_sha1 != new_sha1:
                 if sent_capabilities:
                 if sent_capabilities:
-                    self.proto.write_pkt_line("%s %s %s" % (old_sha1, new_sha1, refname))
+                    self.proto.write_pkt_line("%s %s %s" % (old_sha1, new_sha1,
+                                                            refname))
                 else:
                 else:
-                    self.proto.write_pkt_line("%s %s %s\0%s" % (old_sha1, new_sha1, refname, self.capabilities()))
+                    self.proto.write_pkt_line(
+                      "%s %s %s\0%s" % (old_sha1, new_sha1, refname,
+                                        ' '.join(self._send_capabilities)))
                     sent_capabilities = True
                     sent_capabilities = True
-            if not new_sha1 in (have, "0" * 40):
+            if new_sha1 not in have and new_sha1 != ZERO_SHA:
                 want.append(new_sha1)
                 want.append(new_sha1)
         self.proto.write_pkt_line(None)
         self.proto.write_pkt_line(None)
         if not want:
         if not want:
             return new_refs
             return new_refs
         objects = generate_pack_contents(have, want)
         objects = generate_pack_contents(have, want)
-        (entries, sha) = write_pack_data(self.proto.write_file(), objects, 
+        (entries, sha) = write_pack_data(self.proto.write_file(), objects,
                                          len(objects))
                                          len(objects))
-        
-        # read the final confirmation sha
-        client_sha = self.proto.read(20)
-        if not client_sha in (None, "", sha):
-            raise ChecksumMismatch(sha, client_sha)
-            
+
+        if 'report-status' in self._send_capabilities:
+            unpack = self.proto.read_pkt_line().strip()
+            if unpack != 'unpack ok':
+                st = True
+                # flush remaining error data
+                while st is not None:
+                    st = self.proto.read_pkt_line()
+                raise SendPackError(unpack)
+            statuses = []
+            errs = False
+            ref_status = self.proto.read_pkt_line()
+            while ref_status:
+                ref_status = ref_status.strip()
+                statuses.append(ref_status)
+                if not ref_status.startswith('ok '):
+                    errs = True
+                ref_status = self.proto.read_pkt_line()
+
+            if errs:
+                ref_status = {}
+                ok = set()
+                for status in statuses:
+                    if ' ' not in status:
+                        # malformed response, move on to the next one
+                        continue
+                    status, ref = status.split(' ', 1)
+
+                    if status == 'ng':
+                        if ' ' in ref:
+                            ref, status = ref.split(' ', 1)
+                    else:
+                        ok.add(ref)
+                    ref_status[ref] = status
+                raise UpdateRefsError('%s failed to update' %
+                                      ', '.join([ref for ref in ref_status
+                                                 if ref not in ok]),
+                                      ref_status=ref_status)
+        # wait for EOF before returning
+        data = self.proto.read()
+        if data:
+            raise SendPackError('Unexpected response %r' % data)
         return new_refs
         return new_refs
 
 
     def fetch(self, path, target, determine_wants=None, progress=None):
     def fetch(self, path, target, determine_wants=None, progress=None):
@@ -129,7 +179,7 @@ class GitClient(object):
 
 
         :param path: Path to fetch from
         :param path: Path to fetch from
         :param target: Target repository to fetch into
         :param target: Target repository to fetch into
-        :param determine_wants: Optional function to determine what refs 
+        :param determine_wants: Optional function to determine what refs
             to fetch
             to fetch
         :param progress: Optional progress function
         :param progress: Optional progress function
         :return: remote refs
         :return: remote refs
@@ -138,8 +188,8 @@ class GitClient(object):
             determine_wants = target.object_store.determine_wants_all
             determine_wants = target.object_store.determine_wants_all
         f, commit = target.object_store.add_pack()
         f, commit = target.object_store.add_pack()
         try:
         try:
-            return self.fetch_pack(path, determine_wants, target.graph_walker, 
-                                   f.write, progress)
+            return self.fetch_pack(path, determine_wants,
+                target.get_graph_walker(), f.write, progress)
         finally:
         finally:
             commit()
             commit()
 
 
@@ -158,7 +208,8 @@ class GitClient(object):
             self.proto.write_pkt_line(None)
             self.proto.write_pkt_line(None)
             return refs
             return refs
         assert isinstance(wants, list) and type(wants[0]) == str
         assert isinstance(wants, list) and type(wants[0]) == str
-        self.proto.write_pkt_line("want %s %s\n" % (wants[0], self.capabilities()))
+        self.proto.write_pkt_line("want %s %s\n" % (
+            wants[0], ' '.join(self._fetch_capabilities)))
         for want in wants[1:]:
         for want in wants[1:]:
             self.proto.write_pkt_line("want %s\n" % want)
             self.proto.write_pkt_line("want %s\n" % want)
         self.proto.write_pkt_line(None)
         self.proto.write_pkt_line(None)
@@ -181,13 +232,16 @@ class GitClient(object):
             if len(parts) < 3 or parts[2] != "continue":
             if len(parts) < 3 or parts[2] != "continue":
                 break
                 break
             pkt = self.proto.read_pkt_line()
             pkt = self.proto.read_pkt_line()
+        # TODO(durin42): this is broken if the server didn't support the
+        # side-band-64k capability.
         for pkt in self.proto.read_pkt_seq():
         for pkt in self.proto.read_pkt_seq():
             channel = ord(pkt[0])
             channel = ord(pkt[0])
             pkt = pkt[1:]
             pkt = pkt[1:]
             if channel == 1:
             if channel == 1:
                 pack_data(pkt)
                 pack_data(pkt)
             elif channel == 2:
             elif channel == 2:
-                progress(pkt)
+                if progress is not None:
+                    progress(pkt)
             else:
             else:
                 raise AssertionError("Invalid sideband channel %d" % channel)
                 raise AssertionError("Invalid sideband channel %d" % channel)
         return refs
         return refs
@@ -216,9 +270,9 @@ class TCPGitClient(GitClient):
 
 
     def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress):
     def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress):
         """Fetch a pack from the remote host.
         """Fetch a pack from the remote host.
-        
+
         :param path: Path of the reposiutory on the remote host
         :param path: Path of the reposiutory on the remote host
-        :param determine_wants: Callback that receives available refs dict and 
+        :param determine_wants: Callback that receives available refs dict and
             should return list of sha's to fetch.
             should return list of sha's to fetch.
         :param graph_walker: GraphWalker instance used to find missing shas
         :param graph_walker: GraphWalker instance used to find missing shas
         :param pack_data: Callback for writing pack data
         :param pack_data: Callback for writing pack data
@@ -254,18 +308,18 @@ class SubprocessGitClient(GitClient):
 
 
         :param path: Path to the git repository on the server
         :param path: Path to the git repository on the server
         :param changed_refs: Dictionary with new values for the refs
         :param changed_refs: Dictionary with new values for the refs
-        :param generate_pack_contents: Function that returns an iterator over 
+        :param generate_pack_contents: Function that returns an iterator over
             objects to send
             objects to send
         """
         """
         client = self._connect("git-receive-pack", path)
         client = self._connect("git-receive-pack", path)
         return client.send_pack(path, changed_refs, generate_pack_contents)
         return client.send_pack(path, changed_refs, generate_pack_contents)
 
 
-    def fetch_pack(self, path, determine_wants, graph_walker, pack_data, 
+    def fetch_pack(self, path, determine_wants, graph_walker, pack_data,
         progress):
         progress):
         """Retrieve a pack from the server
         """Retrieve a pack from the server
 
 
         :param path: Path to the git repository on the server
         :param path: Path to the git repository on the server
-        :param determine_wants: Function that receives existing refs 
+        :param determine_wants: Function that receives existing refs
             on the server and returns a list of desired shas
             on the server and returns a list of desired shas
         :param graph_walker: GraphWalker instance
         :param graph_walker: GraphWalker instance
         :param pack_data: Function that can write pack data
         :param pack_data: Function that can write pack data
@@ -281,12 +335,8 @@ class SSHSubprocess(object):
 
 
     def __init__(self, proc):
     def __init__(self, proc):
         self.proc = proc
         self.proc = proc
-
-    def send(self, data):
-        return os.write(self.proc.stdin.fileno(), data)
-
-    def recv(self, count):
-        return self.proc.stdout.read(count)
+        self.read = self.recv = proc.stdout.read
+        self.write = self.send = proc.stdin.write
 
 
     def close(self):
     def close(self):
         self.proc.stdin.close()
         self.proc.stdin.close()
@@ -323,7 +373,9 @@ class SSHGitClient(GitClient):
         self._kwargs = kwargs
         self._kwargs = kwargs
 
 
     def send_pack(self, path, determine_wants, generate_pack_contents):
     def send_pack(self, path, determine_wants, generate_pack_contents):
-        remote = get_ssh_vendor().connect_ssh(self.host, ["git-receive-pack '%s'" % path], port=self.port, username=self.username)
+        remote = get_ssh_vendor().connect_ssh(
+            self.host, ["git-receive-pack '%s'" % path],
+            port=self.port, username=self.username)
         client = GitClient(lambda: _fileno_can_read(remote.proc.stdout.fileno()), remote.recv, remote.send, *self._args, **self._kwargs)
         client = GitClient(lambda: _fileno_can_read(remote.proc.stdout.fileno()), remote.recv, remote.send, *self._args, **self._kwargs)
         return client.send_pack(path, determine_wants, generate_pack_contents)
         return client.send_pack(path, determine_wants, generate_pack_contents)
 
 
@@ -334,3 +386,17 @@ class SSHGitClient(GitClient):
         return client.fetch_pack(path, determine_wants, graph_walker, pack_data,
         return client.fetch_pack(path, determine_wants, graph_walker, pack_data,
                                  progress)
                                  progress)
 
 
+
+def get_transport_and_path(uri):
+    """Obtain a git client from a URI or path.
+
+    :param uri: URI or path
+    :return: Tuple with client instance and relative path.
+    """
+    from dulwich.client import TCPGitClient, SSHGitClient, SubprocessGitClient
+    for handler, transport in (("git://", TCPGitClient), ("git+ssh://", SSHGitClient)):
+        if uri.startswith(handler):
+            host, path = uri[len(handler):].split("/", 1)
+            return transport(host), "/"+path
+    # if its not git or git+ssh, try a local url..
+    return SubprocessGitClient(), uri

+ 56 - 21
dulwich/errors.py

@@ -1,17 +1,17 @@
 # errors.py -- errors for dulwich
 # errors.py -- errors for dulwich
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
 # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
 # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # as published by the Free Software Foundation; version 2
 # or (at your option) any later version of the License.
 # or (at your option) any later version of the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -19,71 +19,83 @@
 
 
 """Dulwich-related exception classes and utility functions."""
 """Dulwich-related exception classes and utility functions."""
 
 
+import binascii
+
+
 class ChecksumMismatch(Exception):
 class ChecksumMismatch(Exception):
     """A checksum didn't match the expected contents."""
     """A checksum didn't match the expected contents."""
 
 
     def __init__(self, expected, got, extra=None):
     def __init__(self, expected, got, extra=None):
+        if len(expected) == 20:
+            expected = binascii.hexlify(expected)
+        if len(got) == 20:
+            got = binascii.hexlify(got)
         self.expected = expected
         self.expected = expected
         self.got = got
         self.got = got
         self.extra = extra
         self.extra = extra
         if self.extra is None:
         if self.extra is None:
-            Exception.__init__(self, 
+            Exception.__init__(self,
                 "Checksum mismatch: Expected %s, got %s" % (expected, got))
                 "Checksum mismatch: Expected %s, got %s" % (expected, got))
         else:
         else:
             Exception.__init__(self,
             Exception.__init__(self,
-                "Checksum mismatch: Expected %s, got %s; %s" % 
+                "Checksum mismatch: Expected %s, got %s; %s" %
                 (expected, got, extra))
                 (expected, got, extra))
 
 
 
 
 class WrongObjectException(Exception):
 class WrongObjectException(Exception):
     """Baseclass for all the _ is not a _ exceptions on objects.
     """Baseclass for all the _ is not a _ exceptions on objects.
-  
+
     Do not instantiate directly.
     Do not instantiate directly.
-  
-    Subclasses should define a _type attribute that indicates what
+
+    Subclasses should define a type_name attribute that indicates what
     was expected if they were raised.
     was expected if they were raised.
     """
     """
-  
+
     def __init__(self, sha, *args, **kwargs):
     def __init__(self, sha, *args, **kwargs):
-        string = "%s is not a %s" % (sha, self._type)
-        Exception.__init__(self, string)
+        Exception.__init__(self, "%s is not a %s" % (sha, self.type_name))
 
 
 
 
 class NotCommitError(WrongObjectException):
 class NotCommitError(WrongObjectException):
     """Indicates that the sha requested does not point to a commit."""
     """Indicates that the sha requested does not point to a commit."""
-  
-    _type = 'commit'
+
+    type_name = 'commit'
 
 
 
 
 class NotTreeError(WrongObjectException):
 class NotTreeError(WrongObjectException):
     """Indicates that the sha requested does not point to a tree."""
     """Indicates that the sha requested does not point to a tree."""
-  
-    _type = 'tree'
+
+    type_name = 'tree'
+
+
+class NotTagError(WrongObjectException):
+    """Indicates that the sha requested does not point to a tag."""
+
+    type_name = 'tag'
 
 
 
 
 class NotBlobError(WrongObjectException):
 class NotBlobError(WrongObjectException):
     """Indicates that the sha requested does not point to a blob."""
     """Indicates that the sha requested does not point to a blob."""
-  
-    _type = 'blob'
+
+    type_name = 'blob'
 
 
 
 
 class MissingCommitError(Exception):
 class MissingCommitError(Exception):
     """Indicates that a commit was not found in the repository"""
     """Indicates that a commit was not found in the repository"""
-  
+
     def __init__(self, sha, *args, **kwargs):
     def __init__(self, sha, *args, **kwargs):
         Exception.__init__(self, "%s is not in the revision store" % sha)
         Exception.__init__(self, "%s is not in the revision store" % sha)
 
 
 
 
 class ObjectMissing(Exception):
 class ObjectMissing(Exception):
     """Indicates that a requested object is missing."""
     """Indicates that a requested object is missing."""
-  
+
     def __init__(self, sha, *args, **kwargs):
     def __init__(self, sha, *args, **kwargs):
         Exception.__init__(self, "%s is not in the pack" % sha)
         Exception.__init__(self, "%s is not in the pack" % sha)
 
 
 
 
 class ApplyDeltaError(Exception):
 class ApplyDeltaError(Exception):
     """Indicates that applying a delta failed."""
     """Indicates that applying a delta failed."""
-    
+
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         Exception.__init__(self, *args, **kwargs)
         Exception.__init__(self, *args, **kwargs)
 
 
@@ -97,11 +109,26 @@ class NotGitRepository(Exception):
 
 
 class GitProtocolError(Exception):
 class GitProtocolError(Exception):
     """Git protocol exception."""
     """Git protocol exception."""
-    
+
+    def __init__(self, *args, **kwargs):
+        Exception.__init__(self, *args, **kwargs)
+
+
+class SendPackError(GitProtocolError):
+    """An error occurred during send_pack."""
+
     def __init__(self, *args, **kwargs):
     def __init__(self, *args, **kwargs):
         Exception.__init__(self, *args, **kwargs)
         Exception.__init__(self, *args, **kwargs)
 
 
 
 
+class UpdateRefsError(GitProtocolError):
+    """The server reported errors updating refs."""
+
+    def __init__(self, *args, **kwargs):
+        self.ref_status = kwargs.pop('ref_status')
+        Exception.__init__(self, *args, **kwargs)
+
+
 class HangupException(GitProtocolError):
 class HangupException(GitProtocolError):
     """Hangup exception."""
     """Hangup exception."""
 
 
@@ -118,5 +145,13 @@ class PackedRefsException(FileFormatException):
     """Indicates an error parsing a packed-refs file."""
     """Indicates an error parsing a packed-refs file."""
 
 
 
 
+class ObjectFormatException(FileFormatException):
+    """Indicates an error parsing an object."""
+
+
 class NoIndexPresent(Exception):
 class NoIndexPresent(Exception):
     """No index is present."""
     """No index is present."""
+
+
+class CommitError(Exception):
+    """An error occurred while performing a commit."""

+ 82 - 0
dulwich/fastexport.py

@@ -0,0 +1,82 @@
+# __init__.py -- Fast export/import functionality
+# Copyright (C) 2010 Jelmer Vernooij <jelmer@samba.org>
+# 
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of 
+# the License.
+# 
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+# 
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+
+"""Fast export/import functionality."""
+
+from dulwich.objects import (
+    Tree,
+    format_timezone,
+    )
+
+import stat
+
+class FastExporter(object):
+    """Generate a fast-export output stream for Git objects."""
+
+    def __init__(self, outf, store):
+        self.outf = outf
+        self.store = store
+        self.markers = {}
+        self._marker_idx = 0
+
+    def _allocate_marker(self):
+        self._marker_idx+=1
+        return self._marker_idx
+
+    def _dump_blob(self, blob, marker):
+        self.outf.write("blob\nmark :%s\n" % marker)
+        self.outf.write("data %s\n" % blob.raw_length())
+        for chunk in blob.as_raw_chunks():
+            self.outf.write(chunk)
+        self.outf.write("\n")
+
+    def export_blob(self, blob):
+        i = self._allocate_marker()
+        self.markers[i] = blob.id
+        self._dump_blob(blob, i)
+        return i
+
+    def _dump_commit(self, commit, marker, ref, file_changes):
+        self.outf.write("commit %s\n" % ref)
+        self.outf.write("mark :%s\n" % marker)
+        self.outf.write("author %s %s %s\n" % (commit.author,
+            commit.author_time, format_timezone(commit.author_timezone)))
+        self.outf.write("committer %s %s %s\n" % (commit.committer,
+            commit.commit_time, format_timezone(commit.commit_timezone)))
+        self.outf.write("data %s\n" % len(commit.message))
+        self.outf.write(commit.message)
+        self.outf.write("\n")
+        self.outf.write('\n'.join(file_changes))
+        self.outf.write("\n\n")
+
+    def export_commit(self, commit, ref, base_tree=None):
+        file_changes = []
+        for (old_path, new_path), (old_mode, new_mode), (old_hexsha, new_hexsha) in \
+                self.store.tree_changes(base_tree, commit.tree):
+            if new_path is None:
+                file_changes.append("D %s" % old_path)
+                continue
+            if not stat.S_ISDIR(new_mode):
+                marker = self.export_blob(self.store[new_hexsha])
+            file_changes.append("M %o :%s %s" % (new_mode, marker, new_path))
+
+        i = self._allocate_marker()
+        self._dump_commit(commit, i, ref, file_changes)
+        return i

+ 41 - 2
dulwich/file.py

@@ -22,6 +22,7 @@
 
 
 import errno
 import errno
 import os
 import os
+import tempfile
 
 
 def ensure_dir_exists(dirname):
 def ensure_dir_exists(dirname):
     """Ensure a directory exists, creating if necessary."""
     """Ensure a directory exists, creating if necessary."""
@@ -31,6 +32,36 @@ def ensure_dir_exists(dirname):
         if e.errno != errno.EEXIST:
         if e.errno != errno.EEXIST:
             raise
             raise
 
 
+def fancy_rename(oldname, newname):
+    """Rename file with temporary backup file to rollback if rename fails"""
+    if not os.path.exists(newname):
+        try:
+            os.rename(oldname, newname)
+        except OSError, e:
+            raise
+        return
+
+    # destination file exists
+    try:
+        (fd, tmpfile) = tempfile.mkstemp(".tmp", prefix=oldname+".", dir=".")
+        os.close(fd)
+        os.remove(tmpfile)
+    except OSError, e:
+        # either file could not be created (e.g. permission problem)
+        # or could not be deleted (e.g. rude virus scanner)
+        raise
+    try:
+        os.rename(newname, tmpfile)
+    except OSError, e:
+        raise   # no rename occurred
+    try:
+        os.rename(oldname, newname)
+    except OSError, e:
+        os.rename(tmpfile, newname)
+        raise
+    os.remove(tmpfile)
+
+
 def GitFile(filename, mode='r', bufsize=-1):
 def GitFile(filename, mode='r', bufsize=-1):
     """Create a file object that obeys the git file locking protocol.
     """Create a file object that obeys the git file locking protocol.
 
 
@@ -89,7 +120,8 @@ class _GitFile(object):
     def __init__(self, filename, mode, bufsize):
     def __init__(self, filename, mode, bufsize):
         self._filename = filename
         self._filename = filename
         self._lockfilename = '%s.lock' % self._filename
         self._lockfilename = '%s.lock' % self._filename
-        fd = os.open(self._lockfilename, os.O_RDWR | os.O_CREAT | os.O_EXCL)
+        fd = os.open(self._lockfilename,
+            os.O_RDWR | os.O_CREAT | os.O_EXCL | getattr(os, "O_BINARY", 0))
         self._file = os.fdopen(fd, mode, bufsize)
         self._file = os.fdopen(fd, mode, bufsize)
         self._closed = False
         self._closed = False
 
 
@@ -111,6 +143,7 @@ class _GitFile(object):
             # The file may have been removed already, which is ok.
             # The file may have been removed already, which is ok.
             if e.errno != errno.ENOENT:
             if e.errno != errno.ENOENT:
                 raise
                 raise
+            self._closed = True
 
 
     def close(self):
     def close(self):
         """Close this file, saving the lockfile over the original.
         """Close this file, saving the lockfile over the original.
@@ -127,7 +160,13 @@ class _GitFile(object):
             return
             return
         self._file.close()
         self._file.close()
         try:
         try:
-            os.rename(self._lockfilename, self._filename)
+            try:
+                os.rename(self._lockfilename, self._filename)
+            except OSError, e:
+                # Windows versions prior to Vista don't support atomic renames
+                if e.errno != errno.EEXIST:
+                    raise
+                fancy_rename(self._lockfilename, self._filename)
         finally:
         finally:
             self.abort()
             self.abort()
 
 

+ 16 - 0
dulwich/index.py

@@ -204,6 +204,8 @@ class Index(object):
 
 
     def read(self):
     def read(self):
         """Read current contents of index from disk."""
         """Read current contents of index from disk."""
+        if not os.path.exists(self._filename):
+            return
         f = GitFile(self._filename, 'rb')
         f = GitFile(self._filename, 'rb')
         try:
         try:
             f = SHA1Reader(f)
             f = SHA1Reader(f)
@@ -254,6 +256,10 @@ class Index(object):
         # Remove the old entry if any
         # Remove the old entry if any
         self._byname[name] = x
         self._byname[name] = x
 
 
+    def __delitem__(self, name):
+        assert isinstance(name, str)
+        del self._byname[name]
+
     def iteritems(self):
     def iteritems(self):
         return self._byname.iteritems()
         return self._byname.iteritems()
 
 
@@ -283,6 +289,14 @@ class Index(object):
         for name in mine:
         for name in mine:
             yield ((None, name), (None, self.get_mode(name)), (None, self.get_sha1(name)))
             yield ((None, name), (None, self.get_mode(name)), (None, self.get_sha1(name)))
 
 
+    def commit(self, object_store):
+        """Create a new tree from an index.
+
+        :param object_store: Object store to save the tree in
+        :return: Root tree SHA
+        """
+        return commit_tree(object_store, self.iterblobs())
+
 
 
 def commit_tree(object_store, blobs):
 def commit_tree(object_store, blobs):
     """Commit a new tree.
     """Commit a new tree.
@@ -327,5 +341,7 @@ def commit_index(object_store, index):
 
 
     :param object_store: Object store to save the tree in
     :param object_store: Object store to save the tree in
     :param index: Index file
     :param index: Index file
+    :note: This function is deprecated, use index.commit() instead.
+    :return: Root tree sha.
     """
     """
     return commit_tree(object_store, index.iterblobs())
     return commit_tree(object_store, index.iterblobs())

+ 13 - 3
dulwich/misc.py

@@ -15,15 +15,26 @@
 # along with this program; if not, write to the Free Software
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # MA  02110-1301, USA.
 # MA  02110-1301, USA.
-"""Misc utilities to work with python2.4.
+"""Misc utilities to work with python <2.6.
 
 
 These utilities can all be deleted when dulwich decides it wants to stop
 These utilities can all be deleted when dulwich decides it wants to stop
-support for python 2.4.
+support for python <2.6.
 """
 """
 try:
 try:
     import hashlib
     import hashlib
 except ImportError:
 except ImportError:
     import sha
     import sha
+
+try:
+    from urlparse import parse_qs
+except ImportError:
+    from cgi import parse_qs
+
+try:
+    from os import SEEK_END
+except ImportError:
+    SEEK_END = 2
+
 import struct
 import struct
 
 
 
 
@@ -87,4 +98,3 @@ def unpack_from(fmt, buf, offset=0):
     except AttributeError:
     except AttributeError:
         b = buf[offset:offset+struct.calcsize(fmt)]
         b = buf[offset:offset+struct.calcsize(fmt)]
         return struct.unpack(fmt, b)
         return struct.unpack(fmt, b)
-

+ 128 - 67
dulwich/object_store.py

@@ -23,6 +23,7 @@
 import errno
 import errno
 import itertools
 import itertools
 import os
 import os
+import posixpath
 import stat
 import stat
 import tempfile
 import tempfile
 import urllib2
 import urllib2
@@ -38,11 +39,13 @@ from dulwich.objects import (
     Tree,
     Tree,
     hex_to_sha,
     hex_to_sha,
     sha_to_hex,
     sha_to_hex,
+    hex_to_filename,
     S_ISGITLINK,
     S_ISGITLINK,
     )
     )
 from dulwich.pack import (
 from dulwich.pack import (
     Pack,
     Pack,
     PackData,
     PackData,
+    ThinPackData,
     iter_sha1,
     iter_sha1,
     load_pack_index,
     load_pack_index,
     write_pack,
     write_pack,
@@ -57,7 +60,8 @@ class BaseObjectStore(object):
     """Object store interface."""
     """Object store interface."""
 
 
     def determine_wants_all(self, refs):
     def determine_wants_all(self, refs):
-	    return [sha for (ref, sha) in refs.iteritems() if not sha in self and not ref.endswith("^{}")]
+        return [sha for (ref, sha) in refs.iteritems()
+                if not sha in self and not ref.endswith("^{}")]
 
 
     def iter_shas(self, shas):
     def iter_shas(self, shas):
         """Iterate over the objects for the specified shas.
         """Iterate over the objects for the specified shas.
@@ -91,14 +95,14 @@ class BaseObjectStore(object):
         """Obtain the raw text for an object.
         """Obtain the raw text for an object.
 
 
         :param name: sha for the object.
         :param name: sha for the object.
-        :return: tuple with object type and object contents.
+        :return: tuple with numeric type and object contents.
         """
         """
         raise NotImplementedError(self.get_raw)
         raise NotImplementedError(self.get_raw)
 
 
     def __getitem__(self, sha):
     def __getitem__(self, sha):
         """Obtain an object by SHA1."""
         """Obtain an object by SHA1."""
-        type, uncomp = self.get_raw(sha)
-        return ShaFile.from_raw_string(type, uncomp)
+        type_num, uncomp = self.get_raw(sha)
+        return ShaFile.from_raw_string(type_num, uncomp)
 
 
     def __iter__(self):
     def __iter__(self):
         """Iterate over the SHAs that are present in this store."""
         """Iterate over the SHAs that are present in this store."""
@@ -137,10 +141,7 @@ class BaseObjectStore(object):
             else:
             else:
                 ttree = {}
                 ttree = {}
             for name, oldmode, oldhexsha in stree.iteritems():
             for name, oldmode, oldhexsha in stree.iteritems():
-                if path == "":
-                    oldchildpath = name
-                else:
-                    oldchildpath = "%s/%s" % (path, name)
+                oldchildpath = posixpath.join(path, name)
                 try:
                 try:
                     (newmode, newhexsha) = ttree[name]
                     (newmode, newhexsha) = ttree[name]
                     newchildpath = oldchildpath
                     newchildpath = oldchildpath
@@ -148,7 +149,7 @@ class BaseObjectStore(object):
                     newmode = None
                     newmode = None
                     newhexsha = None
                     newhexsha = None
                     newchildpath = None
                     newchildpath = None
-                if (want_unchanged or oldmode != newmode or 
+                if (want_unchanged or oldmode != newmode or
                     oldhexsha != newhexsha):
                     oldhexsha != newhexsha):
                     if stat.S_ISDIR(oldmode):
                     if stat.S_ISDIR(oldmode):
                         if newmode is None or stat.S_ISDIR(newmode):
                         if newmode is None or stat.S_ISDIR(newmode):
@@ -166,10 +167,7 @@ class BaseObjectStore(object):
                             yield ((oldchildpath, newchildpath), (oldmode, newmode), (oldhexsha, newhexsha))
                             yield ((oldchildpath, newchildpath), (oldmode, newmode), (oldhexsha, newhexsha))
 
 
             for name, newmode, newhexsha in ttree.iteritems():
             for name, newmode, newhexsha in ttree.iteritems():
-                if path == "":
-                    childpath = name
-                else:
-                    childpath = "%s/%s" % (path, name)
+                childpath = posixpath.join(path, name)
                 if not name in stree:
                 if not name in stree:
                     if not stat.S_ISDIR(newmode):
                     if not stat.S_ISDIR(newmode):
                         yield ((None, childpath), (None, newmode), (None, newhexsha))
                         yield ((None, childpath), (None, newmode), (None, newhexsha))
@@ -185,26 +183,27 @@ class BaseObjectStore(object):
         while todo:
         while todo:
             (tid, tpath) = todo.pop()
             (tid, tpath) = todo.pop()
             tree = self[tid]
             tree = self[tid]
-            for name, mode, hexsha in tree.iteritems(): 
-                if tpath == "":
-                    path = name
-                else:
-                    path = "%s/%s" % (tpath, name)
+            for name, mode, hexsha in tree.iteritems():
+                path = posixpath.join(tpath, name)
                 if stat.S_ISDIR(mode):
                 if stat.S_ISDIR(mode):
                     todo.add((hexsha, path))
                     todo.add((hexsha, path))
                 else:
                 else:
                     yield path, mode, hexsha
                     yield path, mode, hexsha
 
 
-    def find_missing_objects(self, haves, wants, progress=None):
+    def find_missing_objects(self, haves, wants, progress=None,
+                             get_tagged=None):
         """Find the missing objects required for a set of revisions.
         """Find the missing objects required for a set of revisions.
 
 
         :param haves: Iterable over SHAs already in common.
         :param haves: Iterable over SHAs already in common.
         :param wants: Iterable over SHAs of objects to fetch.
         :param wants: Iterable over SHAs of objects to fetch.
-        :param progress: Simple progress function that will be called with 
+        :param progress: Simple progress function that will be called with
             updated progress strings.
             updated progress strings.
+        :param get_tagged: Function that returns a dict of pointed-to sha -> tag
+            sha for including tags.
         :return: Iterator over (sha, path) pairs.
         :return: Iterator over (sha, path) pairs.
         """
         """
-        return iter(MissingObjectFinder(self, haves, wants, progress).next, None)
+        finder = MissingObjectFinder(self, haves, wants, progress, get_tagged)
+        return iter(finder.next, None)
 
 
     def find_common_revisions(self, graphwalker):
     def find_common_revisions(self, graphwalker):
         """Find which revisions this store has in common using graphwalker.
         """Find which revisions this store has in common using graphwalker.
@@ -223,19 +222,20 @@ class BaseObjectStore(object):
 
 
     def get_graph_walker(self, heads):
     def get_graph_walker(self, heads):
         """Obtain a graph walker for this object store.
         """Obtain a graph walker for this object store.
-        
+
         :param heads: Local heads to start search with
         :param heads: Local heads to start search with
         :return: GraphWalker object
         :return: GraphWalker object
         """
         """
         return ObjectStoreGraphWalker(heads, lambda sha: self[sha].parents)
         return ObjectStoreGraphWalker(heads, lambda sha: self[sha].parents)
 
 
-    def generate_pack_contents(self, have, want):
+    def generate_pack_contents(self, have, want, progress=None):
         """Iterate over the contents of a pack file.
         """Iterate over the contents of a pack file.
 
 
         :param have: List of SHA1s of objects that should not be sent
         :param have: List of SHA1s of objects that should not be sent
         :param want: List of SHA1s of objects that should be sent
         :param want: List of SHA1s of objects that should be sent
+        :param progress: Optional progress reporting method
         """
         """
-        return self.iter_shas(self.find_missing_objects(have, want))
+        return self.iter_shas(self.find_missing_objects(have, want, progress))
 
 
 
 
 class PackBasedObjectStore(BaseObjectStore):
 class PackBasedObjectStore(BaseObjectStore):
@@ -253,6 +253,10 @@ class PackBasedObjectStore(BaseObjectStore):
     def _load_packs(self):
     def _load_packs(self):
         raise NotImplementedError(self._load_packs)
         raise NotImplementedError(self._load_packs)
 
 
+    def _pack_cache_stale(self):
+        """Check whether the pack cache is stale."""
+        raise NotImplementedError(self._pack_cache_stale)
+
     def _add_known_pack(self, pack):
     def _add_known_pack(self, pack):
         """Add a newly appeared pack to the cache by path.
         """Add a newly appeared pack to the cache by path.
 
 
@@ -263,7 +267,7 @@ class PackBasedObjectStore(BaseObjectStore):
     @property
     @property
     def packs(self):
     def packs(self):
         """List with pack objects."""
         """List with pack objects."""
-        if self._pack_cache is None:
+        if self._pack_cache is None or self._pack_cache_stale():
             self._pack_cache = self._load_packs()
             self._pack_cache = self._load_packs()
         return self._pack_cache
         return self._pack_cache
 
 
@@ -284,9 +288,9 @@ class PackBasedObjectStore(BaseObjectStore):
 
 
     def get_raw(self, name):
     def get_raw(self, name):
         """Obtain the raw text for an object.
         """Obtain the raw text for an object.
-        
+
         :param name: sha for the object.
         :param name: sha for the object.
-        :return: tuple with object type and object contents.
+        :return: tuple with numeric type and object contents.
         """
         """
         if len(name) == 40:
         if len(name) == 40:
             sha = hex_to_sha(name)
             sha = hex_to_sha(name)
@@ -301,24 +305,25 @@ class PackBasedObjectStore(BaseObjectStore):
                 return pack.get_raw(sha)
                 return pack.get_raw(sha)
             except KeyError:
             except KeyError:
                 pass
                 pass
-        if hexsha is None: 
+        if hexsha is None:
             hexsha = sha_to_hex(name)
             hexsha = sha_to_hex(name)
         ret = self._get_loose_object(hexsha)
         ret = self._get_loose_object(hexsha)
         if ret is not None:
         if ret is not None:
-            return ret.type, ret.as_raw_string()
+            return ret.type_num, ret.as_raw_string()
         raise KeyError(hexsha)
         raise KeyError(hexsha)
 
 
     def add_objects(self, objects):
     def add_objects(self, objects):
         """Add a set of objects to this object store.
         """Add a set of objects to this object store.
 
 
         :param objects: Iterable over objects, should support __len__.
         :param objects: Iterable over objects, should support __len__.
+        :return: Pack object of the objects written.
         """
         """
         if len(objects) == 0:
         if len(objects) == 0:
             # Don't bother writing an empty pack file
             # Don't bother writing an empty pack file
             return
             return
         f, commit = self.add_pack()
         f, commit = self.add_pack()
         write_pack_data(f, objects, len(objects))
         write_pack_data(f, objects, len(objects))
-        commit()
+        return commit()
 
 
 
 
 class DiskObjectStore(PackBasedObjectStore):
 class DiskObjectStore(PackBasedObjectStore):
@@ -332,11 +337,14 @@ class DiskObjectStore(PackBasedObjectStore):
         super(DiskObjectStore, self).__init__()
         super(DiskObjectStore, self).__init__()
         self.path = path
         self.path = path
         self.pack_dir = os.path.join(self.path, PACKDIR)
         self.pack_dir = os.path.join(self.path, PACKDIR)
+        self._pack_cache_time = 0
 
 
     def _load_packs(self):
     def _load_packs(self):
         pack_files = []
         pack_files = []
         try:
         try:
-            for name in os.listdir(self.pack_dir):
+            self._pack_cache_time = os.stat(self.pack_dir).st_mtime
+            pack_dir_contents = os.listdir(self.pack_dir)
+            for name in pack_dir_contents:
                 # TODO: verify that idx exists first
                 # TODO: verify that idx exists first
                 if name.startswith("pack-") and name.endswith(".pack"):
                 if name.startswith("pack-") and name.endswith(".pack"):
                     filename = os.path.join(self.pack_dir, name)
                     filename = os.path.join(self.pack_dir, name)
@@ -349,11 +357,17 @@ class DiskObjectStore(PackBasedObjectStore):
         suffix_len = len(".pack")
         suffix_len = len(".pack")
         return [Pack(f[:-suffix_len]) for _, f in pack_files]
         return [Pack(f[:-suffix_len]) for _, f in pack_files]
 
 
+    def _pack_cache_stale(self):
+        try:
+            return os.stat(self.pack_dir).st_mtime > self._pack_cache_time
+        except OSError, e:
+            if e.errno == errno.ENOENT:
+                return True
+            raise
+
     def _get_shafile_path(self, sha):
     def _get_shafile_path(self, sha):
-        dir = sha[:2]
-        file = sha[2:]
         # Check from object dir
         # Check from object dir
-        return os.path.join(self.path, dir, file)
+        return hex_to_filename(self.path, sha)
 
 
     def _iter_loose_objects(self):
     def _iter_loose_objects(self):
         for base in os.listdir(self.path):
         for base in os.listdir(self.path):
@@ -365,8 +379,8 @@ class DiskObjectStore(PackBasedObjectStore):
     def _get_loose_object(self, sha):
     def _get_loose_object(self, sha):
         path = self._get_shafile_path(sha)
         path = self._get_shafile_path(sha)
         try:
         try:
-            return ShaFile.from_file(path)
-        except OSError, e:
+            return ShaFile.from_path(path)
+        except (OSError, IOError), e:
             if e.errno == errno.ENOENT:
             if e.errno == errno.ENOENT:
                 return None
                 return None
             raise
             raise
@@ -374,51 +388,54 @@ class DiskObjectStore(PackBasedObjectStore):
     def move_in_thin_pack(self, path):
     def move_in_thin_pack(self, path):
         """Move a specific file containing a pack into the pack directory.
         """Move a specific file containing a pack into the pack directory.
 
 
-        :note: The file should be on the same file system as the 
+        :note: The file should be on the same file system as the
             packs directory.
             packs directory.
 
 
         :param path: Path to the pack file.
         :param path: Path to the pack file.
         """
         """
-        data = PackData(path)
+        data = ThinPackData(self.get_raw, path)
 
 
         # Write index for the thin pack (do we really need this?)
         # Write index for the thin pack (do we really need this?)
-        temppath = os.path.join(self.pack_dir, 
+        temppath = os.path.join(self.pack_dir,
             sha_to_hex(urllib2.randombytes(20))+".tempidx")
             sha_to_hex(urllib2.randombytes(20))+".tempidx")
-        data.create_index_v2(temppath, self.get_raw)
+        data.create_index_v2(temppath)
         p = Pack.from_objects(data, load_pack_index(temppath))
         p = Pack.from_objects(data, load_pack_index(temppath))
 
 
         # Write a full pack version
         # Write a full pack version
-        temppath = os.path.join(self.pack_dir, 
+        temppath = os.path.join(self.pack_dir,
             sha_to_hex(urllib2.randombytes(20))+".temppack")
             sha_to_hex(urllib2.randombytes(20))+".temppack")
-        write_pack(temppath, ((o, None) for o in p.iterobjects(self.get_raw)), 
-                len(p))
+        write_pack(temppath, ((o, None) for o in p.iterobjects()), len(p))
         pack_sha = load_pack_index(temppath+".idx").objects_sha1()
         pack_sha = load_pack_index(temppath+".idx").objects_sha1()
         newbasename = os.path.join(self.pack_dir, "pack-%s" % pack_sha)
         newbasename = os.path.join(self.pack_dir, "pack-%s" % pack_sha)
         os.rename(temppath+".pack", newbasename+".pack")
         os.rename(temppath+".pack", newbasename+".pack")
         os.rename(temppath+".idx", newbasename+".idx")
         os.rename(temppath+".idx", newbasename+".idx")
-        self._add_known_pack(Pack(newbasename))
+        final_pack = Pack(newbasename)
+        self._add_known_pack(final_pack)
+        return final_pack
 
 
     def move_in_pack(self, path):
     def move_in_pack(self, path):
         """Move a specific file containing a pack into the pack directory.
         """Move a specific file containing a pack into the pack directory.
 
 
-        :note: The file should be on the same file system as the 
+        :note: The file should be on the same file system as the
             packs directory.
             packs directory.
 
 
         :param path: Path to the pack file.
         :param path: Path to the pack file.
         """
         """
         p = PackData(path)
         p = PackData(path)
         entries = p.sorted_entries()
         entries = p.sorted_entries()
-        basename = os.path.join(self.pack_dir, 
+        basename = os.path.join(self.pack_dir,
             "pack-%s" % iter_sha1(entry[0] for entry in entries))
             "pack-%s" % iter_sha1(entry[0] for entry in entries))
         write_pack_index_v2(basename+".idx", entries, p.get_stored_checksum())
         write_pack_index_v2(basename+".idx", entries, p.get_stored_checksum())
         p.close()
         p.close()
         os.rename(path, basename + ".pack")
         os.rename(path, basename + ".pack")
-        self._add_known_pack(Pack(basename))
+        final_pack = Pack(basename)
+        self._add_known_pack(final_pack)
+        return final_pack
 
 
     def add_thin_pack(self):
     def add_thin_pack(self):
         """Add a new thin pack to this object store.
         """Add a new thin pack to this object store.
 
 
-        Thin packs are packs that contain deltas with parents that exist 
+        Thin packs are packs that contain deltas with parents that exist
         in a different pack.
         in a different pack.
         """
         """
         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
@@ -427,13 +444,15 @@ class DiskObjectStore(PackBasedObjectStore):
             os.fsync(fd)
             os.fsync(fd)
             f.close()
             f.close()
             if os.path.getsize(path) > 0:
             if os.path.getsize(path) > 0:
-                self.move_in_thin_pack(path)
+                return self.move_in_thin_pack(path)
+            else:
+                return None
         return f, commit
         return f, commit
 
 
     def add_pack(self):
     def add_pack(self):
-        """Add a new pack to this object store. 
+        """Add a new pack to this object store.
 
 
-        :return: Fileobject to write to and a commit function to 
+        :return: Fileobject to write to and a commit function to
             call when the pack is finished.
             call when the pack is finished.
         """
         """
         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
         fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack")
@@ -442,7 +461,9 @@ class DiskObjectStore(PackBasedObjectStore):
             os.fsync(fd)
             os.fsync(fd)
             f.close()
             f.close()
             if os.path.getsize(path) > 0:
             if os.path.getsize(path) > 0:
-                self.move_in_pack(path)
+                return self.move_in_pack(path)
+            else:
+                return None
         return f, commit
         return f, commit
 
 
     def add_object(self, obj):
     def add_object(self, obj):
@@ -451,8 +472,11 @@ class DiskObjectStore(PackBasedObjectStore):
         :param obj: Object to add
         :param obj: Object to add
         """
         """
         dir = os.path.join(self.path, obj.id[:2])
         dir = os.path.join(self.path, obj.id[:2])
-        if not os.path.isdir(dir):
+        try:
             os.mkdir(dir)
             os.mkdir(dir)
+        except OSError, e:
+            if e.errno != errno.EEXIST:
+                raise
         path = os.path.join(dir, obj.id[2:])
         path = os.path.join(dir, obj.id[2:])
         if os.path.exists(path):
         if os.path.exists(path):
             return # Already there, no need to write again
             return # Already there, no need to write again
@@ -462,6 +486,17 @@ class DiskObjectStore(PackBasedObjectStore):
         finally:
         finally:
             f.close()
             f.close()
 
 
+    @classmethod
+    def init(cls, path):
+        try:
+            os.mkdir(path)
+        except OSError, e:
+            if e.errno != errno.EEXIST:
+                raise
+        os.mkdir(os.path.join(path, "info"))
+        os.mkdir(os.path.join(path, PACKDIR))
+        return cls(path)
+
 
 
 class MemoryObjectStore(BaseObjectStore):
 class MemoryObjectStore(BaseObjectStore):
     """Object store that keeps all objects in memory."""
     """Object store that keeps all objects in memory."""
@@ -489,9 +524,9 @@ class MemoryObjectStore(BaseObjectStore):
 
 
     def get_raw(self, name):
     def get_raw(self, name):
         """Obtain the raw text for an object.
         """Obtain the raw text for an object.
-        
+
         :param name: sha for the object.
         :param name: sha for the object.
-        :return: tuple with object type and object contents.
+        :return: tuple with numeric type and object contents.
         """
         """
         return self[name].as_raw_string()
         return self[name].as_raw_string()
 
 
@@ -573,7 +608,7 @@ class ObjectStoreIterator(ObjectIterator):
     def __contains__(self, needle):
     def __contains__(self, needle):
         """Check if an object is present.
         """Check if an object is present.
 
 
-        :note: This checks if the object is present in 
+        :note: This checks if the object is present in
             the underlying object store, not if it would
             the underlying object store, not if it would
             be yielded by the iterator.
             be yielded by the iterator.
 
 
@@ -583,7 +618,7 @@ class ObjectStoreIterator(ObjectIterator):
 
 
     def __getitem__(self, key):
     def __getitem__(self, key):
         """Find an object by SHA1.
         """Find an object by SHA1.
-        
+
         :note: This retrieves the object from the underlying
         :note: This retrieves the object from the underlying
             object store. It will also succeed if the object would
             object store. It will also succeed if the object would
             not be returned by the iterator.
             not be returned by the iterator.
@@ -604,9 +639,10 @@ def tree_lookup_path(lookup_obj, root_sha, path):
     """
     """
     parts = path.split("/")
     parts = path.split("/")
     sha = root_sha
     sha = root_sha
+    mode = None
     for p in parts:
     for p in parts:
         obj = lookup_obj(sha)
         obj = lookup_obj(sha)
-        if type(obj) is not Tree:
+        if not isinstance(obj, Tree):
             raise NotTreeError(sha)
             raise NotTreeError(sha)
         if p == '':
         if p == '':
             continue
             continue
@@ -617,14 +653,18 @@ def tree_lookup_path(lookup_obj, root_sha, path):
 class MissingObjectFinder(object):
 class MissingObjectFinder(object):
     """Find the objects missing from another object store.
     """Find the objects missing from another object store.
 
 
-    :param object_store: Object store containing at least all objects to be 
+    :param object_store: Object store containing at least all objects to be
         sent
         sent
     :param haves: SHA1s of commits not to send (already present in target)
     :param haves: SHA1s of commits not to send (already present in target)
     :param wants: SHA1s of commits to send
     :param wants: SHA1s of commits to send
     :param progress: Optional function to report progress to.
     :param progress: Optional function to report progress to.
+    :param get_tagged: Function that returns a dict of pointed-to sha -> tag
+        sha for including tags.
+    :param tagged: dict of pointed-to sha -> tag sha for including tags
     """
     """
 
 
-    def __init__(self, object_store, haves, wants, progress=None):
+    def __init__(self, object_store, haves, wants, progress=None,
+                 get_tagged=None):
         self.sha_done = set(haves)
         self.sha_done = set(haves)
         self.objects_to_send = set([(w, None, False) for w in wants if w not in haves])
         self.objects_to_send = set([(w, None, False) for w in wants if w not in haves])
         self.object_store = object_store
         self.object_store = object_store
@@ -632,6 +672,7 @@ class MissingObjectFinder(object):
             self.progress = lambda x: None
             self.progress = lambda x: None
         else:
         else:
             self.progress = progress
             self.progress = progress
+        self._tagged = get_tagged and get_tagged() or {}
 
 
     def add_todo(self, entries):
     def add_todo(self, entries):
         self.objects_to_send.update([e for e in entries if not e[0] in self.sha_done])
         self.objects_to_send.update([e for e in entries if not e[0] in self.sha_done])
@@ -658,13 +699,19 @@ class MissingObjectFinder(object):
                 self.parse_tree(o)
                 self.parse_tree(o)
             elif isinstance(o, Tag):
             elif isinstance(o, Tag):
                 self.parse_tag(o)
                 self.parse_tag(o)
+        if sha in self._tagged:
+            self.add_todo([(self._tagged[sha], None, True)])
         self.sha_done.add(sha)
         self.sha_done.add(sha)
         self.progress("counting objects: %d\r" % len(self.sha_done))
         self.progress("counting objects: %d\r" % len(self.sha_done))
         return (sha, name)
         return (sha, name)
 
 
 
 
 class ObjectStoreGraphWalker(object):
 class ObjectStoreGraphWalker(object):
-    """Graph walker that finds out what commits are missing from an object store."""
+    """Graph walker that finds what commits are missing from an object store.
+
+    :ivar heads: Revisions without descendants in the local repo
+    :ivar get_parents: Function to retrieve parents in the local repo
+    """
 
 
     def __init__(self, local_heads, get_parents):
     def __init__(self, local_heads, get_parents):
         """Create a new instance.
         """Create a new instance.
@@ -677,12 +724,26 @@ class ObjectStoreGraphWalker(object):
         self.parents = {}
         self.parents = {}
 
 
     def ack(self, sha):
     def ack(self, sha):
-        """Ack that a particular revision and its ancestors are present in the source."""
-        if sha in self.heads:
-            self.heads.remove(sha)
-        if sha in self.parents:
-            for p in self.parents[sha]:
-                self.ack(p)
+        """Ack that a revision and its ancestors are present in the source."""
+        ancestors = set([sha])
+
+        # stop if we run out of heads to remove
+        while self.heads:
+            for a in ancestors:
+                if a in self.heads:
+                    self.heads.remove(a)
+
+            # collect all ancestors
+            new_ancestors = set()
+            for a in ancestors:
+                if a in self.parents:
+                    new_ancestors.update(self.parents[a])
+
+            # no more ancestors; stop
+            if not new_ancestors:
+                break
+
+            ancestors = new_ancestors
 
 
     def next(self):
     def next(self):
         """Iterate over ancestors of heads in the target."""
         """Iterate over ancestors of heads in the target."""

File diff suppressed because it is too large
+ 662 - 250
dulwich/objects.py


File diff suppressed because it is too large
+ 409 - 269
dulwich/pack.py


+ 67 - 5
dulwich/patch.py

@@ -22,10 +22,14 @@ These patches are basically unified diffs with some extra metadata tacked
 on.
 on.
 """
 """
 
 
-import difflib
+from difflib import SequenceMatcher
+import rfc822
 import subprocess
 import subprocess
 import time
 import time
 
 
+from dulwich.objects import (
+    Commit,
+    )
 
 
 def write_commit_patch(f, commit, contents, progress, version=None):
 def write_commit_patch(f, commit, contents, progress, version=None):
     """Write a individual file patch.
     """Write a individual file patch.
@@ -68,6 +72,36 @@ def get_summary(commit):
     return commit.message.splitlines()[0].replace(" ", "-")
     return commit.message.splitlines()[0].replace(" ", "-")
 
 
 
 
+def unified_diff(a, b, fromfile='', tofile='', n=3):
+    """difflib.unified_diff that doesn't write any dates or trailing spaces.
+
+    Based on the same function in Python2.6.5-rc2's difflib.py
+    """
+    started = False
+    for group in SequenceMatcher(None, a, b).get_grouped_opcodes(n):
+        if not started:
+            yield '--- %s\n' % fromfile
+            yield '+++ %s\n' % tofile
+            started = True
+        i1, i2, j1, j2 = group[0][1], group[-1][2], group[0][3], group[-1][4]
+        yield "@@ -%d,%d +%d,%d @@\n" % (i1+1, i2-i1, j1+1, j2-j1)
+        for tag, i1, i2, j1, j2 in group:
+            if tag == 'equal':
+                for line in a[i1:i2]:
+                    yield ' ' + line
+                continue
+            if tag == 'replace' or tag == 'delete':
+                for line in a[i1:i2]:
+                    if not line[-1] == '\n':
+                        line += '\n\\ No newline at end of file\n'
+                    yield '-' + line
+            if tag == 'replace' or tag == 'insert':
+                for line in b[j1:j2]:
+                    if not line[-1] == '\n':
+                        line += '\n\\ No newline at end of file\n'
+                    yield '+' + line
+
+
 def write_blob_diff(f, (old_path, old_mode, old_blob), 
 def write_blob_diff(f, (old_path, old_mode, old_blob), 
                        (new_path, new_mode, new_blob)):
                        (new_path, new_mode, new_blob)):
     """Write diff file header.
     """Write diff file header.
@@ -98,13 +132,41 @@ def write_blob_diff(f, (old_path, old_mode, old_blob),
     if old_mode != new_mode:
     if old_mode != new_mode:
         if new_mode is not None:
         if new_mode is not None:
             if old_mode is not None:
             if old_mode is not None:
-                f.write("old file mode %o\n" % old_mode)
-            f.write("new file mode %o\n" % new_mode) 
+                f.write("old mode %o\n" % old_mode)
+            f.write("new mode %o\n" % new_mode) 
         else:
         else:
-            f.write("deleted file mode %o\n" % old_mode)
+            f.write("deleted mode %o\n" % old_mode)
     f.write("index %s..%s %o\n" % (
     f.write("index %s..%s %o\n" % (
         blob_id(old_blob), blob_id(new_blob), new_mode))
         blob_id(old_blob), blob_id(new_blob), new_mode))
     old_contents = lines(old_blob)
     old_contents = lines(old_blob)
     new_contents = lines(new_blob)
     new_contents = lines(new_blob)
-    f.writelines(difflib.unified_diff(old_contents, new_contents, 
+    f.writelines(unified_diff(old_contents, new_contents, 
         old_path, new_path))
         old_path, new_path))
+
+
+def git_am_patch_split(f):
+    """Parse a git-am-style patch and split it up into bits.
+
+    :param f: File-like object to parse
+    :return: Tuple with commit object, diff contents and git version
+    """
+    msg = rfc822.Message(f)
+    c = Commit()
+    c.author = msg["from"]
+    c.committer = msg["from"]
+    if msg["subject"].startswith("[PATCH"):
+        subject = msg["subject"].split("]", 1)[1][1:]
+    else:
+        subject = msg["subject"]
+    c.message = subject
+    for l in f:
+        if l == "---\n":
+            break
+        c.message += l
+    diff = ""
+    for l in f:
+        if l == "-- \n":
+            break
+        diff += l
+    version = f.next().rstrip("\n")
+    return c, diff, version

+ 116 - 2
dulwich/protocol.py

@@ -19,19 +19,27 @@
 
 
 """Generic functions for talking the git smart server protocol."""
 """Generic functions for talking the git smart server protocol."""
 
 
+from cStringIO import StringIO
+import os
 import socket
 import socket
 
 
 from dulwich.errors import (
 from dulwich.errors import (
     HangupException,
     HangupException,
     GitProtocolError,
     GitProtocolError,
     )
     )
+from dulwich.misc import (
+    SEEK_END,
+    )
 
 
 TCP_GIT_PORT = 9418
 TCP_GIT_PORT = 9418
 
 
+ZERO_SHA = "0" * 40
+
 SINGLE_ACK = 0
 SINGLE_ACK = 0
 MULTI_ACK = 1
 MULTI_ACK = 1
 MULTI_ACK_DETAILED = 2
 MULTI_ACK_DETAILED = 2
 
 
+
 class ProtocolFile(object):
 class ProtocolFile(object):
     """
     """
     Some network ops are like file ops. The file ops expect to operate on
     Some network ops are like file ops. The file ops expect to operate on
@@ -160,6 +168,112 @@ class Protocol(object):
         return cmd, args[:-1].split(chr(0))
         return cmd, args[:-1].split(chr(0))
 
 
 
 
+_RBUFSIZE = 8192  # Default read buffer size.
+
+
+class ReceivableProtocol(Protocol):
+    """Variant of Protocol that allows reading up to a size without blocking.
+
+    This class has a recv() method that behaves like socket.recv() in addition
+    to a read() method.
+
+    If you want to read n bytes from the wire and block until exactly n bytes
+    (or EOF) are read, use read(n). If you want to read at most n bytes from the
+    wire but don't care if you get less, use recv(n). Note that recv(n) will
+    still block until at least one byte is read.
+    """
+
+    def __init__(self, recv, write, report_activity=None, rbufsize=_RBUFSIZE):
+        super(ReceivableProtocol, self).__init__(self.read, write,
+                                                 report_activity)
+        self._recv = recv
+        self._rbuf = StringIO()
+        self._rbufsize = rbufsize
+
+    def read(self, size):
+        # From _fileobj.read in socket.py in the Python 2.6.5 standard library,
+        # with the following modifications:
+        #  - omit the size <= 0 branch
+        #  - seek back to start rather than 0 in case some buffer has been
+        #    consumed.
+        #  - use SEEK_END instead of the magic number.
+        # Copyright (c) 2001-2010 Python Software Foundation; All Rights Reserved
+        # Licensed under the Python Software Foundation License.
+        # TODO: see if buffer is more efficient than cStringIO.
+        assert size > 0
+
+        # Our use of StringIO rather than lists of string objects returned by
+        # recv() minimizes memory usage and fragmentation that occurs when
+        # rbufsize is large compared to the typical return value of recv().
+        buf = self._rbuf
+        start = buf.tell()
+        buf.seek(0, SEEK_END)
+        # buffer may have been partially consumed by recv()
+        buf_len = buf.tell() - start
+        if buf_len >= size:
+            # Already have size bytes in our buffer?  Extract and return.
+            buf.seek(start)
+            rv = buf.read(size)
+            self._rbuf = StringIO()
+            self._rbuf.write(buf.read())
+            self._rbuf.seek(0)
+            return rv
+
+        self._rbuf = StringIO()  # reset _rbuf.  we consume it via buf.
+        while True:
+            left = size - buf_len
+            # recv() will malloc the amount of memory given as its
+            # parameter even though it often returns much less data
+            # than that.  The returned data string is short lived
+            # as we copy it into a StringIO and free it.  This avoids
+            # fragmentation issues on many platforms.
+            data = self._recv(left)
+            if not data:
+                break
+            n = len(data)
+            if n == size and not buf_len:
+                # Shortcut.  Avoid buffer data copies when:
+                # - We have no data in our buffer.
+                # AND
+                # - Our call to recv returned exactly the
+                #   number of bytes we were asked to read.
+                return data
+            if n == left:
+                buf.write(data)
+                del data  # explicit free
+                break
+            assert n <= left, "_recv(%d) returned %d bytes" % (left, n)
+            buf.write(data)
+            buf_len += n
+            del data  # explicit free
+            #assert buf_len == buf.tell()
+        buf.seek(start)
+        return buf.read()
+
+    def recv(self, size):
+        assert size > 0
+
+        buf = self._rbuf
+        start = buf.tell()
+        buf.seek(0, SEEK_END)
+        buf_len = buf.tell()
+        buf.seek(start)
+
+        left = buf_len - start
+        if not left:
+            # only read from the wire if our read buffer is exhausted
+            data = self._recv(self._rbufsize)
+            if len(data) == size:
+                # shortcut: skip the buffer if we read exactly size bytes
+                return data
+            buf = StringIO()
+            buf.write(data)
+            buf.seek(0)
+            del data  # explicit free
+            self._rbuf = buf
+        return buf.read(size)
+
+
 def extract_capabilities(text):
 def extract_capabilities(text):
     """Extract a capabilities list from a string, if present.
     """Extract a capabilities list from a string, if present.
 
 
@@ -169,7 +283,7 @@ def extract_capabilities(text):
     if not "\0" in text:
     if not "\0" in text:
         return text, []
         return text, []
     text, capabilities = text.rstrip().split("\0")
     text, capabilities = text.rstrip().split("\0")
-    return (text, capabilities.split(" "))
+    return (text, capabilities.strip().split(" "))
 
 
 
 
 def extract_want_line_capabilities(text):
 def extract_want_line_capabilities(text):
@@ -192,7 +306,7 @@ def extract_want_line_capabilities(text):
 def ack_type(capabilities):
 def ack_type(capabilities):
     """Extract the ack type from a capabilities list."""
     """Extract the ack type from a capabilities list."""
     if 'multi_ack_detailed' in capabilities:
     if 'multi_ack_detailed' in capabilities:
-      return MULTI_ACK_DETAILED
+        return MULTI_ACK_DETAILED
     elif 'multi_ack' in capabilities:
     elif 'multi_ack' in capabilities:
         return MULTI_ACK
         return MULTI_ACK
     return SINGLE_ACK
     return SINGLE_ACK

+ 371 - 79
dulwich/repo.py

@@ -1,18 +1,18 @@
 # repo.py -- For dealing wih git repositories.
 # repo.py -- For dealing wih git repositories.
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
 # Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org>
 # Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # as published by the Free Software Foundation; version 2
-# of the License or (at your option) any later version of 
+# of the License or (at your option) any later version of
 # the License.
 # the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -26,13 +26,15 @@ import errno
 import os
 import os
 
 
 from dulwich.errors import (
 from dulwich.errors import (
-    MissingCommitError, 
+    MissingCommitError,
     NoIndexPresent,
     NoIndexPresent,
-    NotBlobError, 
-    NotCommitError, 
+    NotBlobError,
+    NotCommitError,
     NotGitRepository,
     NotGitRepository,
-    NotTreeError, 
+    NotTreeError,
+    NotTagError,
     PackedRefsException,
     PackedRefsException,
+    CommitError,
     )
     )
 from dulwich.file import (
 from dulwich.file import (
     ensure_dir_exists,
     ensure_dir_exists,
@@ -48,7 +50,10 @@ from dulwich.objects import (
     Tag,
     Tag,
     Tree,
     Tree,
     hex_to_sha,
     hex_to_sha,
+    object_class,
     )
     )
+import warnings
+
 
 
 OBJECTDIR = 'objects'
 OBJECTDIR = 'objects'
 SYMREF = 'ref: '
 SYMREF = 'ref: '
@@ -58,9 +63,6 @@ REFSDIR_HEADS = 'heads'
 INDEX_FILENAME = "index"
 INDEX_FILENAME = "index"
 
 
 BASE_DIRECTORIES = [
 BASE_DIRECTORIES = [
-    [OBJECTDIR], 
-    [OBJECTDIR, "info"], 
-    [OBJECTDIR, "pack"],
     ["branches"],
     ["branches"],
     [REFSDIR],
     [REFSDIR],
     [REFSDIR, REFSDIR_TAGS],
     [REFSDIR, REFSDIR_TAGS],
@@ -73,7 +75,7 @@ BASE_DIRECTORIES = [
 def read_info_refs(f):
 def read_info_refs(f):
     ret = {}
     ret = {}
     for l in f.readlines():
     for l in f.readlines():
-        (sha, name) = l.rstrip("\n").split("\t", 1)
+        (sha, name) = l.rstrip("\r\n").split("\t", 1)
         ret[name] = sha
         ret[name] = sha
     return ret
     return ret
 
 
@@ -114,12 +116,18 @@ class RefsContainer(object):
     """A container for refs."""
     """A container for refs."""
 
 
     def set_ref(self, name, other):
     def set_ref(self, name, other):
+        warnings.warn("RefsContainer.set_ref() is deprecated."
+            "Use set_symblic_ref instead.",
+            category=DeprecationWarning, stacklevel=2)
+        return self.set_symbolic_ref(name, other)
+
+    def set_symbolic_ref(self, name, other):
         """Make a ref point at another ref.
         """Make a ref point at another ref.
 
 
         :param name: Name of the ref to set
         :param name: Name of the ref to set
         :param other: Name of the ref to point at
         :param other: Name of the ref to point at
         """
         """
-        self[name] = SYMREF + other + '\n'
+        raise NotImplementedError(self.set_symbolic_ref)
 
 
     def get_packed_refs(self):
     def get_packed_refs(self):
         """Get contents of the packed-refs file.
         """Get contents of the packed-refs file.
@@ -131,14 +139,28 @@ class RefsContainer(object):
         """
         """
         raise NotImplementedError(self.get_packed_refs)
         raise NotImplementedError(self.get_packed_refs)
 
 
+    def get_peeled(self, name):
+        """Return the cached peeled value of a ref, if available.
+
+        :param name: Name of the ref to peel
+        :return: The peeled value of the ref. If the ref is known not point to a
+            tag, this will be the SHA the ref refers to. If the ref may point to
+            a tag, but no cached information is available, None is returned.
+        """
+        return None
+
     def import_refs(self, base, other):
     def import_refs(self, base, other):
         for name, value in other.iteritems():
         for name, value in other.iteritems():
             self["%s/%s" % (base, name)] = value
             self["%s/%s" % (base, name)] = value
 
 
+    def allkeys(self):
+        """All refs present in this container."""
+        raise NotImplementedError(self.allkeys)
+
     def keys(self, base=None):
     def keys(self, base=None):
         """Refs present in this container.
         """Refs present in this container.
 
 
-        :param base: An optional base to return refs under
+        :param base: An optional base to return refs under.
         :return: An unsorted set of valid refs in this container, including
         :return: An unsorted set of valid refs in this container, including
             packed refs.
             packed refs.
         """
         """
@@ -148,10 +170,17 @@ class RefsContainer(object):
             return self.allkeys()
             return self.allkeys()
 
 
     def subkeys(self, base):
     def subkeys(self, base):
+        """Refs present in this container under a base.
+
+        :param base: The base to return refs under.
+        :return: A set of valid refs in this container under the base; the base
+            prefix is stripped from the ref names returned.
+        """
         keys = set()
         keys = set()
+        base_len = len(base) + 1
         for refname in self.allkeys():
         for refname in self.allkeys():
             if refname.startswith(base):
             if refname.startswith(base):
-                keys.add(refname)
+                keys.add(refname[base_len:])
         return keys
         return keys
 
 
     def as_dict(self, base=None):
     def as_dict(self, base=None):
@@ -186,11 +215,23 @@ class RefsContainer(object):
         if not name.startswith('refs/') or not check_ref_format(name[5:]):
         if not name.startswith('refs/') or not check_ref_format(name[5:]):
             raise KeyError(name)
             raise KeyError(name)
 
 
+    def read_ref(self, refname):
+        """Read a reference without following any references.
+
+        :param refname: The name of the reference
+        :return: The contents of the ref file, or None if it does
+            not exist.
+        """
+        contents = self.read_loose_ref(refname)
+        if not contents:
+            contents = self.get_packed_refs().get(refname, None)
+        return contents
+
     def read_loose_ref(self, name):
     def read_loose_ref(self, name):
         """Read a loose reference and return its contents.
         """Read a loose reference and return its contents.
 
 
         :param name: the refname to read
         :param name: the refname to read
-        :return: The contents of the ref file, or None if it does 
+        :return: The contents of the ref file, or None if it does
             not exist.
             not exist.
         """
         """
         raise NotImplementedError(self.read_loose_ref)
         raise NotImplementedError(self.read_loose_ref)
@@ -206,16 +247,19 @@ class RefsContainer(object):
         depth = 0
         depth = 0
         while contents.startswith(SYMREF):
         while contents.startswith(SYMREF):
             refname = contents[len(SYMREF):]
             refname = contents[len(SYMREF):]
-            contents = self.read_loose_ref(refname)
+            contents = self.read_ref(refname)
             if not contents:
             if not contents:
-                contents = self.get_packed_refs().get(refname, None)
-                if not contents:
-                    break
+                break
             depth += 1
             depth += 1
             if depth > 5:
             if depth > 5:
                 raise KeyError(name)
                 raise KeyError(name)
         return refname, contents
         return refname, contents
 
 
+    def __contains__(self, refname):
+        if self.read_ref(refname):
+            return True
+        return False
+
     def __getitem__(self, name):
     def __getitem__(self, name):
         """Get the SHA1 for a reference name.
         """Get the SHA1 for a reference name.
 
 
@@ -226,8 +270,74 @@ class RefsContainer(object):
             raise KeyError(name)
             raise KeyError(name)
         return sha
         return sha
 
 
+    def set_if_equals(self, name, old_ref, new_ref):
+        """Set a refname to new_ref only if it currently equals old_ref.
+
+        This method follows all symbolic references if applicable for the
+        subclass, and can be used to perform an atomic compare-and-swap
+        operation.
+
+        :param name: The refname to set.
+        :param old_ref: The old sha the refname must refer to, or None to set
+            unconditionally.
+        :param new_ref: The new sha the refname will refer to.
+        :return: True if the set was successful, False otherwise.
+        """
+        raise NotImplementedError(self.set_if_equals)
+
+    def add_if_new(self, name, ref):
+        """Add a new reference only if it does not already exist."""
+        raise NotImplementedError(self.add_if_new)
+
+    def __setitem__(self, name, ref):
+        """Set a reference name to point to the given SHA1.
+
+        This method follows all symbolic references if applicable for the
+        subclass.
+
+        :note: This method unconditionally overwrites the contents of a
+            reference. To update atomically only if the reference has not
+            changed, use set_if_equals().
+        :param name: The refname to set.
+        :param ref: The new sha the refname will refer to.
+        """
+        self.set_if_equals(name, None, ref)
+
+    def remove_if_equals(self, name, old_ref):
+        """Remove a refname only if it currently equals old_ref.
+
+        This method does not follow symbolic references, even if applicable for
+        the subclass. It can be used to perform an atomic compare-and-delete
+        operation.
+
+        :param name: The refname to delete.
+        :param old_ref: The old sha the refname must refer to, or None to delete
+            unconditionally.
+        :return: True if the delete was successful, False otherwise.
+        """
+        raise NotImplementedError(self.remove_if_equals)
+
+    def __delitem__(self, name):
+        """Remove a refname.
+
+        This method does not follow symbolic references, even if applicable for
+        the subclass.
+
+        :note: This method unconditionally deletes the contents of a reference.
+            To delete atomically only if the reference has not changed, use
+            remove_if_equals().
+
+        :param name: The refname to delete.
+        """
+        self.remove_if_equals(name, None)
+
 
 
 class DictRefsContainer(RefsContainer):
 class DictRefsContainer(RefsContainer):
+    """RefsContainer backed by a simple dict.
+
+    This container does not support symbolic or packed references and is not
+    threadsafe.
+    """
 
 
     def __init__(self, refs):
     def __init__(self, refs):
         self._refs = refs
         self._refs = refs
@@ -236,7 +346,32 @@ class DictRefsContainer(RefsContainer):
         return self._refs.keys()
         return self._refs.keys()
 
 
     def read_loose_ref(self, name):
     def read_loose_ref(self, name):
-        return self._refs[name]
+        return self._refs.get(name, None)
+
+    def get_packed_refs(self):
+        return {}
+
+    def set_symbolic_ref(self, name, other):
+        self._refs[name] = SYMREF + other
+
+    def set_if_equals(self, name, old_ref, new_ref):
+        if old_ref is not None and self._refs.get(name, None) != old_ref:
+            return False
+        realname, _ = self._follow(name)
+        self._refs[realname] = new_ref
+        return True
+
+    def add_if_new(self, name, ref):
+        if name in self._refs:
+            return False
+        self._refs[name] = ref
+        return True
+
+    def remove_if_equals(self, name, old_ref):
+        if old_ref is not None and self._refs.get(name, None) != old_ref:
+            return False
+        del self._refs[name]
+        return True
 
 
 
 
 class DiskRefsContainer(RefsContainer):
 class DiskRefsContainer(RefsContainer):
@@ -245,7 +380,7 @@ class DiskRefsContainer(RefsContainer):
     def __init__(self, path):
     def __init__(self, path):
         self.path = path
         self.path = path
         self._packed_refs = None
         self._packed_refs = None
-        self._peeled_refs = {}
+        self._peeled_refs = None
 
 
     def __repr__(self):
     def __repr__(self):
         return "%s(%r)" % (self.__class__.__name__, self.path)
         return "%s(%r)" % (self.__class__.__name__, self.path)
@@ -298,7 +433,10 @@ class DiskRefsContainer(RefsContainer):
         """
         """
         # TODO: invalidate the cache on repacking
         # TODO: invalidate the cache on repacking
         if self._packed_refs is None:
         if self._packed_refs is None:
+            # set both to empty because we want _peeled_refs to be
+            # None if and only if _packed_refs is also None.
             self._packed_refs = {}
             self._packed_refs = {}
+            self._peeled_refs = {}
             path = os.path.join(self.path, 'packed-refs')
             path = os.path.join(self.path, 'packed-refs')
             try:
             try:
                 f = GitFile(path, 'rb')
                 f = GitFile(path, 'rb')
@@ -322,6 +460,24 @@ class DiskRefsContainer(RefsContainer):
                 f.close()
                 f.close()
         return self._packed_refs
         return self._packed_refs
 
 
+    def get_peeled(self, name):
+        """Return the cached peeled value of a ref, if available.
+
+        :param name: Name of the ref to peel
+        :return: The peeled value of the ref. If the ref is known not point to a
+            tag, this will be the SHA the ref refers to. If the ref may point to
+            a tag, but no cached information is available, None is returned.
+        """
+        self.get_packed_refs()
+        if self._peeled_refs is None or name not in self._packed_refs:
+            # No cache: no peeled refs were read, or this ref is loose
+            return None
+        if name in self._peeled_refs:
+            return self._peeled_refs[name]
+        else:
+            # Known not peelable
+            return self[name]
+
     def read_loose_ref(self, name):
     def read_loose_ref(self, name):
         """Read a reference file and return its contents.
         """Read a reference file and return its contents.
 
 
@@ -340,7 +496,7 @@ class DiskRefsContainer(RefsContainer):
                 header = f.read(len(SYMREF))
                 header = f.read(len(SYMREF))
                 if header == SYMREF:
                 if header == SYMREF:
                     # Read only the first line
                     # Read only the first line
-                    return header + iter(f).next().rstrip("\n")
+                    return header + iter(f).next().rstrip("\r\n")
                 else:
                 else:
                     # Read only the first 40 bytes
                     # Read only the first 40 bytes
                     return header + f.read(40-len(SYMREF))
                     return header + f.read(40-len(SYMREF))
@@ -372,6 +528,25 @@ class DiskRefsContainer(RefsContainer):
         finally:
         finally:
             f.abort()
             f.abort()
 
 
+    def set_symbolic_ref(self, name, other):
+        """Make a ref point at another ref.
+
+        :param name: Name of the ref to set
+        :param other: Name of the ref to point at
+        """
+        self._check_refname(name)
+        self._check_refname(other)
+        filename = self.refpath(name)
+        try:
+            f = GitFile(filename, 'wb')
+            try:
+                f.write(SYMREF + other + '\n')
+            except (IOError, OSError):
+                f.abort()
+                raise
+        finally:
+            f.close()
+
     def set_if_equals(self, name, old_ref, new_ref):
     def set_if_equals(self, name, old_ref, new_ref):
         """Set a refname to new_ref only if it currently equals old_ref.
         """Set a refname to new_ref only if it currently equals old_ref.
 
 
@@ -414,9 +589,23 @@ class DiskRefsContainer(RefsContainer):
         return True
         return True
 
 
     def add_if_new(self, name, ref):
     def add_if_new(self, name, ref):
-        """Add a new reference only if it does not already exist."""
-        self._check_refname(name)
-        filename = self.refpath(name)
+        """Add a new reference only if it does not already exist.
+
+        This method follows symrefs, and only ensures that the last ref in the
+        chain does not exist.
+
+        :param name: The refname to set.
+        :param ref: The new sha the refname will refer to.
+        :return: True if the add was successful, False otherwise.
+        """
+        try:
+            realname, contents = self._follow(name)
+            if contents is not None:
+                return False
+        except KeyError:
+            realname = name
+        self._check_refname(realname)
+        filename = self.refpath(realname)
         ensure_dir_exists(os.path.dirname(filename))
         ensure_dir_exists(os.path.dirname(filename))
         f = GitFile(filename, 'wb')
         f = GitFile(filename, 'wb')
         try:
         try:
@@ -432,17 +621,6 @@ class DiskRefsContainer(RefsContainer):
             f.close()
             f.close()
         return True
         return True
 
 
-    def __setitem__(self, name, ref):
-        """Set a reference name to point to the given SHA1.
-
-        This method follows all symbolic references.
-
-        :note: This method unconditionally overwrites the contents of a reference
-            on disk. To update atomically only if the reference has not changed
-            on disk, use set_if_equals().
-        """
-        self.set_if_equals(name, None, ref)
-
     def remove_if_equals(self, name, old_ref):
     def remove_if_equals(self, name, old_ref):
         """Remove a refname only if it currently equals old_ref.
         """Remove a refname only if it currently equals old_ref.
 
 
@@ -477,16 +655,6 @@ class DiskRefsContainer(RefsContainer):
             f.abort()
             f.abort()
         return True
         return True
 
 
-    def __delitem__(self, name):
-        """Remove a refname.
-
-        This method does not follow symbolic references.
-        :note: This method unconditionally deletes the contents of a reference
-            on disk. To delete atomically only if the reference has not changed
-            on disk, use set_if_equals().
-        """
-        self.remove_if_equals(name, None)
-
 
 
 def _split_ref_line(line):
 def _split_ref_line(line):
     """Split a single ref line into a tuple of SHA1 and name."""
     """Split a single ref line into a tuple of SHA1 and name."""
@@ -516,7 +684,7 @@ def read_packed_refs(f):
             continue
             continue
         if l[0] == "^":
         if l[0] == "^":
             raise PackedRefsException(
             raise PackedRefsException(
-                "found peeled ref in packed-refs without peeled")
+              "found peeled ref in packed-refs without peeled")
         yield _split_ref_line(l)
         yield _split_ref_line(l)
 
 
 
 
@@ -532,7 +700,7 @@ def read_packed_refs_with_peeled(f):
     for l in f:
     for l in f:
         if l[0] == "#":
         if l[0] == "#":
             continue
             continue
-        l = l.rstrip("\n")
+        l = l.rstrip("\r\n")
         if l[0] == "^":
         if l[0] == "^":
             if not last:
             if not last:
                 raise PackedRefsException("unexpected peeled ref line")
                 raise PackedRefsException("unexpected peeled ref line")
@@ -558,6 +726,7 @@ def write_packed_refs(f, packed_refs, peeled_refs=None):
 
 
     :param f: empty file-like object to write to
     :param f: empty file-like object to write to
     :param packed_refs: dict of refname to sha of packed refs to write
     :param packed_refs: dict of refname to sha of packed refs to write
+    :param peeled_refs: dict of refname to peeled value of sha
     """
     """
     if peeled_refs is None:
     if peeled_refs is None:
         peeled_refs = {}
         peeled_refs = {}
@@ -595,7 +764,7 @@ class BaseRepo(object):
 
 
     def open_index(self):
     def open_index(self):
         """Open the index for this repository.
         """Open the index for this repository.
-        
+
         :raises NoIndexPresent: If no index is present
         :raises NoIndexPresent: If no index is present
         :return: Index instance
         :return: Index instance
         """
         """
@@ -605,33 +774,39 @@ class BaseRepo(object):
         """Fetch objects into another repository.
         """Fetch objects into another repository.
 
 
         :param target: The target repository
         :param target: The target repository
-        :param determine_wants: Optional function to determine what refs to 
+        :param determine_wants: Optional function to determine what refs to
             fetch.
             fetch.
         :param progress: Optional progress function
         :param progress: Optional progress function
         """
         """
         if determine_wants is None:
         if determine_wants is None:
             determine_wants = lambda heads: heads.values()
             determine_wants = lambda heads: heads.values()
         target.object_store.add_objects(
         target.object_store.add_objects(
-            self.fetch_objects(determine_wants, target.get_graph_walker(),
-                progress))
+          self.fetch_objects(determine_wants, target.get_graph_walker(),
+                             progress))
         return self.get_refs()
         return self.get_refs()
 
 
-    def fetch_objects(self, determine_wants, graph_walker, progress):
+    def fetch_objects(self, determine_wants, graph_walker, progress,
+                      get_tagged=None):
         """Fetch the missing objects required for a set of revisions.
         """Fetch the missing objects required for a set of revisions.
 
 
-        :param determine_wants: Function that takes a dictionary with heads 
+        :param determine_wants: Function that takes a dictionary with heads
             and returns the list of heads to fetch.
             and returns the list of heads to fetch.
-        :param graph_walker: Object that can iterate over the list of revisions 
-            to fetch and has an "ack" method that will be called to acknowledge 
+        :param graph_walker: Object that can iterate over the list of revisions
+            to fetch and has an "ack" method that will be called to acknowledge
             that a revision is present.
             that a revision is present.
-        :param progress: Simple progress function that will be called with 
+        :param progress: Simple progress function that will be called with
             updated progress strings.
             updated progress strings.
+        :param get_tagged: Function that returns a dict of pointed-to sha -> tag
+            sha for including tags.
         :return: iterator over objects, with __len__ implemented
         :return: iterator over objects, with __len__ implemented
         """
         """
         wants = determine_wants(self.get_refs())
         wants = determine_wants(self.get_refs())
+        if not wants:
+            return []
         haves = self.object_store.find_common_revisions(graph_walker)
         haves = self.object_store.find_common_revisions(graph_walker)
         return self.object_store.iter_shas(
         return self.object_store.iter_shas(
-            self.object_store.find_missing_objects(haves, wants, progress))
+          self.object_store.find_missing_objects(haves, wants, progress,
+                                                 get_tagged))
 
 
     def get_graph_walker(self, heads=None):
     def get_graph_walker(self, heads=None):
         if heads is None:
         if heads is None:
@@ -653,15 +828,18 @@ class BaseRepo(object):
     def _get_object(self, sha, cls):
     def _get_object(self, sha, cls):
         assert len(sha) in (20, 40)
         assert len(sha) in (20, 40)
         ret = self.get_object(sha)
         ret = self.get_object(sha)
-        if ret._type != cls._type:
+        if not isinstance(ret, cls):
             if cls is Commit:
             if cls is Commit:
                 raise NotCommitError(ret)
                 raise NotCommitError(ret)
             elif cls is Blob:
             elif cls is Blob:
                 raise NotBlobError(ret)
                 raise NotBlobError(ret)
             elif cls is Tree:
             elif cls is Tree:
                 raise NotTreeError(ret)
                 raise NotTreeError(ret)
+            elif cls is Tag:
+                raise NotTagError(ret)
             else:
             else:
-                raise Exception("Type invalid: %r != %r" % (ret._type, cls._type))
+                raise Exception("Type invalid: %r != %r" % (
+                  ret.type_name, cls.type_name))
         return ret
         return ret
 
 
     def get_object(self, sha):
     def get_object(self, sha):
@@ -678,17 +856,71 @@ class BaseRepo(object):
                     for section in p.sections())
                     for section in p.sections())
 
 
     def commit(self, sha):
     def commit(self, sha):
+        """Retrieve the commit with a particular SHA.
+
+        :param sha: SHA of the commit to retrieve
+        :raise NotCommitError: If the SHA provided doesn't point at a Commit
+        :raise KeyError: If the SHA provided didn't exist
+        :return: A `Commit` object
+        """
+        warnings.warn("Repo.commit(sha) is deprecated. Use Repo[sha] instead.",
+            category=DeprecationWarning, stacklevel=2)
         return self._get_object(sha, Commit)
         return self._get_object(sha, Commit)
 
 
     def tree(self, sha):
     def tree(self, sha):
+        """Retrieve the tree with a particular SHA.
+
+        :param sha: SHA of the tree to retrieve
+        :raise NotTreeError: If the SHA provided doesn't point at a Tree
+        :raise KeyError: If the SHA provided didn't exist
+        :return: A `Tree` object
+        """
+        warnings.warn("Repo.tree(sha) is deprecated. Use Repo[sha] instead.",
+            category=DeprecationWarning, stacklevel=2)
         return self._get_object(sha, Tree)
         return self._get_object(sha, Tree)
 
 
     def tag(self, sha):
     def tag(self, sha):
+        """Retrieve the tag with a particular SHA.
+
+        :param sha: SHA of the tag to retrieve
+        :raise NotTagError: If the SHA provided doesn't point at a Tag
+        :raise KeyError: If the SHA provided didn't exist
+        :return: A `Tag` object
+        """
+        warnings.warn("Repo.tag(sha) is deprecated. Use Repo[sha] instead.",
+            category=DeprecationWarning, stacklevel=2)
         return self._get_object(sha, Tag)
         return self._get_object(sha, Tag)
 
 
     def get_blob(self, sha):
     def get_blob(self, sha):
+        """Retrieve the blob with a particular SHA.
+
+        :param sha: SHA of the blob to retrieve
+        :raise NotBlobError: If the SHA provided doesn't point at a Blob
+        :raise KeyError: If the SHA provided didn't exist
+        :return: A `Blob` object
+        """
+        warnings.warn("Repo.get_blob(sha) is deprecated. Use Repo[sha] "
+            "instead.", category=DeprecationWarning, stacklevel=2)
         return self._get_object(sha, Blob)
         return self._get_object(sha, Blob)
 
 
+    def get_peeled(self, ref):
+        """Get the peeled value of a ref.
+
+        :param ref: the refname to peel
+        :return: the fully-peeled SHA1 of a tag object, after peeling all
+            intermediate tags; if the original ref does not point to a tag, this
+            will equal the original SHA1.
+        """
+        cached = self.refs.get_peeled(ref)
+        if cached is not None:
+            return cached
+        obj = self[ref]
+        obj_class = object_class(obj.type_name)
+        while obj_class is Tag:
+            obj_class, sha = obj.object
+            obj = self.get_object(sha)
+        return obj.id
+
     def revision_history(self, head):
     def revision_history(self, head):
         """Returns a list of the commits reachable from head.
         """Returns a list of the commits reachable from head.
 
 
@@ -707,9 +939,11 @@ class BaseRepo(object):
         while pending_commits != []:
         while pending_commits != []:
             head = pending_commits.pop(0)
             head = pending_commits.pop(0)
             try:
             try:
-                commit = self.commit(head)
+                commit = self[head]
             except KeyError:
             except KeyError:
                 raise MissingCommitError(head)
                 raise MissingCommitError(head)
+            if type(commit) != Commit:
+                raise NotCommitError(commit)
             if commit in history:
             if commit in history:
                 continue
                 continue
             i = 0
             i = 0
@@ -718,16 +952,24 @@ class BaseRepo(object):
                     break
                     break
                 i += 1
                 i += 1
             history.insert(i, commit)
             history.insert(i, commit)
-            parents = commit.parents
-            pending_commits += parents
+            pending_commits += commit.parents
         history.reverse()
         history.reverse()
         return history
         return history
 
 
     def __getitem__(self, name):
     def __getitem__(self, name):
         if len(name) in (20, 40):
         if len(name) in (20, 40):
-            return self.object_store[name]
+            try:
+                return self.object_store[name]
+            except KeyError:
+                pass
         return self.object_store[self.refs[name]]
         return self.object_store[self.refs[name]]
 
 
+    def __contains__(self, name):
+        if len(name) in (20, 40):
+            return name in self.object_store or name in self.refs
+        else:
+            return name in self.refs
+
     def __setitem__(self, name, value):
     def __setitem__(self, name, value):
         if name.startswith("refs/") or name == "HEAD":
         if name.startswith("refs/") or name == "HEAD":
             if isinstance(value, ShaFile):
             if isinstance(value, ShaFile):
@@ -736,43 +978,47 @@ class BaseRepo(object):
                 self.refs[name] = value
                 self.refs[name] = value
             else:
             else:
                 raise TypeError(value)
                 raise TypeError(value)
-        raise ValueError(name)
+        else:
+            raise ValueError(name)
 
 
     def __delitem__(self, name):
     def __delitem__(self, name):
         if name.startswith("refs") or name == "HEAD":
         if name.startswith("refs") or name == "HEAD":
             del self.refs[name]
             del self.refs[name]
         raise ValueError(name)
         raise ValueError(name)
 
 
-    def do_commit(self, committer, message,
+    def do_commit(self, message, committer=None,
                   author=None, commit_timestamp=None,
                   author=None, commit_timestamp=None,
-                  commit_timezone=None, author_timestamp=None, 
+                  commit_timezone=None, author_timestamp=None,
                   author_timezone=None, tree=None):
                   author_timezone=None, tree=None):
         """Create a new commit.
         """Create a new commit.
 
 
-        :param committer: Committer fullname
         :param message: Commit message
         :param message: Commit message
+        :param committer: Committer fullname
         :param author: Author fullname (defaults to committer)
         :param author: Author fullname (defaults to committer)
         :param commit_timestamp: Commit timestamp (defaults to now)
         :param commit_timestamp: Commit timestamp (defaults to now)
         :param commit_timezone: Commit timestamp timezone (defaults to GMT)
         :param commit_timezone: Commit timestamp timezone (defaults to GMT)
         :param author_timestamp: Author timestamp (defaults to commit timestamp)
         :param author_timestamp: Author timestamp (defaults to commit timestamp)
-        :param author_timezone: Author timestamp timezone 
+        :param author_timezone: Author timestamp timezone
             (defaults to commit timestamp timezone)
             (defaults to commit timestamp timezone)
         :param tree: SHA1 of the tree root to use (if not specified the current index will be committed).
         :param tree: SHA1 of the tree root to use (if not specified the current index will be committed).
         :return: New commit SHA1
         :return: New commit SHA1
         """
         """
-        from dulwich.index import commit_index
         import time
         import time
         index = self.open_index()
         index = self.open_index()
         c = Commit()
         c = Commit()
         if tree is None:
         if tree is None:
-            c.tree = commit_index(self.object_store, index)
+            c.tree = index.commit(self.object_store)
         else:
         else:
             c.tree = tree
             c.tree = tree
+        # TODO: Allow username to be missing, and get it from .git/config
+        if committer is None:
+            raise ValueError("committer not set")
         c.committer = committer
         c.committer = committer
         if commit_timestamp is None:
         if commit_timestamp is None:
             commit_timestamp = time.time()
             commit_timestamp = time.time()
         c.commit_time = int(commit_timestamp)
         c.commit_time = int(commit_timestamp)
         if commit_timezone is None:
         if commit_timezone is None:
+            # FIXME: Use current user timezone rather than UTC
             commit_timezone = 0
             commit_timezone = 0
         c.commit_timezone = commit_timezone
         c.commit_timezone = commit_timezone
         if author is None:
         if author is None:
@@ -785,8 +1031,20 @@ class BaseRepo(object):
             author_timezone = commit_timezone
             author_timezone = commit_timezone
         c.author_timezone = author_timezone
         c.author_timezone = author_timezone
         c.message = message
         c.message = message
-        self.object_store.add_object(c)
-        self.refs["HEAD"] = c.id
+        try:
+            old_head = self.refs["HEAD"]
+            c.parents = [old_head]
+            self.object_store.add_object(c)
+            ok = self.refs.set_if_equals("HEAD", old_head, c.id)
+        except KeyError:
+            c.parents = []
+            self.object_store.add_object(c)
+            ok = self.refs.add_if_new("HEAD", c.id)
+        if not ok:
+            # Fail if the atomic compare-and-swap failed, leaving the commit and
+            # all its objects as garbage.
+            raise CommitError("HEAD changed during commit")
+
         return c.id
         return c.id
 
 
 
 
@@ -804,8 +1062,8 @@ class Repo(BaseRepo):
         else:
         else:
             raise NotGitRepository(root)
             raise NotGitRepository(root)
         self.path = root
         self.path = root
-        object_store = DiskObjectStore(
-            os.path.join(self.controldir(), OBJECTDIR))
+        object_store = DiskObjectStore(os.path.join(self.controldir(),
+                                                    OBJECTDIR))
         refs = DiskRefsContainer(self.controldir())
         refs = DiskRefsContainer(self.controldir())
         BaseRepo.__init__(self, object_store, refs)
         BaseRepo.__init__(self, object_store, refs)
 
 
@@ -852,7 +1110,40 @@ class Repo(BaseRepo):
 
 
     def has_index(self):
     def has_index(self):
         """Check if an index is present."""
         """Check if an index is present."""
-        return os.path.exists(self.index_path())
+        # Bare repos must never have index files; non-bare repos may have a
+        # missing index file, which is treated as empty.
+        return not self.bare
+
+    def stage(self, paths):
+        """Stage a set of paths.
+
+        :param paths: List of paths, relative to the repository path
+        """
+        from dulwich.index import cleanup_mode
+        index = self.open_index()
+        for path in paths:
+            full_path = os.path.join(self.path, path)
+            blob = Blob()
+            try:
+                st = os.stat(full_path)
+            except OSError:
+                # File no longer exists
+                try:
+                    del index[path]
+                except KeyError:
+                    pass  # Doesn't exist in the index either
+            else:
+                f = open(full_path, 'rb')
+                try:
+                    blob.data = f.read()
+                finally:
+                    f.close()
+                self.object_store.add_object(blob)
+                # XXX: Cleanup some of the other file properties as well?
+                index[path] = (st.st_ctime, st.st_mtime, st.st_dev, st.st_ino,
+                    cleanup_mode(st.st_mode), st.st_uid, st.st_gid, st.st_size,
+                    blob.id, 0)
+        index.write()
 
 
     def __repr__(self):
     def __repr__(self):
         return "<Repo at %r>" % self.path
         return "<Repo at %r>" % self.path
@@ -868,8 +1159,9 @@ class Repo(BaseRepo):
     def init_bare(cls, path, mkdir=True):
     def init_bare(cls, path, mkdir=True):
         for d in BASE_DIRECTORIES:
         for d in BASE_DIRECTORIES:
             os.mkdir(os.path.join(path, *d))
             os.mkdir(os.path.join(path, *d))
+        DiskObjectStore.init(os.path.join(path, OBJECTDIR))
         ret = cls(path)
         ret = cls(path)
-        ret.refs.set_ref("HEAD", "refs/heads/master")
+        ret.refs.set_symbolic_ref("HEAD", "refs/heads/master")
         ret._put_named_file('description', "Unnamed repository")
         ret._put_named_file('description', "Unnamed repository")
         ret._put_named_file('config', """[core]
         ret._put_named_file('config', """[core]
     repositoryformatversion = 0
     repositoryformatversion = 0
@@ -877,7 +1169,7 @@ class Repo(BaseRepo):
     bare = false
     bare = false
     logallrefupdates = true
     logallrefupdates = true
 """)
 """)
-        ret._put_named_file(os.path.join('info', 'excludes'), '')
+        ret._put_named_file(os.path.join('info', 'exclude'), '')
         return ret
         return ret
 
 
     create = init_bare
     create = init_bare

+ 251 - 143
dulwich/server.py

@@ -27,36 +27,55 @@ Documentation/technical directory in the cgit distribution, and in particular:
 
 
 
 
 import collections
 import collections
+import socket
+import zlib
 import SocketServer
 import SocketServer
-import tempfile
 
 
 from dulwich.errors import (
 from dulwich.errors import (
     ApplyDeltaError,
     ApplyDeltaError,
     ChecksumMismatch,
     ChecksumMismatch,
     GitProtocolError,
     GitProtocolError,
+    ObjectFormatException,
     )
     )
 from dulwich.objects import (
 from dulwich.objects import (
     hex_to_sha,
     hex_to_sha,
     )
     )
+from dulwich.pack import (
+    PackStreamReader,
+    write_pack_data,
+    )
 from dulwich.protocol import (
 from dulwich.protocol import (
-    Protocol,
+    MULTI_ACK,
+    MULTI_ACK_DETAILED,
     ProtocolFile,
     ProtocolFile,
+    ReceivableProtocol,
+    SINGLE_ACK,
     TCP_GIT_PORT,
     TCP_GIT_PORT,
+    ZERO_SHA,
+    ack_type,
     extract_capabilities,
     extract_capabilities,
     extract_want_line_capabilities,
     extract_want_line_capabilities,
-    SINGLE_ACK,
-    MULTI_ACK,
-    MULTI_ACK_DETAILED,
-    ack_type,
-    )
-from dulwich.repo import (
-    Repo,
-    )
-from dulwich.pack import (
-    write_pack_data,
     )
     )
 
 
+
+
 class Backend(object):
 class Backend(object):
+    """A backend for the Git smart server implementation."""
+
+    def open_repository(self, path):
+        """Open the repository at a path."""
+        raise NotImplementedError(self.open_repository)
+
+
+class BackendRepo(object):
+    """Repository abstraction used by the Git server.
+    
+    Please note that the methods required here are a 
+    subset of those provided by dulwich.repo.Repo.
+    """
+
+    object_store = None
+    refs = None
 
 
     def get_refs(self):
     def get_refs(self):
         """
         """
@@ -66,144 +85,177 @@ class Backend(object):
         """
         """
         raise NotImplementedError
         raise NotImplementedError
 
 
-    def apply_pack(self, refs, read):
-        """ Import a set of changes into a repository and update the refs
+    def get_peeled(self, name):
+        """Return the cached peeled value of a ref, if available.
 
 
-        :param refs: list of tuple(name, sha)
-        :param read: callback to read from the incoming pack
+        :param name: Name of the ref to peel
+        :return: The peeled value of the ref. If the ref is known not point to
+            a tag, this will be the SHA the ref refers to. If no cached
+            information about a tag is available, this method may return None,
+            but it should attempt to peel the tag if possible.
         """
         """
-        raise NotImplementedError
+        return None
 
 
-    def fetch_objects(self, determine_wants, graph_walker, progress):
+    def fetch_objects(self, determine_wants, graph_walker, progress,
+                      get_tagged=None):
         """
         """
         Yield the objects required for a list of commits.
         Yield the objects required for a list of commits.
 
 
         :param progress: is a callback to send progress messages to the client
         :param progress: is a callback to send progress messages to the client
+        :param get_tagged: Function that returns a dict of pointed-to sha -> tag
+            sha for including tags.
         """
         """
         raise NotImplementedError
         raise NotImplementedError
 
 
 
 
-class GitBackend(Backend):
+class PackStreamCopier(PackStreamReader):
+    """Class to verify a pack stream as it is being read.
 
 
-    def __init__(self, repo=None):
-        if repo is None:
-            repo = Repo(tmpfile.mkdtemp())
-        self.repo = repo
-        self.object_store = self.repo.object_store
-        self.fetch_objects = self.repo.fetch_objects
-        self.get_refs = self.repo.get_refs
+    The pack is read from a ReceivableProtocol using read() or recv() as
+    appropriate and written out to the given file-like object.
+    """
 
 
-    def apply_pack(self, refs, read):
-        f, commit = self.repo.object_store.add_thin_pack()
-        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
-        status = []
-        unpack_error = None
-        # TODO: more informative error messages than just the exception string
-        try:
-            # TODO: decode the pack as we stream to avoid blocking reads beyond
-            # the end of data (when using HTTP/1.1 chunked encoding)
-            while True:
-                data = read(10240)
-                if not data:
-                    break
-                f.write(data)
-        except all_exceptions, e:
-            unpack_error = str(e).replace('\n', '')
-        try:
-            commit()
-        except all_exceptions, e:
-            if not unpack_error:
-                unpack_error = str(e).replace('\n', '')
+    def __init__(self, read_all, read_some, outfile):
+        super(PackStreamCopier, self).__init__(read_all, read_some)
+        self.outfile = outfile
 
 
-        if unpack_error:
-            status.append(('unpack', unpack_error))
-        else:
-            status.append(('unpack', 'ok'))
+    def _read(self, read, size):
+        data = super(PackStreamCopier, self)._read(read, size)
+        self.outfile.write(data)
+        return data
 
 
-        for oldsha, sha, ref in refs:
-            # TODO: check refname
-            ref_error = None
-            try:
-                if ref == "0" * 40:
-                    try:
-                        del self.repo.refs[ref]
-                    except all_exceptions:
-                        ref_error = 'failed to delete'
-                else:
-                    try:
-                        self.repo.refs[ref] = sha
-                    except all_exceptions:
-                        ref_error = 'failed to write'
-            except KeyError, e:
-                ref_error = 'bad ref'
-            if ref_error:
-                status.append((ref, ref_error))
-            else:
-                status.append((ref, 'ok'))
+    def verify(self):
+        """Verify a pack stream and write it to the output file.
 
 
+        See PackStreamReader.iterobjects for a list of exceptions this may
+        throw.
+        """
+        for _, _, _ in self.read_objects():
+            pass
 
 
-        print "pack applied"
-        return status
+
+class DictBackend(Backend):
+    """Trivial backend that looks up Git repositories in a dictionary."""
+
+    def __init__(self, repos):
+        self.repos = repos
+
+    def open_repository(self, path):
+        # FIXME: What to do in case there is no repo ?
+        return self.repos[path]
 
 
 
 
 class Handler(object):
 class Handler(object):
     """Smart protocol command handler base class."""
     """Smart protocol command handler base class."""
 
 
-    def __init__(self, backend, read, write):
+    def __init__(self, backend, proto):
         self.backend = backend
         self.backend = backend
-        self.proto = Protocol(read, write)
+        self.proto = proto
+        self._client_capabilities = None
+
+    def capability_line(self):
+        return " ".join(self.capabilities())
 
 
     def capabilities(self):
     def capabilities(self):
-        return " ".join(self.default_capabilities())
+        raise NotImplementedError(self.capabilities)
+
+    def innocuous_capabilities(self):
+        return ("include-tag", "thin-pack", "no-progress", "ofs-delta")
+
+    def required_capabilities(self):
+        """Return a list of capabilities that we require the client to have."""
+        return []
+
+    def set_client_capabilities(self, caps):
+        allowable_caps = set(self.innocuous_capabilities())
+        allowable_caps.update(self.capabilities())
+        for cap in caps:
+            if cap not in allowable_caps:
+                raise GitProtocolError('Client asked for capability %s that '
+                                       'was not advertised.' % cap)
+        for cap in self.required_capabilities():
+            if cap not in caps:
+                raise GitProtocolError('Client does not support required '
+                                       'capability %s.' % cap)
+        self._client_capabilities = set(caps)
+
+    def has_capability(self, cap):
+        if self._client_capabilities is None:
+            raise GitProtocolError('Server attempted to access capability %s '
+                                   'before asking client' % cap)
+        return cap in self._client_capabilities
 
 
 
 
 class UploadPackHandler(Handler):
 class UploadPackHandler(Handler):
     """Protocol handler for uploading a pack to the server."""
     """Protocol handler for uploading a pack to the server."""
 
 
-    def __init__(self, backend, read, write,
+    def __init__(self, backend, args, proto,
                  stateless_rpc=False, advertise_refs=False):
                  stateless_rpc=False, advertise_refs=False):
-        Handler.__init__(self, backend, read, write)
-        self._client_capabilities = None
+        Handler.__init__(self, backend, proto)
+        self.repo = backend.open_repository(args[0])
         self._graph_walker = None
         self._graph_walker = None
         self.stateless_rpc = stateless_rpc
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
         self.advertise_refs = advertise_refs
 
 
-    def default_capabilities(self):
+    def capabilities(self):
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
         return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
-                "ofs-delta")
+                "ofs-delta", "no-progress", "include-tag")
 
 
-    def set_client_capabilities(self, caps):
-        my_caps = self.default_capabilities()
-        for cap in caps:
-            if '_ack' in cap and cap not in my_caps:
-                raise GitProtocolError('Client asked for capability %s that '
-                                       'was not advertised.' % cap)
-        self._client_capabilities = caps
+    def required_capabilities(self):
+        return ("side-band-64k", "thin-pack", "ofs-delta")
 
 
-    def get_client_capabilities(self):
-        return self._client_capabilities
+    def progress(self, message):
+        if self.has_capability("no-progress"):
+            return
+        self.proto.write_sideband(2, message)
 
 
-    client_capabilities = property(get_client_capabilities,
-                                   set_client_capabilities)
+    def get_tagged(self, refs=None, repo=None):
+        """Get a dict of peeled values of tags to their original tag shas.
 
 
-    def handle(self):
+        :param refs: dict of refname -> sha of possible tags; defaults to all of
+            the backend's refs.
+        :param repo: optional Repo instance for getting peeled refs; defaults to
+            the backend's repo, if available
+        :return: dict of peeled_sha -> tag_sha, where tag_sha is the sha of a
+            tag whose peeled value is peeled_sha.
+        """
+        if not self.has_capability("include-tag"):
+            return {}
+        if refs is None:
+            refs = self.repo.get_refs()
+        if repo is None:
+            repo = getattr(self.repo, "repo", None)
+            if repo is None:
+                # Bail if we don't have a Repo available; this is ok since
+                # clients must be able to handle if the server doesn't include
+                # all relevant tags.
+                # TODO: fix behavior when missing
+                return {}
+        tagged = {}
+        for name, sha in refs.iteritems():
+            peeled_sha = repo.get_peeled(name)
+            if peeled_sha != sha:
+                tagged[peeled_sha] = sha
+        return tagged
 
 
-        progress = lambda x: self.proto.write_sideband(2, x)
+    def handle(self):
         write = lambda x: self.proto.write_sideband(1, x)
         write = lambda x: self.proto.write_sideband(1, x)
 
 
-        graph_walker = ProtocolGraphWalker(self)
-        objects_iter = self.backend.fetch_objects(
-          graph_walker.determine_wants, graph_walker, progress)
+        graph_walker = ProtocolGraphWalker(self, self.repo.object_store,
+            self.repo.get_peeled)
+        objects_iter = self.repo.fetch_objects(
+          graph_walker.determine_wants, graph_walker, self.progress,
+          get_tagged=self.get_tagged)
 
 
         # Do they want any objects?
         # Do they want any objects?
         if len(objects_iter) == 0:
         if len(objects_iter) == 0:
             return
             return
 
 
-        progress("dul-daemon says what\n")
-        progress("counting objects: %d, done.\n" % len(objects_iter))
+        self.progress("dul-daemon says what\n")
+        self.progress("counting objects: %d, done.\n" % len(objects_iter))
         write_pack_data(ProtocolFile(None, write), objects_iter, 
         write_pack_data(ProtocolFile(None, write), objects_iter, 
                         len(objects_iter))
                         len(objects_iter))
-        progress("how was that, then?\n")
+        self.progress("how was that, then?\n")
         # we are done
         # we are done
         self.proto.write("0000")
         self.proto.write("0000")
 
 
@@ -211,9 +263,9 @@ class UploadPackHandler(Handler):
 class ProtocolGraphWalker(object):
 class ProtocolGraphWalker(object):
     """A graph walker that knows the git protocol.
     """A graph walker that knows the git protocol.
 
 
-    As a graph walker, this class implements ack(), next(), and reset(). It also
-    contains some base methods for interacting with the wire and walking the
-    commit tree.
+    As a graph walker, this class implements ack(), next(), and reset(). It
+    also contains some base methods for interacting with the wire and walking
+    the commit tree.
 
 
     The work of determining which acks to send is passed on to the
     The work of determining which acks to send is passed on to the
     implementation instance stored in _impl. The reason for this is that we do
     implementation instance stored in _impl. The reason for this is that we do
@@ -221,9 +273,10 @@ class ProtocolGraphWalker(object):
     call to set_ack_level() is required to set up the implementation, before any
     call to set_ack_level() is required to set up the implementation, before any
     calls to next() or ack() are made.
     calls to next() or ack() are made.
     """
     """
-    def __init__(self, handler):
+    def __init__(self, handler, object_store, get_peeled):
         self.handler = handler
         self.handler = handler
-        self.store = handler.backend.object_store
+        self.store = object_store
+        self.get_peeled = get_peeled
         self.proto = handler.proto
         self.proto = handler.proto
         self.stateless_rpc = handler.stateless_rpc
         self.stateless_rpc = handler.stateless_rpc
         self.advertise_refs = handler.advertise_refs
         self.advertise_refs = handler.advertise_refs
@@ -251,9 +304,12 @@ class ProtocolGraphWalker(object):
             for i, (ref, sha) in enumerate(heads.iteritems()):
             for i, (ref, sha) in enumerate(heads.iteritems()):
                 line = "%s %s" % (sha, ref)
                 line = "%s %s" % (sha, ref)
                 if not i:
                 if not i:
-                    line = "%s\x00%s" % (line, self.handler.capabilities())
+                    line = "%s\x00%s" % (line, self.handler.capability_line())
                 self.proto.write_pkt_line("%s\n" % line)
                 self.proto.write_pkt_line("%s\n" % line)
-                # TODO: include peeled value of any tags
+                peeled_sha = self.get_peeled(ref)
+                if peeled_sha != sha:
+                    self.proto.write_pkt_line('%s %s^{}\n' %
+                                              (peeled_sha, ref))
 
 
             # i'm done..
             # i'm done..
             self.proto.write_pkt_line(None)
             self.proto.write_pkt_line(None)
@@ -266,7 +322,7 @@ class ProtocolGraphWalker(object):
         if not want:
         if not want:
             return []
             return []
         line, caps = extract_want_line_capabilities(want)
         line, caps = extract_want_line_capabilities(want)
-        self.handler.client_capabilities = caps
+        self.handler.set_client_capabilities(caps)
         self.set_ack_type(ack_type(caps))
         self.set_ack_type(ack_type(caps))
         command, sha = self._split_proto_line(line)
         command, sha = self._split_proto_line(line)
 
 
@@ -274,10 +330,10 @@ class ProtocolGraphWalker(object):
         while command != None:
         while command != None:
             if command != 'want':
             if command != 'want':
                 raise GitProtocolError(
                 raise GitProtocolError(
-                    'Protocol got unexpected command %s' % command)
+                  'Protocol got unexpected command %s' % command)
             if sha not in values:
             if sha not in values:
                 raise GitProtocolError(
                 raise GitProtocolError(
-                    'Client wants invalid object %s' % sha)
+                  'Client wants invalid object %s' % sha)
             want_revs.append(sha)
             want_revs.append(sha)
             command, sha = self.read_proto_line()
             command, sha = self.read_proto_line()
 
 
@@ -359,10 +415,10 @@ class ProtocolGraphWalker(object):
             commit = pending.popleft()
             commit = pending.popleft()
             if commit.id in haves:
             if commit.id in haves:
                 return True
                 return True
-            if not getattr(commit, 'get_parents', None):
+            if commit.type_name != "commit":
                 # non-commit wants are assumed to be satisfied
                 # non-commit wants are assumed to be satisfied
                 continue
                 continue
-            for parent in commit.get_parents():
+            for parent in commit.parents:
                 parent_obj = self.store[parent]
                 parent_obj = self.store[parent]
                 # TODO: handle parents with later commit times than children
                 # TODO: handle parents with later commit times than children
                 if parent_obj.commit_time >= earliest:
                 if parent_obj.commit_time >= earliest:
@@ -385,10 +441,10 @@ class ProtocolGraphWalker(object):
 
 
     def set_ack_type(self, ack_type):
     def set_ack_type(self, ack_type):
         impl_classes = {
         impl_classes = {
-            MULTI_ACK: MultiAckGraphWalkerImpl,
-            MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
-            SINGLE_ACK: SingleAckGraphWalkerImpl,
-            }
+          MULTI_ACK: MultiAckGraphWalkerImpl,
+          MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
+          SINGLE_ACK: SingleAckGraphWalkerImpl,
+          }
         self._impl = impl_classes[ack_type](self)
         self._impl = impl_classes[ack_type](self)
 
 
 
 
@@ -497,32 +553,72 @@ class MultiAckDetailedGraphWalkerImpl(object):
 class ReceivePackHandler(Handler):
 class ReceivePackHandler(Handler):
     """Protocol handler for downloading a pack from the client."""
     """Protocol handler for downloading a pack from the client."""
 
 
-    def __init__(self, backend, read, write,
+    def __init__(self, backend, args, proto,
                  stateless_rpc=False, advertise_refs=False):
                  stateless_rpc=False, advertise_refs=False):
-        Handler.__init__(self, backend, read, write)
+        Handler.__init__(self, backend, proto)
+        self.repo = backend.open_repository(args[0])
         self.stateless_rpc = stateless_rpc
         self.stateless_rpc = stateless_rpc
         self.advertise_refs = advertise_refs
         self.advertise_refs = advertise_refs
 
 
-    def __init__(self, backend, read, write,
-                 stateless_rpc=False, advertise_refs=False):
-        Handler.__init__(self, backend, read, write)
-        self._stateless_rpc = stateless_rpc
-        self._advertise_refs = advertise_refs
-
-    def default_capabilities(self):
+    def capabilities(self):
         return ("report-status", "delete-refs")
         return ("report-status", "delete-refs")
 
 
+    def _apply_pack(self, refs):
+        f, commit = self.repo.object_store.add_thin_pack()
+        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError,
+                          AssertionError, socket.error, zlib.error,
+                          ObjectFormatException)
+        status = []
+        # TODO: more informative error messages than just the exception string
+        try:
+            PackStreamCopier(self.proto.read, self.proto.recv, f).verify()
+            p = commit()
+            if not p:
+                raise IOError('Failed to write pack')
+            p.check()
+            status.append(('unpack', 'ok'))
+        except all_exceptions, e:
+            status.append(('unpack', str(e).replace('\n', '')))
+            # The pack may still have been moved in, but it may contain broken
+            # objects. We trust a later GC to clean it up.
+
+        for oldsha, sha, ref in refs:
+            ref_status = 'ok'
+            try:
+                if sha == ZERO_SHA:
+                    if not 'delete-refs' in self.capabilities():
+                        raise GitProtocolError(
+                          'Attempted to delete refs without delete-refs '
+                          'capability.')
+                    try:
+                        del self.repo.refs[ref]
+                    except all_exceptions:
+                        ref_status = 'failed to delete'
+                else:
+                    try:
+                        self.repo.refs[ref] = sha
+                    except all_exceptions:
+                        ref_status = 'failed to write'
+            except KeyError, e:
+                ref_status = 'bad ref'
+            status.append((ref, ref_status))
+
+        return status
+
     def handle(self):
     def handle(self):
-        refs = self.backend.get_refs().items()
+        refs = self.repo.get_refs().items()
 
 
         if self.advertise_refs or not self.stateless_rpc:
         if self.advertise_refs or not self.stateless_rpc:
             if refs:
             if refs:
-                self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
+                self.proto.write_pkt_line(
+                  "%s %s\x00%s\n" % (refs[0][1], refs[0][0],
+                                     self.capability_line()))
                 for i in range(1, len(refs)):
                 for i in range(1, len(refs)):
                     ref = refs[i]
                     ref = refs[i]
                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
                     self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
             else:
             else:
-                self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
+                self.proto.write_pkt_line("%s capabilities^{} %s" % (
+                  ZERO_SHA, self.capability_line()))
 
 
             self.proto.write("0000")
             self.proto.write("0000")
             if self.advertise_refs:
             if self.advertise_refs:
@@ -535,7 +631,8 @@ class ReceivePackHandler(Handler):
         if ref is None:
         if ref is None:
             return
             return
 
 
-        ref, client_capabilities = extract_capabilities(ref)
+        ref, caps = extract_capabilities(ref)
+        self.set_client_capabilities(caps)
 
 
         # client will now send us a list of (oldsha, newsha, ref)
         # client will now send us a list of (oldsha, newsha, ref)
         while ref:
         while ref:
@@ -543,11 +640,11 @@ class ReceivePackHandler(Handler):
             ref = self.proto.read_pkt_line()
             ref = self.proto.read_pkt_line()
 
 
         # backend can now deal with this refs and read a pack using self.read
         # backend can now deal with this refs and read a pack using self.read
-        status = self.backend.apply_pack(client_refs, self.proto.read)
+        status = self._apply_pack(client_refs)
 
 
         # when we have read all the pack from the client, send a status report
         # when we have read all the pack from the client, send a status report
         # if the client asked for it
         # if the client asked for it
-        if 'report-status' in client_capabilities:
+        if self.has_capability('report-status'):
             for name, msg in status:
             for name, msg in status:
                 if name == 'unpack':
                 if name == 'unpack':
                     self.proto.write_pkt_line('unpack %s\n' % msg)
                     self.proto.write_pkt_line('unpack %s\n' % msg)
@@ -558,21 +655,27 @@ class ReceivePackHandler(Handler):
             self.proto.write_pkt_line(None)
             self.proto.write_pkt_line(None)
 
 
 
 
+# Default handler classes for git services.
+DEFAULT_HANDLERS = {
+  'git-upload-pack': UploadPackHandler,
+  'git-receive-pack': ReceivePackHandler,
+  }
+
+
 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
 
 
+    def __init__(self, handlers, *args, **kwargs):
+        self.handlers = handlers and handlers or DEFAULT_HANDLERS
+        SocketServer.StreamRequestHandler.__init__(self, *args, **kwargs)
+
     def handle(self):
     def handle(self):
-        proto = Protocol(self.rfile.read, self.wfile.write)
+        proto = ReceivableProtocol(self.connection.recv, self.wfile.write)
         command, args = proto.read_cmd()
         command, args = proto.read_cmd()
 
 
-        # switch case to handle the specific git command
-        if command == 'git-upload-pack':
-            cls = UploadPackHandler
-        elif command == 'git-receive-pack':
-            cls = ReceivePackHandler
-        else:
-            return
-
-        h = cls(self.server.backend, self.rfile.read, self.wfile.write)
+        cls = self.handlers.get(command, None)
+        if not callable(cls):
+            raise GitProtocolError('Invalid service %s' % command)
+        h = cls(self.server.backend, args, proto)
         h.handle()
         h.handle()
 
 
 
 
@@ -581,6 +684,11 @@ class TCPGitServer(SocketServer.TCPServer):
     allow_reuse_address = True
     allow_reuse_address = True
     serve = SocketServer.TCPServer.serve_forever
     serve = SocketServer.TCPServer.serve_forever
 
 
-    def __init__(self, backend, listen_addr, port=TCP_GIT_PORT):
+    def _make_handler(self, *args, **kwargs):
+        return TCPGitRequestHandler(self.handlers, *args, **kwargs)
+
+    def __init__(self, backend, listen_addr, port=TCP_GIT_PORT, handlers=None):
         self.backend = backend
         self.backend = backend
-        SocketServer.TCPServer.__init__(self, (listen_addr, port), TCPGitRequestHandler)
+        self.handlers = handlers
+        SocketServer.TCPServer.__init__(self, (listen_addr, port),
+                                        self._make_handler)

+ 30 - 0
dulwich/tests/__init__.py

@@ -18,3 +18,33 @@
 # MA  02110-1301, USA.
 # MA  02110-1301, USA.
 
 
 """Tests for Dulwich."""
 """Tests for Dulwich."""
+
+import unittest
+
+# XXX: Ideally we should allow other test runners as well, 
+# but unfortunately unittest doesn't have a SkipTest/TestSkipped
+# exception.
+from nose import SkipTest as TestSkipped
+
+def test_suite():
+    names = [
+        'client',
+        'fastexport',
+        'file',
+        'index',
+        'lru_cache',
+        'objects',
+        'object_store',
+        'pack',
+        'patch',
+        'protocol',
+        'repository',
+        'server',
+        'web',
+        ]
+    module_names = ['dulwich.tests.test_' + name for name in names]
+    result = unittest.TestSuite()
+    loader = unittest.TestLoader()
+    suite = loader.loadTestsFromNames(module_names)
+    result.addTests(suite)
+    return result

+ 157 - 0
dulwich/tests/compat/server_utils.py

@@ -0,0 +1,157 @@
+# server_utils.py -- Git server compatibility utilities
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Utilities for testing git server compatibility."""
+
+
+import select
+import socket
+import threading
+
+from dulwich.tests.utils import (
+    tear_down_repo,
+    )
+from utils import (
+    import_repo,
+    run_git,
+    )
+
+
+class ServerTests(object):
+    """Base tests for testing servers.
+
+    Does not inherit from TestCase so tests are not automatically run.
+    """
+
+    def setUp(self):
+        self._old_repo = import_repo('server_old.export')
+        self._new_repo = import_repo('server_new.export')
+        self._server = None
+
+    def tearDown(self):
+        if self._server is not None:
+            self._server.shutdown()
+            self._server = None
+        tear_down_repo(self._old_repo)
+        tear_down_repo(self._new_repo)
+
+    def test_push_to_dulwich(self):
+        self.assertReposNotEqual(self._old_repo, self._new_repo)
+        port = self._start_server(self._old_repo)
+
+        all_branches = ['master', 'branch']
+        branch_args = ['%s:%s' % (b, b) for b in all_branches]
+        url = '%s://localhost:%s/' % (self.protocol, port)
+        returncode, _ = run_git(['push', url] + branch_args,
+                                cwd=self._new_repo.path)
+        self.assertEqual(0, returncode)
+        self.assertReposEqual(self._old_repo, self._new_repo)
+
+    def test_fetch_from_dulwich(self):
+        self.assertReposNotEqual(self._old_repo, self._new_repo)
+        port = self._start_server(self._new_repo)
+
+        all_branches = ['master', 'branch']
+        branch_args = ['%s:%s' % (b, b) for b in all_branches]
+        url = '%s://localhost:%s/' % (self.protocol, port)
+        returncode, _ = run_git(['fetch', url] + branch_args,
+                                cwd=self._old_repo.path)
+        # flush the pack cache so any new packs are picked up
+        self._old_repo.object_store._pack_cache = None
+        self.assertEqual(0, returncode)
+        self.assertReposEqual(self._old_repo, self._new_repo)
+
+
+class ShutdownServerMixIn:
+    """Mixin that allows serve_forever to be shut down.
+
+    The methods in this mixin are backported from SocketServer.py in the Python
+    2.6.4 standard library. The mixin is unnecessary in 2.6 and later, when
+    BaseServer supports the shutdown method directly.
+    """
+
+    def __init__(self):
+        self.__is_shut_down = threading.Event()
+        self.__serving = False
+
+    def serve_forever(self, poll_interval=0.5):
+        """Handle one request at a time until shutdown.
+
+        Polls for shutdown every poll_interval seconds. Ignores
+        self.timeout. If you need to do periodic tasks, do them in
+        another thread.
+        """
+        self.__serving = True
+        self.__is_shut_down.clear()
+        while self.__serving:
+            # XXX: Consider using another file descriptor or
+            # connecting to the socket to wake this up instead of
+            # polling. Polling reduces our responsiveness to a
+            # shutdown request and wastes cpu at all other times.
+            r, w, e = select.select([self], [], [], poll_interval)
+            if r:
+                self._handle_request_noblock()
+        self.__is_shut_down.set()
+
+    serve = serve_forever  # override alias from TCPGitServer
+
+    def shutdown(self):
+        """Stops the serve_forever loop.
+
+        Blocks until the loop has finished. This must be called while
+        serve_forever() is running in another thread, or it will deadlock.
+        """
+        self.__serving = False
+        self.__is_shut_down.wait()
+
+    def handle_request(self):
+        """Handle one request, possibly blocking.
+
+        Respects self.timeout.
+        """
+        # Support people who used socket.settimeout() to escape
+        # handle_request before self.timeout was available.
+        timeout = self.socket.gettimeout()
+        if timeout is None:
+            timeout = self.timeout
+        elif self.timeout is not None:
+            timeout = min(timeout, self.timeout)
+        fd_sets = select.select([self], [], [], timeout)
+        if not fd_sets[0]:
+            self.handle_timeout()
+            return
+        self._handle_request_noblock()
+
+    def _handle_request_noblock(self):
+        """Handle one request, without blocking.
+
+        I assume that select.select has returned that the socket is
+        readable before this function was called, so there should be
+        no risk of blocking in get_request().
+        """
+        try:
+            request, client_address = self.get_request()
+        except socket.error:
+            return
+        if self.verify_request(request, client_address):
+            try:
+                self.process_request(request, client_address)
+            except:
+                self.handle_error(request, client_address)
+                self.close_request(request)

+ 179 - 0
dulwich/tests/compat/test_client.py

@@ -0,0 +1,179 @@
+# test_client.py -- Compatibilty tests for git client.
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Compatibilty tests between the Dulwich client and the cgit server."""
+
+import os
+import shutil
+import signal
+import tempfile
+
+from dulwich import client
+from dulwich import errors
+from dulwich import file
+from dulwich import index
+from dulwich import protocol
+from dulwich import object_store
+from dulwich import objects
+from dulwich import repo
+from dulwich.tests import (
+    TestSkipped,
+    )
+
+from utils import (
+    CompatTestCase,
+    check_for_daemon,
+    import_repo_to_dir,
+    run_git,
+    )
+
+class DulwichClientTest(CompatTestCase):
+    """Tests for client/server compatibility."""
+
+    def setUp(self):
+        if check_for_daemon(limit=1):
+            raise TestSkipped('git-daemon was already running on port %s' %
+                              protocol.TCP_GIT_PORT)
+        CompatTestCase.setUp(self)
+        fd, self.pidfile = tempfile.mkstemp(prefix='dulwich-test-git-client',
+                                            suffix=".pid")
+        os.fdopen(fd).close()
+        self.gitroot = os.path.dirname(import_repo_to_dir('server_new.export'))
+        dest = os.path.join(self.gitroot, 'dest')
+        file.ensure_dir_exists(dest)
+        run_git(['init', '--bare'], cwd=dest)
+        run_git(
+            ['daemon', '--verbose', '--export-all',
+             '--pid-file=%s' % self.pidfile, '--base-path=%s' % self.gitroot,
+             '--detach', '--reuseaddr', '--enable=receive-pack',
+             '--listen=localhost', self.gitroot], cwd=self.gitroot)
+        if not check_for_daemon():
+            raise TestSkipped('git-daemon failed to start')
+
+    def tearDown(self):
+        CompatTestCase.tearDown(self)
+        try:
+            os.kill(int(open(self.pidfile).read().strip()), signal.SIGKILL)
+            os.unlink(self.pidfile)
+        except (OSError, IOError):
+            pass
+        shutil.rmtree(self.gitroot)
+
+    def assertDestEqualsSrc(self):
+        src = repo.Repo(os.path.join(self.gitroot, 'server_new.export'))
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        self.assertReposEqual(src, dest)
+
+    def _do_send_pack(self):
+        c = client.TCPGitClient('localhost')
+        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        src = repo.Repo(srcpath)
+        sendrefs = dict(src.get_refs())
+        del sendrefs['HEAD']
+        c.send_pack('/dest', lambda _: sendrefs,
+                    src.object_store.generate_pack_contents)
+
+    def test_send_pack(self):
+        self._do_send_pack()
+        self.assertDestEqualsSrc()
+
+    def test_send_pack_nothing_to_send(self):
+        self._do_send_pack()
+        self.assertDestEqualsSrc()
+        # nothing to send, but shouldn't raise either.
+        self._do_send_pack()
+
+    def test_send_without_report_status(self):
+        c = client.TCPGitClient('localhost')
+        c._send_capabilities.remove('report-status')
+        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        src = repo.Repo(srcpath)
+        sendrefs = dict(src.get_refs())
+        del sendrefs['HEAD']
+        c.send_pack('/dest', lambda _: sendrefs,
+                    src.object_store.generate_pack_contents)
+        self.assertDestEqualsSrc()
+
+    def disable_ff_and_make_dummy_commit(self):
+        # disable non-fast-forward pushes to the server
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        run_git(['config', 'receive.denyNonFastForwards', 'true'], cwd=dest.path)
+        b = objects.Blob.from_string('hi')
+        dest.object_store.add_object(b)
+        t = index.commit_tree(dest.object_store, [('hi', b.id, 0100644)])
+        c = objects.Commit()
+        c.author = c.committer = 'Foo Bar <foo@example.com>'
+        c.author_time = c.commit_time = 0
+        c.author_timezone = c.commit_timezone = 0
+        c.message = 'hi'
+        c.tree = t
+        dest.object_store.add_object(c)
+        return dest, c.id
+
+    def compute_send(self):
+        srcpath = os.path.join(self.gitroot, 'server_new.export')
+        src = repo.Repo(srcpath)
+        sendrefs = dict(src.get_refs())
+        del sendrefs['HEAD']
+        return sendrefs, src.object_store.generate_pack_contents
+
+    def test_send_pack_one_error(self):
+        dest, dummy_commit = self.disable_ff_and_make_dummy_commit()
+        dest.refs['refs/heads/master'] = dummy_commit
+        sendrefs, gen_pack = self.compute_send()
+        c = client.TCPGitClient('localhost')
+        try:
+            c.send_pack('/dest', lambda _: sendrefs, gen_pack)
+        except errors.UpdateRefsError, e:
+            self.assertEqual('refs/heads/master failed to update', str(e))
+            self.assertEqual({'refs/heads/branch': 'ok',
+                              'refs/heads/master': 'non-fast-forward'},
+                             e.ref_status)
+
+    def test_send_pack_multiple_errors(self):
+        dest, dummy = self.disable_ff_and_make_dummy_commit()
+        # set up for two non-ff errors
+        dest.refs['refs/heads/branch'] = dest.refs['refs/heads/master'] = dummy
+        sendrefs, gen_pack = self.compute_send()
+        c = client.TCPGitClient('localhost')
+        try:
+            c.send_pack('/dest', lambda _: sendrefs, gen_pack)
+        except errors.UpdateRefsError, e:
+            self.assertEqual('refs/heads/branch, refs/heads/master failed to '
+                             'update', str(e))
+            self.assertEqual({'refs/heads/branch': 'non-fast-forward',
+                              'refs/heads/master': 'non-fast-forward'},
+                             e.ref_status)
+
+    def test_fetch_pack(self):
+        c = client.TCPGitClient('localhost')
+        dest = repo.Repo(os.path.join(self.gitroot, 'dest'))
+        refs = c.fetch('/server_new.export', dest)
+        map(lambda r: dest.refs.set_if_equals(r[0], None, r[1]), refs.items())
+        self.assertDestEqualsSrc()
+
+    def test_incremental_fetch_pack(self):
+        self.test_fetch_pack()
+        dest, dummy = self.disable_ff_and_make_dummy_commit()
+        dest.refs['refs/heads/master'] = dummy
+        c = client.TCPGitClient('localhost')
+        dest = repo.Repo(os.path.join(self.gitroot, 'server_new.export'))
+        refs = c.fetch('/dest', dest)
+        map(lambda r: dest.refs.set_if_equals(r[0], None, r[1]), refs.items())
+        self.assertDestEqualsSrc()

+ 73 - 0
dulwich/tests/compat/test_pack.py

@@ -0,0 +1,73 @@
+# test_pack.py -- Compatibilty tests for git packs.
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Compatibilty tests for git packs."""
+
+
+import binascii
+import os
+import shutil
+import tempfile
+
+from dulwich.pack import (
+    write_pack,
+    )
+from dulwich.tests.test_pack import (
+    pack1_sha,
+    PackTests,
+    )
+from utils import (
+    require_git_version,
+    run_git,
+    )
+
+
+class TestPack(PackTests):
+    """Compatibility tests for reading and writing pack files."""
+
+    def setUp(self):
+        require_git_version((1, 5, 0))
+        PackTests.setUp(self)
+        self._tempdir = tempfile.mkdtemp()
+
+    def tearDown(self):
+        shutil.rmtree(self._tempdir)
+        PackTests.tearDown(self)
+
+    def test_copy(self):
+        origpack = self.get_pack(pack1_sha)
+        self.assertSucceeds(origpack.index.check)
+        pack_path = os.path.join(self._tempdir, "Elch")
+        write_pack(pack_path, [(x, "") for x in origpack.iterobjects()],
+                   len(origpack))
+
+        returncode, output = run_git(['verify-pack', '-v', pack_path],
+                                     capture_stdout=True)
+        self.assertEquals(0, returncode)
+
+        pack_shas = set()
+        for line in output.splitlines():
+            sha = line[:40]
+            try:
+                binascii.unhexlify(sha)
+            except TypeError:
+                continue  # non-sha line
+            pack_shas.add(sha)
+        orig_shas = set(o.id for o in origpack.iterobjects())
+        self.assertEquals(orig_shas, pack_shas)

+ 131 - 0
dulwich/tests/compat/test_repository.py

@@ -0,0 +1,131 @@
+# test_repo.py -- Git repo compatibility tests
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Compatibility tests for dulwich repositories."""
+
+
+from cStringIO import StringIO
+import itertools
+import os
+
+from dulwich.objects import (
+    hex_to_sha,
+    )
+from dulwich.repo import (
+    check_ref_format,
+    )
+from dulwich.tests.utils import (
+    tear_down_repo,
+    )
+
+from utils import (
+    run_git,
+    import_repo,
+    CompatTestCase,
+    )
+
+
+class ObjectStoreTestCase(CompatTestCase):
+    """Tests for git repository compatibility."""
+
+    def setUp(self):
+        CompatTestCase.setUp(self)
+        self._repo = import_repo('server_new.export')
+
+    def tearDown(self):
+        CompatTestCase.tearDown(self)
+        tear_down_repo(self._repo)
+
+    def _run_git(self, args):
+        returncode, output = run_git(args, capture_stdout=True,
+                                     cwd=self._repo.path)
+        self.assertEqual(0, returncode)
+        return output
+
+    def _parse_refs(self, output):
+        refs = {}
+        for line in StringIO(output):
+            fields = line.rstrip('\n').split(' ')
+            self.assertEqual(3, len(fields))
+            refname, type_name, sha = fields
+            check_ref_format(refname[5:])
+            hex_to_sha(sha)
+            refs[refname] = (type_name, sha)
+        return refs
+
+    def _parse_objects(self, output):
+        return set(s.rstrip('\n').split(' ')[0] for s in StringIO(output))
+
+    def test_bare(self):
+        self.assertTrue(self._repo.bare)
+        self.assertFalse(os.path.exists(os.path.join(self._repo.path, '.git')))
+
+    def test_head(self):
+        output = self._run_git(['rev-parse', 'HEAD'])
+        head_sha = output.rstrip('\n')
+        hex_to_sha(head_sha)
+        self.assertEqual(head_sha, self._repo.refs['HEAD'])
+
+    def test_refs(self):
+        output = self._run_git(
+          ['for-each-ref', '--format=%(refname) %(objecttype) %(objectname)'])
+        expected_refs = self._parse_refs(output)
+
+        actual_refs = {}
+        for refname, sha in self._repo.refs.as_dict().iteritems():
+            if refname == 'HEAD':
+                continue  # handled in test_head
+            obj = self._repo[sha]
+            self.assertEqual(sha, obj.id)
+            actual_refs[refname] = (obj.type_name, obj.id)
+        self.assertEqual(expected_refs, actual_refs)
+
+    # TODO(dborowitz): peeled ref tests
+
+    def _get_loose_shas(self):
+        output = self._run_git(['rev-list', '--all', '--objects', '--unpacked'])
+        return self._parse_objects(output)
+
+    def _get_all_shas(self):
+        output = self._run_git(['rev-list', '--all', '--objects'])
+        return self._parse_objects(output)
+
+    def assertShasMatch(self, expected_shas, actual_shas_iter):
+        actual_shas = set()
+        for sha in actual_shas_iter:
+            obj = self._repo[sha]
+            self.assertEqual(sha, obj.id)
+            actual_shas.add(sha)
+        self.assertEqual(expected_shas, actual_shas)
+
+    def test_loose_objects(self):
+        # TODO(dborowitz): This is currently not very useful since fast-imported
+        # repos only contained packed objects.
+        expected_shas = self._get_loose_shas()
+        self.assertShasMatch(expected_shas,
+                             self._repo.object_store._iter_loose_objects())
+
+    def test_packed_objects(self):
+        expected_shas = self._get_all_shas() - self._get_loose_shas()
+        self.assertShasMatch(expected_shas,
+                             itertools.chain(*self._repo.object_store.packs))
+
+    def test_all_objects(self):
+        expected_shas = self._get_all_shas()
+        self.assertShasMatch(expected_shas, iter(self._repo.object_store))

+ 78 - 0
dulwich/tests/compat/test_server.py

@@ -0,0 +1,78 @@
+# test_server.py -- Compatibilty tests for git server.
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Compatibilty tests between Dulwich and the cgit server.
+
+Warning: these tests should be fairly stable, but when writing/debugging new
+tests, deadlocks may freeze the test process such that it cannot be Ctrl-C'ed.
+On *nix, you can kill the tests with Ctrl-Z, "kill %".
+"""
+
+import threading
+
+from dulwich.server import (
+    DictBackend,
+    TCPGitServer,
+    )
+from dulwich.tests import (
+    TestSkipped,
+    )
+from server_utils import (
+    ServerTests,
+    ShutdownServerMixIn,
+    )
+from utils import (
+    CompatTestCase,
+    )
+
+
+if not getattr(TCPGitServer, 'shutdown', None):
+    _TCPGitServer = TCPGitServer
+
+    class TCPGitServer(ShutdownServerMixIn, TCPGitServer):
+        """Subclass of TCPGitServer that can be shut down."""
+
+        def __init__(self, *args, **kwargs):
+            # BaseServer is old-style so we have to call both __init__s
+            ShutdownServerMixIn.__init__(self)
+            _TCPGitServer.__init__(self, *args, **kwargs)
+
+        serve = ShutdownServerMixIn.serve_forever
+
+
+class GitServerTestCase(ServerTests, CompatTestCase):
+    """Tests for client/server compatibility."""
+
+    protocol = 'git'
+
+    def setUp(self):
+        ServerTests.setUp(self)
+        CompatTestCase.setUp(self)
+
+    def tearDown(self):
+        ServerTests.tearDown(self)
+        CompatTestCase.tearDown(self)
+
+    def _start_server(self, repo):
+        backend = DictBackend({'/': repo})
+        dul_server = TCPGitServer(backend, 'localhost', 0)
+        threading.Thread(target=dul_server.serve).start()
+        self._server = dul_server
+        _, port = self._server.socket.getsockname()
+        return port

+ 116 - 0
dulwich/tests/compat/test_web.py

@@ -0,0 +1,116 @@
+# test_web.py -- Compatibilty tests for the git web server.
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Compatibilty tests between Dulwich and the cgit HTTP server.
+
+Warning: these tests should be fairly stable, but when writing/debugging new
+tests, deadlocks may freeze the test process such that it cannot be Ctrl-C'ed.
+On *nix, you can kill the tests with Ctrl-Z, "kill %".
+"""
+
+import threading
+from wsgiref import simple_server
+
+from dulwich.server import (
+    DictBackend,
+    )
+from dulwich.tests import (
+    TestSkipped,
+    )
+from dulwich.web import (
+    HTTPGitApplication,
+    )
+
+from server_utils import (
+    ServerTests,
+    ShutdownServerMixIn,
+    )
+from utils import (
+    CompatTestCase,
+    )
+
+
+if getattr(simple_server.WSGIServer, 'shutdown', None):
+    WSGIServer = simple_server.WSGIServer
+else:
+    class WSGIServer(ShutdownServerMixIn, simple_server.WSGIServer):
+        """Subclass of WSGIServer that can be shut down."""
+
+        def __init__(self, *args, **kwargs):
+            # BaseServer is old-style so we have to call both __init__s
+            ShutdownServerMixIn.__init__(self)
+            simple_server.WSGIServer.__init__(self, *args, **kwargs)
+
+        serve = ShutdownServerMixIn.serve_forever
+
+
+class WebTests(ServerTests):
+    """Base tests for web server tests.
+
+    Contains utility and setUp/tearDown methods, but does non inherit from
+    TestCase so tests are not automatically run.
+    """
+
+    protocol = 'http'
+
+    def _start_server(self, repo):
+        backend = DictBackend({'/': repo})
+        app = self._make_app(backend)
+        dul_server = simple_server.make_server('localhost', 0, app,
+                                               server_class=WSGIServer)
+        threading.Thread(target=dul_server.serve_forever).start()
+        self._server = dul_server
+        _, port = dul_server.socket.getsockname()
+        return port
+
+
+class SmartWebTestCase(WebTests, CompatTestCase):
+    """Test cases for smart HTTP server."""
+
+    min_git_version = (1, 6, 6)
+
+    def setUp(self):
+        WebTests.setUp(self)
+        CompatTestCase.setUp(self)
+
+    def tearDown(self):
+        WebTests.tearDown(self)
+        CompatTestCase.tearDown(self)
+
+    def _make_app(self, backend):
+        return HTTPGitApplication(backend)
+
+
+class DumbWebTestCase(WebTests, CompatTestCase):
+    """Test cases for dumb HTTP server."""
+
+    def setUp(self):
+        WebTests.setUp(self)
+        CompatTestCase.setUp(self)
+
+    def tearDown(self):
+        WebTests.tearDown(self)
+        CompatTestCase.tearDown(self)
+
+    def _make_app(self, backend):
+        return HTTPGitApplication(backend, dumb=True)
+
+    def test_push_to_dulwich(self):
+        # Note: remove this if dumb pushing is supported
+        raise TestSkipped('Dumb web pushing not supported.')

+ 196 - 0
dulwich/tests/compat/utils.py

@@ -0,0 +1,196 @@
+# utils.py -- Git compatibility utilities
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Utilities for interacting with cgit."""
+
+import errno
+import os
+import socket
+import subprocess
+import tempfile
+import time
+import unittest
+
+from dulwich.repo import Repo
+from dulwich.protocol import TCP_GIT_PORT
+
+from dulwich.tests import (
+    TestSkipped,
+    )
+
+_DEFAULT_GIT = 'git'
+
+
+def git_version(git_path=_DEFAULT_GIT):
+    """Attempt to determine the version of git currently installed.
+
+    :param git_path: Path to the git executable; defaults to the version in
+        the system path.
+    :return: A tuple of ints of the form (major, minor, point), or None if no
+        git installation was found.
+    """
+    try:
+        _, output = run_git(['--version'], git_path=git_path,
+                            capture_stdout=True)
+    except OSError:
+        return None
+    version_prefix = 'git version '
+    if not output.startswith(version_prefix):
+        return None
+    output = output[len(version_prefix):]
+    nums = output.split('.')
+    if len(nums) == 2:
+        nums.add('0')
+    else:
+        nums = nums[:3]
+    try:
+        return tuple(int(x) for x in nums)
+    except ValueError:
+        return None
+
+
+def require_git_version(required_version, git_path=_DEFAULT_GIT):
+    """Require git version >= version, or skip the calling test."""
+    found_version = git_version(git_path=git_path)
+    if found_version < required_version:
+        required_version = '.'.join(map(str, required_version))
+        found_version = '.'.join(map(str, found_version))
+        raise TestSkipped('Test requires git >= %s, found %s' %
+                         (required_version, found_version))
+
+
+def run_git(args, git_path=_DEFAULT_GIT, input=None, capture_stdout=False,
+            **popen_kwargs):
+    """Run a git command.
+
+    Input is piped from the input parameter and output is sent to the standard
+    streams, unless capture_stdout is set.
+
+    :param args: A list of args to the git command.
+    :param git_path: Path to to the git executable.
+    :param input: Input data to be sent to stdin.
+    :param capture_stdout: Whether to capture and return stdout.
+    :param popen_kwargs: Additional kwargs for subprocess.Popen;
+        stdin/stdout args are ignored.
+    :return: A tuple of (returncode, stdout contents). If capture_stdout is
+        False, None will be returned as stdout contents.
+    :raise OSError: if the git executable was not found.
+    """
+    args = [git_path] + args
+    popen_kwargs['stdin'] = subprocess.PIPE
+    if capture_stdout:
+        popen_kwargs['stdout'] = subprocess.PIPE
+    else:
+        popen_kwargs.pop('stdout', None)
+    p = subprocess.Popen(args, **popen_kwargs)
+    stdout, stderr = p.communicate(input=input)
+    return (p.returncode, stdout)
+
+
+def run_git_or_fail(args, git_path=_DEFAULT_GIT, input=None, **popen_kwargs):
+    """Run a git command, capture stdout/stderr, and fail if git fails."""
+    popen_kwargs['stderr'] = subprocess.STDOUT
+    returncode, stdout = run_git(args, git_path=git_path, input=input,
+                                 capture_stdout=True, **popen_kwargs)
+    assert returncode == 0
+    return stdout
+
+
+def import_repo_to_dir(name):
+    """Import a repo from a fast-export file in a temporary directory.
+
+    These are used rather than binary repos for compat tests because they are
+    more compact an human-editable, and we already depend on git.
+
+    :param name: The name of the repository export file, relative to
+        dulwich/tests/data/repos.
+    :returns: The path to the imported repository.
+    """
+    temp_dir = tempfile.mkdtemp()
+    export_path = os.path.join(os.path.dirname(__file__), os.pardir, 'data',
+                               'repos', name)
+    temp_repo_dir = os.path.join(temp_dir, name)
+    export_file = open(export_path, 'rb')
+    run_git_or_fail(['init', '--bare', temp_repo_dir])
+    run_git_or_fail(['fast-import'], input=export_file.read(),
+                    cwd=temp_repo_dir)
+    export_file.close()
+    return temp_repo_dir
+
+def import_repo(name):
+    """Import a repo from a fast-export file in a temporary directory.
+
+    :param name: The name of the repository export file, relative to
+        dulwich/tests/data/repos.
+    :returns: An initialized Repo object that lives in a temporary directory.
+    """
+    return Repo(import_repo_to_dir(name))
+
+
+def check_for_daemon(limit=10, delay=0.1, timeout=0.1, port=TCP_GIT_PORT):
+    """Check for a running TCP daemon.
+
+    Defaults to checking 10 times with a delay of 0.1 sec between tries.
+
+    :param limit: Number of attempts before deciding no daemon is running.
+    :param delay: Delay between connection attempts.
+    :param timeout: Socket timeout for connection attempts.
+    :param port: Port on which we expect the daemon to appear.
+    :returns: A boolean, true if a daemon is running on the specified port,
+        false if not.
+    """
+    for _ in xrange(limit):
+        time.sleep(delay)
+        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+        s.settimeout(delay)
+        try:
+            s.connect(('localhost', port))
+            s.close()
+            return True
+        except socket.error, e:
+            if getattr(e, 'errno', False) and e.errno != errno.ECONNREFUSED:
+                raise
+            elif e.args[0] != errno.ECONNREFUSED:
+                raise
+    return False
+
+
+class CompatTestCase(unittest.TestCase):
+    """Test case that requires git for compatibility checks.
+
+    Subclasses can change the git version required by overriding
+    min_git_version.
+    """
+
+    min_git_version = (1, 5, 0)
+
+    def setUp(self):
+        require_git_version(self.min_git_version)
+
+    def assertReposEqual(self, repo1, repo2):
+        self.assertEqual(repo1.get_refs(), repo2.get_refs())
+        self.assertEqual(sorted(set(repo1.object_store)),
+                         sorted(set(repo2.object_store)))
+
+    def assertReposNotEqual(self, repo1, repo2):
+        refs1 = repo1.get_refs()
+        objs1 = set(repo1.object_store)
+        refs2 = repo2.get_refs()
+        objs2 = set(repo2.object_store)
+        self.assertFalse(refs1 == refs2 and objs1 == objs2)

BIN
dulwich/tests/data/blobs/11/11111111111111111111111111111111111111


+ 0 - 0
dulwich/tests/data/blobs/6f670c0fb53f9463760b7295fbb814e965fb20c8 → dulwich/tests/data/blobs/6f/670c0fb53f9463760b7295fbb814e965fb20c8


+ 0 - 0
dulwich/tests/data/blobs/954a536f7819d40e6f637f849ee187dd10066349 → dulwich/tests/data/blobs/95/4a536f7819d40e6f637f849ee187dd10066349


+ 0 - 0
dulwich/tests/data/blobs/e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 → dulwich/tests/data/blobs/e6/9de29bb2d1d6434b8b29ae775ad8c2e48c5391


+ 0 - 0
dulwich/tests/data/commits/0d89f20333fbb1d2f3a94da77f4981373d8f4310 → dulwich/tests/data/commits/0d/89f20333fbb1d2f3a94da77f4981373d8f4310


+ 0 - 0
dulwich/tests/data/commits/5dac377bdded4c9aeb8dff595f0faeebcc8498cc → dulwich/tests/data/commits/5d/ac377bdded4c9aeb8dff595f0faeebcc8498cc


+ 0 - 0
dulwich/tests/data/commits/60dacdc733de308bb77bb76ce0fb0f9b44c9769e → dulwich/tests/data/commits/60/dacdc733de308bb77bb76ce0fb0f9b44c9769e


+ 2 - 0
dulwich/tests/data/repos/a.git/objects/28/237f4dc30d0d462658d6b937b08a0f0b6ef55a

@@ -0,0 +1,2 @@
+x5ÌA
+Â0…a×9Å\@™¦i›�""ÁLÚ1T"uPêéMA7�oó~å•ó»î2(0á�íHˆ\uB\]ÛMÞN‚c+ÄH�Ñõ!0ä”&5Zi-»)Ê~	œó’ß“~ Ã�§˜sœåP~G¨lÛÖ®Á†`�јkéÌüÔ÷ÀN0—

+ 3 - 0
dulwich/tests/data/repos/a.git/objects/b0/931cadc54336e78a1d980420e3268903b57a50

@@ -0,0 +1,3 @@
+x-�[
+Β0ύΞ*ξ*IΜ��Έ7�Η5T[o©΅RWo†Γΐ™
+­wο�*Θ`eφ�/“Ωi­·7sΰΒjƒpΑθμ«Ϋ��h�†ΚjkL[c7‡τΐόΈ„αL½‡ϊ�>Η�<Ά2βΎέ� ¤1JrηtάqΞΨµεhΜ°βςθΙΎΦ¥2v

+ 3 - 0
dulwich/tests/data/repos/a.git/packed-refs

@@ -0,0 +1,3 @@
+# pack-refs with: peeled 
+b0931cadc54336e78a1d980420e3268903b57a50 refs/tags/mytag-packed
+^2a72d929692c41d8554c07f6301757ba18a65d91

+ 1 - 0
dulwich/tests/data/repos/a.git/refs/tags/mytag

@@ -0,0 +1 @@
+28237f4dc30d0d462658d6b937b08a0f0b6ef55a

+ 3 - 0
dulwich/tests/data/repos/refs.git/objects/3e/c9c43c84ff242e3ef4a9fc5bc111fd780a76a8

@@ -0,0 +1,3 @@
+x-�Q
+Â0DýÎ)ö-›mšVñ^ i6±Ò6’.ŠžÞ~ÍÌÇ{#Cm›]rwy´Î×u�=’uº5^[³o¸õ<¸®H<*y?Æ´,“()ŽÌa«°¦ßˆœá2<Î)§×$8x÷¯§˜Rœ¹.è4Ykˆt�Pa�¨Ôµ¨
+q?…À™W)'«”ÜÔǧ6

+ 5 - 0
dulwich/tests/data/repos/refs.git/objects/cd/a609072918d7b70057b6bef9f4c2537843fcfe

@@ -0,0 +1,5 @@
+x-ŤQ
+Â0DýÎ)öm7i’Vń^ i6±bIEOo
+~Íc`ŢđAví.ą;ŤZy´Îk×u�<*ë¤Ń^Z˝oÉx\×T4
+ţ<	Ć4Ď.ŽLam
+¬ÍFÖj«#e¸/‚sĘé=ńŢýńSŠŞ‹äRY«•Bc ÂQ�k‘–Ĺ
üeZ¸Ü-\r?)Y9Ţ

+ 1 - 0
dulwich/tests/data/repos/refs.git/packed-refs

@@ -1,3 +1,4 @@
 # pack-refs with: peeled 
 # pack-refs with: peeled 
 df6800012397fb85c56e7418dd4eb9405dee075c refs/tags/refs-0.1
 df6800012397fb85c56e7418dd4eb9405dee075c refs/tags/refs-0.1
 ^42d06bd4b77fed026b154d16493e5deab78f02ec
 ^42d06bd4b77fed026b154d16493e5deab78f02ec
+42d06bd4b77fed026b154d16493e5deab78f02ec refs/heads/packed

+ 1 - 0
dulwich/tests/data/repos/refs.git/refs/tags/refs-0.2

@@ -0,0 +1 @@
+3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8

+ 99 - 0
dulwich/tests/data/repos/server_new.export

@@ -0,0 +1,99 @@
+blob
+mark :1
+data 13
+foo contents
+
+reset refs/heads/master
+commit refs/heads/master
+mark :2
+author Dave Borowitz <dborowitz@google.com> 1265755064 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755064 -0800
+data 16
+initial checkin
+M 100644 :1 foo
+
+blob
+mark :3
+data 13
+baz contents
+
+blob
+mark :4
+data 21
+updated foo contents
+
+commit refs/heads/master
+mark :5
+author Dave Borowitz <dborowitz@google.com> 1265755140 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755140 -0800
+data 15
+master checkin
+from :2
+M 100644 :3 baz
+M 100644 :4 foo
+
+blob
+mark :6
+data 24
+updated foo contents v2
+
+commit refs/heads/master
+mark :7
+author Dave Borowitz <dborowitz@google.com> 1265755287 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755287 -0800
+data 17
+master checkin 2
+from :5
+M 100644 :6 foo
+
+blob
+mark :8
+data 24
+updated foo contents v3
+
+commit refs/heads/master
+mark :9
+author Dave Borowitz <dborowitz@google.com> 1265755295 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755295 -0800
+data 17
+master checkin 3
+from :7
+M 100644 :8 foo
+
+blob
+mark :10
+data 22
+branched bar contents
+
+blob
+mark :11
+data 22
+branched foo contents
+
+commit refs/heads/branch
+mark :12
+author Dave Borowitz <dborowitz@google.com> 1265755111 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755111 -0800
+data 15
+branch checkin
+from :2
+M 100644 :10 bar
+M 100644 :11 foo
+
+blob
+mark :13
+data 25
+branched bar contents v2
+
+commit refs/heads/branch
+mark :14
+author Dave Borowitz <dborowitz@google.com> 1265755319 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755319 -0800
+data 17
+branch checkin 2
+from :12
+M 100644 :13 bar
+
+reset refs/heads/master
+from :9
+

+ 57 - 0
dulwich/tests/data/repos/server_old.export

@@ -0,0 +1,57 @@
+blob
+mark :1
+data 13
+foo contents
+
+reset refs/heads/master
+commit refs/heads/master
+mark :2
+author Dave Borowitz <dborowitz@google.com> 1265755064 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755064 -0800
+data 16
+initial checkin
+M 100644 :1 foo
+
+blob
+mark :3
+data 22
+branched bar contents
+
+blob
+mark :4
+data 22
+branched foo contents
+
+commit refs/heads/branch
+mark :5
+author Dave Borowitz <dborowitz@google.com> 1265755111 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755111 -0800
+data 15
+branch checkin
+from :2
+M 100644 :3 bar
+M 100644 :4 foo
+
+blob
+mark :6
+data 13
+baz contents
+
+blob
+mark :7
+data 21
+updated foo contents
+
+commit refs/heads/master
+mark :8
+author Dave Borowitz <dborowitz@google.com> 1265755140 -0800
+committer Dave Borowitz <dborowitz@google.com> 1265755140 -0800
+data 15
+master checkin
+from :2
+M 100644 :6 baz
+M 100644 :7 foo
+
+reset refs/heads/master
+from :8
+

+ 0 - 0
dulwich/tests/data/tags/71033db03a03c6a36721efcf1968dd8f8e0cf023 → dulwich/tests/data/tags/71/033db03a03c6a36721efcf1968dd8f8e0cf023


+ 0 - 0
dulwich/tests/data/trees/70c190eb48fa8bbb50ddc692a17b44cb781af7f6 → dulwich/tests/data/trees/70/c190eb48fa8bbb50ddc692a17b44cb781af7f6


+ 11 - 5
dulwich/tests/test_client.py

@@ -1,16 +1,16 @@
 # test_client.py -- Tests for the git protocol, client side
 # test_client.py -- Tests for the git protocol, client side
 # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
 # Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # as published by the Free Software Foundation; version 2
 # or (at your option) any later version of the License.
 # or (at your option) any later version of the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -23,16 +23,22 @@ from dulwich.client import (
     GitClient,
     GitClient,
     )
     )
 
 
+
+# TODO(durin42): add unit-level tests of GitClient
 class GitClientTests(TestCase):
 class GitClientTests(TestCase):
 
 
     def setUp(self):
     def setUp(self):
         self.rout = StringIO()
         self.rout = StringIO()
         self.rin = StringIO()
         self.rin = StringIO()
-        self.client = GitClient(lambda x: True, self.rin.read, 
+        self.client = GitClient(lambda x: True, self.rin.read,
             self.rout.write)
             self.rout.write)
 
 
     def test_caps(self):
     def test_caps(self):
-        self.assertEquals(['multi_ack', 'side-band-64k', 'ofs-delta', 'thin-pack'], self.client._capabilities)
+        self.assertEquals(set(['multi_ack', 'side-band-64k', 'ofs-delta',
+                               'thin-pack']),
+                          set(self.client._fetch_capabilities))
+        self.assertEquals(set(['ofs-delta', 'report-status']),
+                          set(self.client._send_capabilities))
 
 
     def test_fetch_pack_none(self):
     def test_fetch_pack_none(self):
         self.rin.write(
         self.rin.write(

+ 78 - 0
dulwich/tests/test_fastexport.py

@@ -0,0 +1,78 @@
+# test_fastexport.py -- Fast export/import functionality
+# Copyright (C) 2010 Jelmer Vernooij <jelmer@samba.org>
+# 
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of 
+# the License.
+# 
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+# 
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+from cStringIO import StringIO
+import stat
+from unittest import TestCase
+
+from dulwich.fastexport import (
+    FastExporter,
+    )
+from dulwich.object_store import (
+    MemoryObjectStore,
+    )
+from dulwich.objects import (
+    Blob,
+    Commit,
+    Tree,
+    )
+
+
+class FastExporterTests(TestCase):
+
+    def setUp(self):
+        super(FastExporterTests, self).setUp()
+        self.store = MemoryObjectStore()
+        self.stream = StringIO()
+        self.fastexporter = FastExporter(self.stream, self.store)
+
+    def test_export_blob(self):
+        b = Blob()
+        b.data = "fooBAR"
+        self.assertEquals(1, self.fastexporter.export_blob(b))
+        self.assertEquals('blob\nmark :1\ndata 6\nfooBAR\n',
+            self.stream.getvalue())
+
+    def test_export_commit(self):
+        b = Blob()
+        b.data = "FOO"
+        t = Tree()
+        t.add(stat.S_IFREG | 0644, "foo", b.id)
+        c = Commit()
+        c.committer = c.author = "Jelmer <jelmer@host>"
+        c.author_time = c.commit_time = 1271345553.47
+        c.author_timezone = c.commit_timezone = 0
+        c.message = "msg"
+        c.tree = t.id
+        self.store.add_objects([(b, None), (t, None), (c, None)])
+        self.assertEquals(2,
+                self.fastexporter.export_commit(c, "refs/heads/master"))
+        self.assertEquals("""blob
+mark :1
+data 3
+FOO
+commit refs/heads/master
+mark :2
+author Jelmer <jelmer@host> 1271345553.47 +0000
+committer Jelmer <jelmer@host> 1271345553.47 +0000
+data 3
+msg
+M 100644 :1 foo
+
+""", self.stream.getvalue())

+ 68 - 2
dulwich/tests/test_file.py

@@ -20,12 +20,71 @@
 import errno
 import errno
 import os
 import os
 import shutil
 import shutil
+import sys
 import tempfile
 import tempfile
 import unittest
 import unittest
 
 
-from dulwich.file import GitFile
+from dulwich.file import GitFile, fancy_rename
+from dulwich.tests import TestSkipped
+
+
+class FancyRenameTests(unittest.TestCase):
+
+    def setUp(self):
+        self._tempdir = tempfile.mkdtemp()
+        self.foo = self.path('foo')
+        self.bar = self.path('bar')
+        self.create(self.foo, 'foo contents')
+
+    def tearDown(self):
+        shutil.rmtree(self._tempdir)
+
+    def path(self, filename):
+        return os.path.join(self._tempdir, filename)
+
+    def create(self, path, contents):
+        f = open(path, 'wb')
+        f.write(contents)
+        f.close()
+
+    def test_no_dest_exists(self):
+        self.assertFalse(os.path.exists(self.bar))
+        fancy_rename(self.foo, self.bar)
+        self.assertFalse(os.path.exists(self.foo))
+
+        new_f = open(self.bar, 'rb')
+        self.assertEquals('foo contents', new_f.read())
+        new_f.close()
+         
+    def test_dest_exists(self):
+        self.create(self.bar, 'bar contents')
+        fancy_rename(self.foo, self.bar)
+        self.assertFalse(os.path.exists(self.foo))
+
+        new_f = open(self.bar, 'rb')
+        self.assertEquals('foo contents', new_f.read())
+        new_f.close()
+
+    def test_dest_opened(self):
+        if sys.platform != "win32":
+            raise TestSkipped("platform allows overwriting open files")
+        self.create(self.bar, 'bar contents')
+        dest_f = open(self.bar, 'rb')
+        self.assertRaises(OSError, fancy_rename, self.foo, self.bar)
+        dest_f.close()
+        self.assertTrue(os.path.exists(self.path('foo')))
+
+        new_f = open(self.foo, 'rb')
+        self.assertEquals('foo contents', new_f.read())
+        new_f.close()
+
+        new_f = open(self.bar, 'rb')
+        self.assertEquals('bar contents', new_f.read())
+        new_f.close()
+
 
 
 class GitFileTests(unittest.TestCase):
 class GitFileTests(unittest.TestCase):
+
     def setUp(self):
     def setUp(self):
         self._tempdir = tempfile.mkdtemp()
         self._tempdir = tempfile.mkdtemp()
         f = open(self.path('foo'), 'wb')
         f = open(self.path('foo'), 'wb')
@@ -85,7 +144,7 @@ class GitFileTests(unittest.TestCase):
         f1.write('new')
         f1.write('new')
         try:
         try:
             f2 = GitFile(foo, 'wb')
             f2 = GitFile(foo, 'wb')
-            fail()
+            self.fail()
         except OSError, e:
         except OSError, e:
             self.assertEquals(errno.EEXIST, e.errno)
             self.assertEquals(errno.EEXIST, e.errno)
         f1.write(' contents')
         f1.write(' contents')
@@ -129,3 +188,10 @@ class GitFileTests(unittest.TestCase):
             f.abort()
             f.abort()
         except (IOError, OSError):
         except (IOError, OSError):
             self.fail()
             self.fail()
+
+    def test_abort_close_removed(self):
+        foo = self.path('foo')
+        f = GitFile(foo, 'wb')
+        os.remove(foo+".lock")
+        f.abort()
+        self.assertTrue(f._closed)

+ 31 - 11
dulwich/tests/test_index.py

@@ -1,16 +1,16 @@
 # test_index.py -- Tests for the git index
 # test_index.py -- Tests for the git index
 # Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org>
 # Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # as published by the Free Software Foundation; version 2
 # or (at your option) any later version of the License.
 # or (at your option) any later version of the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -24,8 +24,10 @@ from cStringIO import (
     StringIO,
     StringIO,
     )
     )
 import os
 import os
+import shutil
 import stat
 import stat
 import struct
 import struct
+import tempfile
 from unittest import TestCase
 from unittest import TestCase
 
 
 from dulwich.index import (
 from dulwich.index import (
@@ -43,6 +45,7 @@ from dulwich.objects import (
     Blob,
     Blob,
     )
     )
 
 
+
 class IndexTestCase(TestCase):
 class IndexTestCase(TestCase):
 
 
     datadir = os.path.join(os.path.dirname(__file__), 'data/indexes')
     datadir = os.path.join(os.path.dirname(__file__), 'data/indexes')
@@ -51,7 +54,7 @@ class IndexTestCase(TestCase):
         return Index(os.path.join(self.datadir, name))
         return Index(os.path.join(self.datadir, name))
 
 
 
 
-class SimpleIndexTestcase(IndexTestCase):
+class SimpleIndexTestCase(IndexTestCase):
 
 
     def test_len(self):
     def test_len(self):
         self.assertEquals(1, len(self.get_simple_index("index")))
         self.assertEquals(1, len(self.get_simple_index("index")))
@@ -60,21 +63,38 @@ class SimpleIndexTestcase(IndexTestCase):
         self.assertEquals(['bla'], list(self.get_simple_index("index")))
         self.assertEquals(['bla'], list(self.get_simple_index("index")))
 
 
     def test_getitem(self):
     def test_getitem(self):
-        self.assertEquals( ((1230680220, 0), (1230680220, 0), 2050, 3761020, 33188, 1000, 1000, 0, 'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 0)
-            , 
-                self.get_simple_index("index")["bla"])
+        self.assertEquals(((1230680220, 0), (1230680220, 0), 2050, 3761020,
+                           33188, 1000, 1000, 0,
+                           'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 0),
+                          self.get_simple_index("index")["bla"])
+
+    def test_empty(self):
+        i = self.get_simple_index("notanindex")
+        self.assertEquals(0, len(i))
+        self.assertFalse(os.path.exists(i._filename))
 
 
 
 
 class SimpleIndexWriterTestCase(IndexTestCase):
 class SimpleIndexWriterTestCase(IndexTestCase):
 
 
+    def setUp(self):
+        IndexTestCase.setUp(self)
+        self.tempdir = tempfile.mkdtemp()
+
+    def tearDown(self):
+        IndexTestCase.tearDown(self)
+        shutil.rmtree(self.tempdir)
+
     def test_simple_write(self):
     def test_simple_write(self):
-        entries = [('barbla', (1230680220, 0), (1230680220, 0), 2050, 3761020, 33188, 1000, 1000, 0, 'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 0)]
-        x = open('test-simple-write-index', 'w+')
+        entries = [('barbla', (1230680220, 0), (1230680220, 0), 2050, 3761020,
+                    33188, 1000, 1000, 0,
+                    'e69de29bb2d1d6434b8b29ae775ad8c2e48c5391', 0)]
+        filename = os.path.join(self.tempdir, 'test-simple-write-index')
+        x = open(filename, 'w+')
         try:
         try:
             write_index(x, entries)
             write_index(x, entries)
         finally:
         finally:
             x.close()
             x.close()
-        x = open('test-simple-write-index', 'r')
+        x = open(filename, 'r')
         try:
         try:
             self.assertEquals(entries, list(read_index(x)))
             self.assertEquals(entries, list(read_index(x)))
         finally:
         finally:
@@ -108,7 +128,7 @@ class CommitTreeTests(TestCase):
         self.assertEquals(dirid, "c1a1deb9788150829579a8b4efa6311e7b638650")
         self.assertEquals(dirid, "c1a1deb9788150829579a8b4efa6311e7b638650")
         self.assertEquals((stat.S_IFDIR, dirid), self.store[rootid]["bla"])
         self.assertEquals((stat.S_IFDIR, dirid), self.store[rootid]["bla"])
         self.assertEquals((stat.S_IFREG, blob.id), self.store[dirid]["bar"])
         self.assertEquals((stat.S_IFREG, blob.id), self.store[dirid]["bar"])
-        self.assertEquals(set([rootid, dirid, blob.id]), 
+        self.assertEquals(set([rootid, dirid, blob.id]),
                           set(self.store._data.keys()))
                           set(self.store._data.keys()))
 
 
 
 

+ 33 - 29
dulwich/tests/test_object_store.py

@@ -1,16 +1,16 @@
 # test_object_store.py -- tests for object_store.py
 # test_object_store.py -- tests for object_store.py
 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # as published by the Free Software Foundation; version 2
 # or (at your option) any later version of the License.
 # or (at your option) any later version of the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -20,6 +20,9 @@
 """Tests for the object store interface."""
 """Tests for the object store interface."""
 
 
 
 
+import os
+import shutil
+import tempfile
 from unittest import TestCase
 from unittest import TestCase
 
 
 from dulwich.objects import (
 from dulwich.objects import (
@@ -29,24 +32,12 @@ from dulwich.object_store import (
     DiskObjectStore,
     DiskObjectStore,
     MemoryObjectStore,
     MemoryObjectStore,
     )
     )
-import os
-import shutil
-
-
-testobject = Blob()
-testobject.data = "yummy data"
-
+from utils import (
+    make_object,
+    )
 
 
-class SpecificDiskObjectStoreTests(TestCase):
-
-    def test_pack_dir(self):
-        o = DiskObjectStore("foo")
-        self.assertEquals(os.path.join("foo", "pack"), o.pack_dir)
-
-    def test_empty_packs(self):
-        o = DiskObjectStore("foo")
-        self.assertEquals([], o.packs)
 
 
+testobject = make_object(Blob, data="yummy data")
 
 
 
 
 class ObjectStoreTests(object):
 class ObjectStoreTests(object):
@@ -55,10 +46,10 @@ class ObjectStoreTests(object):
         self.assertEquals([], list(self.store))
         self.assertEquals([], list(self.store))
 
 
     def test_get_nonexistant(self):
     def test_get_nonexistant(self):
-        self.assertRaises(KeyError, self.store.__getitem__, "a" * 40)
+        self.assertRaises(KeyError, lambda: self.store["a" * 40])
 
 
     def test_contains_nonexistant(self):
     def test_contains_nonexistant(self):
-        self.assertFalse(self.store.__contains__("a" * 40))
+        self.assertFalse(("a" * 40) in self.store)
 
 
     def test_add_objects_empty(self):
     def test_add_objects_empty(self):
         self.store.add_objects([])
         self.store.add_objects([])
@@ -71,7 +62,7 @@ class ObjectStoreTests(object):
     def test_add_object(self):
     def test_add_object(self):
         self.store.add_object(testobject)
         self.store.add_object(testobject)
         self.assertEquals(set([testobject.id]), set(self.store))
         self.assertEquals(set([testobject.id]), set(self.store))
-        self.assertTrue(self.store.__contains__(testobject.id))
+        self.assertTrue(testobject.id in self.store)
         r = self.store[testobject.id]
         r = self.store[testobject.id]
         self.assertEquals(r, testobject)
         self.assertEquals(r, testobject)
 
 
@@ -79,23 +70,36 @@ class ObjectStoreTests(object):
         data = [(testobject, "mypath")]
         data = [(testobject, "mypath")]
         self.store.add_objects(data)
         self.store.add_objects(data)
         self.assertEquals(set([testobject.id]), set(self.store))
         self.assertEquals(set([testobject.id]), set(self.store))
-        self.assertTrue(self.store.__contains__(testobject.id))
+        self.assertTrue(testobject.id in self.store)
         r = self.store[testobject.id]
         r = self.store[testobject.id]
         self.assertEquals(r, testobject)
         self.assertEquals(r, testobject)
 
 
 
 
-class MemoryObjectStoreTests(ObjectStoreTests,TestCase):
+class MemoryObjectStoreTests(ObjectStoreTests, TestCase):
 
 
     def setUp(self):
     def setUp(self):
         TestCase.setUp(self)
         TestCase.setUp(self)
         self.store = MemoryObjectStore()
         self.store = MemoryObjectStore()
 
 
 
 
-class DiskObjectStoreTests(ObjectStoreTests,TestCase):
+class DiskObjectStoreTests(ObjectStoreTests, TestCase):
 
 
     def setUp(self):
     def setUp(self):
         TestCase.setUp(self)
         TestCase.setUp(self)
-        if os.path.exists("foo"):
-            shutil.rmtree("foo")
-        os.makedirs(os.path.join("foo", "pack"))
-        self.store = DiskObjectStore("foo")
+        self.store_dir = tempfile.mkdtemp()
+        self.store = DiskObjectStore.init(self.store_dir)
+
+    def tearDown(self):
+        TestCase.tearDown(self)
+        shutil.rmtree(self.store_dir)
+
+    def test_pack_dir(self):
+        o = DiskObjectStore(self.store_dir)
+        self.assertEquals(os.path.join(self.store_dir, "pack"), o.pack_dir)
+
+    def test_empty_packs(self):
+        o = DiskObjectStore(self.store_dir)
+        self.assertEquals([], o.packs)
+
+
+# TODO: MissingObjectFinderTests

+ 420 - 104
dulwich/tests/test_objects.py

@@ -20,11 +20,18 @@
 
 
 """Tests for git base objects."""
 """Tests for git base objects."""
 
 
+# TODO: Round-trip parse-serialize-parse and serialize-parse-serialize tests.
 
 
+
+import datetime
 import os
 import os
 import stat
 import stat
 import unittest
 import unittest
 
 
+from dulwich.errors import (
+    ChecksumMismatch,
+    ObjectFormatException,
+    )
 from dulwich.objects import (
 from dulwich.objects import (
     Blob,
     Blob,
     Tree,
     Tree,
@@ -32,7 +39,20 @@ from dulwich.objects import (
     Tag,
     Tag,
     format_timezone,
     format_timezone,
     hex_to_sha,
     hex_to_sha,
+    sha_to_hex,
+    hex_to_filename,
+    check_hexsha,
+    check_identity,
     parse_timezone,
     parse_timezone,
+    parse_tree,
+    _parse_tree_py,
+    )
+from dulwich.tests import (
+    TestSkipped,
+    )
+from utils import (
+    make_commit,
+    make_object,
     )
     )
 
 
 a_sha = '6f670c0fb53f9463760b7295fbb814e965fb20c8'
 a_sha = '6f670c0fb53f9463760b7295fbb814e965fb20c8'
@@ -41,13 +61,57 @@ c_sha = '954a536f7819d40e6f637f849ee187dd10066349'
 tree_sha = '70c190eb48fa8bbb50ddc692a17b44cb781af7f6'
 tree_sha = '70c190eb48fa8bbb50ddc692a17b44cb781af7f6'
 tag_sha = '71033db03a03c6a36721efcf1968dd8f8e0cf023'
 tag_sha = '71033db03a03c6a36721efcf1968dd8f8e0cf023'
 
 
+
+try:
+    from itertools import permutations
+except ImportError:
+    # Implementation of permutations from Python 2.6 documentation:
+    # http://docs.python.org/2.6/library/itertools.html#itertools.permutations
+    # Copyright (c) 2001-2010 Python Software Foundation; All Rights Reserved
+    # Modified syntax slightly to run under Python 2.4.
+    def permutations(iterable, r=None):
+        # permutations('ABCD', 2) --> AB AC AD BA BC BD CA CB CD DA DB DC
+        # permutations(range(3)) --> 012 021 102 120 201 210
+        pool = tuple(iterable)
+        n = len(pool)
+        if r is None:
+            r = n
+        if r > n:
+            return
+        indices = range(n)
+        cycles = range(n, n-r, -1)
+        yield tuple(pool[i] for i in indices[:r])
+        while n:
+            for i in reversed(range(r)):
+                cycles[i] -= 1
+                if cycles[i] == 0:
+                    indices[i:] = indices[i+1:] + indices[i:i+1]
+                    cycles[i] = n - i
+                else:
+                    j = cycles[i]
+                    indices[i], indices[-j] = indices[-j], indices[i]
+                    yield tuple(pool[i] for i in indices[:r])
+                    break
+            else:
+                return
+
+
+class TestHexToSha(unittest.TestCase):
+
+    def test_simple(self):
+        self.assertEquals("\xab\xcd" * 10, hex_to_sha("abcd" * 10))
+
+    def test_reverse(self):
+        self.assertEquals("abcd" * 10, sha_to_hex("\xab\xcd" * 10))
+
+
 class BlobReadTests(unittest.TestCase):
 class BlobReadTests(unittest.TestCase):
     """Test decompression of blobs"""
     """Test decompression of blobs"""
-  
-    def get_sha_file(self, obj, base, sha):
-        return obj.from_file(os.path.join(os.path.dirname(__file__),
-                                          'data', base, sha))
-  
+
+    def get_sha_file(self, cls, base, sha):
+        dir = os.path.join(os.path.dirname(__file__), 'data', base)
+        return cls.from_path(hex_to_filename(dir, sha))
+
     def get_blob(self, sha):
     def get_blob(self, sha):
         """Return the blob named sha from the test data dir"""
         """Return the blob named sha from the test data dir"""
         return self.get_sha_file(Blob, 'blobs', sha)
         return self.get_sha_file(Blob, 'blobs', sha)
@@ -82,6 +146,18 @@ class BlobReadTests(unittest.TestCase):
         b = Blob.from_string(string)
         b = Blob.from_string(string)
         self.assertEqual(b.data, string)
         self.assertEqual(b.data, string)
         self.assertEqual(b.sha().hexdigest(), b_sha)
         self.assertEqual(b.sha().hexdigest(), b_sha)
+
+    def test_chunks(self):
+        string = 'test 5\n'
+        b = Blob.from_string(string)
+        self.assertEqual([string], b.chunked)
+
+    def test_set_chunks(self):
+        b = Blob()
+        b.chunked = ['te', 'st', ' 5\n']
+        self.assertEqual('test 5\n', b.data)
+        b.chunked = ['te', 'st', ' 6\n']
+        self.assertEqual('test 6\n', b.as_raw_string())
   
   
     def test_parse_legacy_blob(self):
     def test_parse_legacy_blob(self):
         string = 'test 3\n'
         string = 'test 3\n'
@@ -107,12 +183,12 @@ class BlobReadTests(unittest.TestCase):
         self.assertEqual(t.tag_time, 1231203091)
         self.assertEqual(t.tag_time, 1231203091)
         self.assertEqual(t.message, 'This is a signed tag\n-----BEGIN PGP SIGNATURE-----\nVersion: GnuPG v1.4.9 (GNU/Linux)\n\niEYEABECAAYFAkliqx8ACgkQqSMmLy9u/kcx5ACfakZ9NnPl02tOyYP6pkBoEkU1\n5EcAn0UFgokaSvS371Ym/4W9iJj6vh3h\n=ql7y\n-----END PGP SIGNATURE-----\n')
         self.assertEqual(t.message, 'This is a signed tag\n-----BEGIN PGP SIGNATURE-----\nVersion: GnuPG v1.4.9 (GNU/Linux)\n\niEYEABECAAYFAkliqx8ACgkQqSMmLy9u/kcx5ACfakZ9NnPl02tOyYP6pkBoEkU1\n5EcAn0UFgokaSvS371Ym/4W9iJj6vh3h\n=ql7y\n-----END PGP SIGNATURE-----\n')
   
   
-  
     def test_read_commit_from_file(self):
     def test_read_commit_from_file(self):
         sha = '60dacdc733de308bb77bb76ce0fb0f9b44c9769e'
         sha = '60dacdc733de308bb77bb76ce0fb0f9b44c9769e'
         c = self.commit(sha)
         c = self.commit(sha)
         self.assertEqual(c.tree, tree_sha)
         self.assertEqual(c.tree, tree_sha)
-        self.assertEqual(c.parents, ['0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
+        self.assertEqual(c.parents,
+            ['0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
         self.assertEqual(c.author,
         self.assertEqual(c.author,
             'James Westby <jw+debian@jameswestby.net>')
             'James Westby <jw+debian@jameswestby.net>')
         self.assertEqual(c.committer,
         self.assertEqual(c.committer,
@@ -150,95 +226,185 @@ class BlobReadTests(unittest.TestCase):
         self.assertEqual(c.commit_timezone, 0)
         self.assertEqual(c.commit_timezone, 0)
         self.assertEqual(c.author_timezone, 0)
         self.assertEqual(c.author_timezone, 0)
         self.assertEqual(c.message, 'Merge ../b\n')
         self.assertEqual(c.message, 'Merge ../b\n')
-  
+
+
+class ShaFileCheckTests(unittest.TestCase):
+
+    def assertCheckFails(self, cls, data):
+        obj = cls()
+        def do_check():
+            obj.set_raw_string(data)
+            obj.check()
+        self.assertRaises(ObjectFormatException, do_check)
+
+    def assertCheckSucceeds(self, cls, data):
+        obj = cls()
+        obj.set_raw_string(data)
+        self.assertEqual(None, obj.check())
 
 
 
 
 class CommitSerializationTests(unittest.TestCase):
 class CommitSerializationTests(unittest.TestCase):
 
 
-    def make_base(self):
-        c = Commit()
-        c.tree = 'd80c186a03f423a81b39df39dc87fd269736ca86'
-        c.parents = ['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd', '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6']
-        c.author = 'James Westby <jw+debian@jameswestby.net>'
-        c.committer = 'James Westby <jw+debian@jameswestby.net>'
-        c.commit_time = 1174773719
-        c.author_time = 1174773719
-        c.commit_timezone = 0
-        c.author_timezone = 0
-        c.message =  'Merge ../b\n'
-        return c
+    def make_commit(self, **kwargs):
+        attrs = {'tree': 'd80c186a03f423a81b39df39dc87fd269736ca86',
+                 'parents': ['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
+                             '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
+                 'author': 'James Westby <jw+debian@jameswestby.net>',
+                 'committer': 'James Westby <jw+debian@jameswestby.net>',
+                 'commit_time': 1174773719,
+                 'author_time': 1174773719,
+                 'commit_timezone': 0,
+                 'author_timezone': 0,
+                 'message':  'Merge ../b\n'}
+        attrs.update(kwargs)
+        return make_commit(**attrs)
 
 
     def test_encoding(self):
     def test_encoding(self):
-        c = self.make_base()
-        c.encoding = "iso8859-1"
-        self.assertTrue("encoding iso8859-1\n" in c.as_raw_string())        
+        c = self.make_commit(encoding='iso8859-1')
+        self.assertTrue('encoding iso8859-1\n' in c.as_raw_string())
 
 
     def test_short_timestamp(self):
     def test_short_timestamp(self):
-        c = self.make_base()
-        c.commit_time = 30
+        c = self.make_commit(commit_time=30)
         c1 = Commit()
         c1 = Commit()
         c1.set_raw_string(c.as_raw_string())
         c1.set_raw_string(c.as_raw_string())
         self.assertEquals(30, c1.commit_time)
         self.assertEquals(30, c1.commit_time)
 
 
+    def test_raw_length(self):
+        c = self.make_commit()
+        self.assertEquals(len(c.as_raw_string()), c.raw_length())
+
     def test_simple(self):
     def test_simple(self):
-        c = self.make_base()
+        c = self.make_commit()
         self.assertEquals(c.id, '5dac377bdded4c9aeb8dff595f0faeebcc8498cc')
         self.assertEquals(c.id, '5dac377bdded4c9aeb8dff595f0faeebcc8498cc')
         self.assertEquals(
         self.assertEquals(
                 'tree d80c186a03f423a81b39df39dc87fd269736ca86\n'
                 'tree d80c186a03f423a81b39df39dc87fd269736ca86\n'
                 'parent ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd\n'
                 'parent ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd\n'
                 'parent 4cffe90e0a41ad3f5190079d7c8f036bde29cbe6\n'
                 'parent 4cffe90e0a41ad3f5190079d7c8f036bde29cbe6\n'
-                'author James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                'committer James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
+                'author James Westby <jw+debian@jameswestby.net> '
+                '1174773719 +0000\n'
+                'committer James Westby <jw+debian@jameswestby.net> '
+                '1174773719 +0000\n'
                 '\n'
                 '\n'
                 'Merge ../b\n', c.as_raw_string())
                 'Merge ../b\n', c.as_raw_string())
 
 
     def test_timezone(self):
     def test_timezone(self):
-        c = self.make_base()
-        c.commit_timezone = 5 * 60
+        c = self.make_commit(commit_timezone=(5 * 60))
         self.assertTrue(" +0005\n" in c.as_raw_string())
         self.assertTrue(" +0005\n" in c.as_raw_string())
 
 
     def test_neg_timezone(self):
     def test_neg_timezone(self):
-        c = self.make_base()
-        c.commit_timezone = -1 * 3600
+        c = self.make_commit(commit_timezone=(-1 * 3600))
         self.assertTrue(" -0100\n" in c.as_raw_string())
         self.assertTrue(" -0100\n" in c.as_raw_string())
 
 
 
 
-class CommitDeserializationTests(unittest.TestCase):
+default_committer = 'James Westby <jw+debian@jameswestby.net> 1174773719 +0000'
+
+class CommitParseTests(ShaFileCheckTests):
+
+    def make_commit_lines(self,
+                          tree='d80c186a03f423a81b39df39dc87fd269736ca86',
+                          parents=['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
+                                   '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
+                          author=default_committer,
+                          committer=default_committer,
+                          encoding=None,
+                          message='Merge ../b\n',
+                          extra=None):
+        lines = []
+        if tree is not None:
+            lines.append('tree %s' % tree)
+        if parents is not None:
+            lines.extend('parent %s' % p for p in parents)
+        if author is not None:
+            lines.append('author %s' % author)
+        if committer is not None:
+            lines.append('committer %s' % committer)
+        if encoding is not None:
+            lines.append('encoding %s' % encoding)
+        if extra is not None:
+            for name, value in sorted(extra.iteritems()):
+                lines.append('%s %s' % (name, value))
+        lines.append('')
+        if message is not None:
+            lines.append(message)
+        return lines
+
+    def make_commit_text(self, **kwargs):
+        return '\n'.join(self.make_commit_lines(**kwargs))
 
 
     def test_simple(self):
     def test_simple(self):
-        c = Commit.from_string(
-                'tree d80c186a03f423a81b39df39dc87fd269736ca86\n'
-                'parent ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd\n'
-                'parent 4cffe90e0a41ad3f5190079d7c8f036bde29cbe6\n'
-                'author James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                'committer James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                '\n'
-                'Merge ../b\n')
+        c = Commit.from_string(self.make_commit_text())
         self.assertEquals('Merge ../b\n', c.message)
         self.assertEquals('Merge ../b\n', c.message)
+        self.assertEquals('James Westby <jw+debian@jameswestby.net>', c.author)
         self.assertEquals('James Westby <jw+debian@jameswestby.net>',
         self.assertEquals('James Westby <jw+debian@jameswestby.net>',
-            c.author)
-        self.assertEquals('James Westby <jw+debian@jameswestby.net>',
-            c.committer)
-        self.assertEquals('d80c186a03f423a81b39df39dc87fd269736ca86',
-            c.tree)
+                          c.committer)
+        self.assertEquals('d80c186a03f423a81b39df39dc87fd269736ca86', c.tree)
         self.assertEquals(['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
         self.assertEquals(['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
-                          '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
-            c.parents)
+                           '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
+                          c.parents)
+        expected_time = datetime.datetime(2007, 3, 24, 22, 1, 59)
+        self.assertEquals(expected_time,
+                          datetime.datetime.utcfromtimestamp(c.commit_time))
+        self.assertEquals(0, c.commit_timezone)
+        self.assertEquals(expected_time,
+                          datetime.datetime.utcfromtimestamp(c.author_time))
+        self.assertEquals(0, c.author_timezone)
+        self.assertEquals(None, c.encoding)
 
 
     def test_custom(self):
     def test_custom(self):
-        c = Commit.from_string(
-                'tree d80c186a03f423a81b39df39dc87fd269736ca86\n'
-                'parent ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd\n'
-                'parent 4cffe90e0a41ad3f5190079d7c8f036bde29cbe6\n'
-                'author James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                'committer James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
-                'extra-field data\n'
-                '\n'
-                'Merge ../b\n')
+        c = Commit.from_string(self.make_commit_text(
+          extra={'extra-field': 'data'}))
         self.assertEquals([('extra-field', 'data')], c.extra)
         self.assertEquals([('extra-field', 'data')], c.extra)
 
 
-
-class TreeSerializationTests(unittest.TestCase):
+    def test_encoding(self):
+        c = Commit.from_string(self.make_commit_text(encoding='UTF-8'))
+        self.assertEquals('UTF-8', c.encoding)
+
+    def test_check(self):
+        self.assertCheckSucceeds(Commit, self.make_commit_text())
+        self.assertCheckSucceeds(Commit, self.make_commit_text(parents=None))
+        self.assertCheckSucceeds(Commit,
+                                 self.make_commit_text(encoding='UTF-8'))
+
+        self.assertCheckFails(Commit, self.make_commit_text(tree='xxx'))
+        self.assertCheckFails(Commit, self.make_commit_text(
+          parents=[a_sha, 'xxx']))
+        bad_committer = "some guy without an email address 1174773719 +0000"
+        self.assertCheckFails(Commit,
+                              self.make_commit_text(committer=bad_committer))
+        self.assertCheckFails(Commit,
+                              self.make_commit_text(author=bad_committer))
+        self.assertCheckFails(Commit, self.make_commit_text(author=None))
+        self.assertCheckFails(Commit, self.make_commit_text(committer=None))
+        self.assertCheckFails(Commit, self.make_commit_text(
+          author=None, committer=None))
+
+    def test_check_duplicates(self):
+        # duplicate each of the header fields
+        for i in xrange(5):
+            lines = self.make_commit_lines(parents=[a_sha], encoding='UTF-8')
+            lines.insert(i, lines[i])
+            text = '\n'.join(lines)
+            if lines[i].startswith('parent'):
+                # duplicate parents are ok for now
+                self.assertCheckSucceeds(Commit, text)
+            else:
+                self.assertCheckFails(Commit, text)
+
+    def test_check_order(self):
+        lines = self.make_commit_lines(parents=[a_sha], encoding='UTF-8')
+        headers = lines[:5]
+        rest = lines[5:]
+        # of all possible permutations, ensure only the original succeeds
+        for perm in permutations(headers):
+            perm = list(perm)
+            text = '\n'.join(perm + rest)
+            if perm == headers:
+                self.assertCheckSucceeds(Commit, text)
+            else:
+                self.assertCheckFails(Commit, text)
+
+
+class TreeTests(ShaFileCheckTests):
 
 
     def test_simple(self):
     def test_simple(self):
         myhexsha = "d80c186a03f423a81b39df39dc87fd269736ca86"
         myhexsha = "d80c186a03f423a81b39df39dc87fd269736ca86"
@@ -247,6 +413,13 @@ class TreeSerializationTests(unittest.TestCase):
         self.assertEquals('100755 myname\0' + hex_to_sha(myhexsha),
         self.assertEquals('100755 myname\0' + hex_to_sha(myhexsha),
                 x.as_raw_string())
                 x.as_raw_string())
 
 
+    def test_tree_update_id(self):
+        x = Tree()
+        x["a.c"] = (0100755, "d80c186a03f423a81b39df39dc87fd269736ca86")
+        self.assertEquals("0c5c6bc2c081accfbc250331b19e43b904ab9cdd", x.id)
+        x["a.b"] = (stat.S_IFDIR, "d80c186a03f423a81b39df39dc87fd269736ca86")
+        self.assertEquals("07bfcb5f3ada15bbebdfa3bbb8fd858a363925c8", x.id)
+
     def test_tree_dir_sort(self):
     def test_tree_dir_sort(self):
         x = Tree()
         x = Tree()
         x["a.c"] = (0100755, "d80c186a03f423a81b39df39dc87fd269736ca86")
         x["a.c"] = (0100755, "d80c186a03f423a81b39df39dc87fd269736ca86")
@@ -254,35 +427,79 @@ class TreeSerializationTests(unittest.TestCase):
         x["a/c"] = (stat.S_IFDIR, "d80c186a03f423a81b39df39dc87fd269736ca86")
         x["a/c"] = (stat.S_IFDIR, "d80c186a03f423a81b39df39dc87fd269736ca86")
         self.assertEquals(["a.c", "a", "a/c"], [p[0] for p in x.iteritems()])
         self.assertEquals(["a.c", "a", "a/c"], [p[0] for p in x.iteritems()])
 
 
+    def _do_test_parse_tree(self, parse_tree):
+        dir = os.path.join(os.path.dirname(__file__), 'data', 'trees')
+        o = Tree.from_path(hex_to_filename(dir, tree_sha))
+        self.assertEquals([('a', 0100644, a_sha), ('b', 0100644, b_sha)],
+                          list(parse_tree(o.as_raw_string())))
+
+    def test_parse_tree(self):
+        self._do_test_parse_tree(_parse_tree_py)
+
+    def test_parse_tree_extension(self):
+        if parse_tree is _parse_tree_py:
+            raise TestSkipped('parse_tree extension not found')
+        self._do_test_parse_tree(parse_tree)
+
+    def test_check(self):
+        t = Tree
+        sha = hex_to_sha(a_sha)
+
+        # filenames
+        self.assertCheckSucceeds(t, '100644 .a\0%s' % sha)
+        self.assertCheckFails(t, '100644 \0%s' % sha)
+        self.assertCheckFails(t, '100644 .\0%s' % sha)
+        self.assertCheckFails(t, '100644 a/a\0%s' % sha)
+        self.assertCheckFails(t, '100644 ..\0%s' % sha)
+
+        # modes
+        self.assertCheckSucceeds(t, '100644 a\0%s' % sha)
+        self.assertCheckSucceeds(t, '100755 a\0%s' % sha)
+        self.assertCheckSucceeds(t, '160000 a\0%s' % sha)
+        # TODO more whitelisted modes
+        self.assertCheckFails(t, '123456 a\0%s' % sha)
+        self.assertCheckFails(t, '123abc a\0%s' % sha)
+
+        # shas
+        self.assertCheckFails(t, '100644 a\0%s' % ('x' * 5))
+        self.assertCheckFails(t, '100644 a\0%s' % ('x' * 18 + '\0'))
+        self.assertCheckFails(t, '100644 a\0%s\n100644 b\0%s' % ('x' * 21, sha))
+
+        # ordering
+        sha2 = hex_to_sha(b_sha)
+        self.assertCheckSucceeds(t, '100644 a\0%s\n100644 b\0%s' % (sha, sha))
+        self.assertCheckSucceeds(t, '100644 a\0%s\n100644 b\0%s' % (sha, sha2))
+        self.assertCheckFails(t, '100644 a\0%s\n100755 a\0%s' % (sha, sha2))
+        self.assertCheckFails(t, '100644 b\0%s\n100644 a\0%s' % (sha2, sha))
+
+    def test_iter(self):
+        t = Tree()
+        t["foo"] = (0100644, a_sha)
+        self.assertEquals(set(["foo"]), set(t))
+
 
 
 class TagSerializeTests(unittest.TestCase):
 class TagSerializeTests(unittest.TestCase):
 
 
     def test_serialize_simple(self):
     def test_serialize_simple(self):
-        x = Tag()
-        x.tagger = "Jelmer Vernooij <jelmer@samba.org>"
-        x.name = "0.1"
-        x.message = "Tag 0.1"
-        x.object = (3, "d80c186a03f423a81b39df39dc87fd269736ca86")
-        x.tag_time = 423423423
-        x.tag_timezone = 0
-        self.assertEquals("""object d80c186a03f423a81b39df39dc87fd269736ca86
-type blob
-tag 0.1
-tagger Jelmer Vernooij <jelmer@samba.org> 423423423 +0000
-
-Tag 0.1""", x.as_raw_string())
-
-
-class TagParseTests(unittest.TestCase):
-
-    def test_parse_ctime(self):
-        x = Tag()
-        x.set_raw_string("""object a38d6181ff27824c79fc7df825164a212eff6a3f
-type commit
-tag v2.6.22-rc7
-tagger Linus Torvalds <torvalds@woody.linux-foundation.org> Sun Jul 1 12:54:34 2007 -0700
-
-Linux 2.6.22-rc7
+        x = make_object(Tag,
+                        tagger='Jelmer Vernooij <jelmer@samba.org>',
+                        name='0.1',
+                        message='Tag 0.1',
+                        object=(Blob, 'd80c186a03f423a81b39df39dc87fd269736ca86'),
+                        tag_time=423423423,
+                        tag_timezone=0)
+        self.assertEquals(('object d80c186a03f423a81b39df39dc87fd269736ca86\n'
+                           'type blob\n'
+                           'tag 0.1\n'
+                           'tagger Jelmer Vernooij <jelmer@samba.org> '
+                           '423423423 +0000\n'
+                           '\n'
+                           'Tag 0.1'), x.as_raw_string())
+
+
+default_tagger = ('Linus Torvalds <torvalds@woody.linux-foundation.org> '
+                  '1183319674 -0700')
+default_message = """Linux 2.6.22-rc7
 -----BEGIN PGP SIGNATURE-----
 -----BEGIN PGP SIGNATURE-----
 Version: GnuPG v1.4.7 (GNU/Linux)
 Version: GnuPG v1.4.7 (GNU/Linux)
 
 
@@ -290,39 +507,136 @@ iD8DBQBGiAaAF3YsRnbiHLsRAitMAKCiLboJkQECM/jpYsY3WPfvUgLXkACgg3ql
 OK2XeQOiEeXtT76rV4t2WR4=
 OK2XeQOiEeXtT76rV4t2WR4=
 =ivrA
 =ivrA
 -----END PGP SIGNATURE-----
 -----END PGP SIGNATURE-----
-""")
-        self.assertEquals("Linus Torvalds <torvalds@woody.linux-foundation.org>", x.tagger)
+"""
+
+
+class TagParseTests(ShaFileCheckTests):
+    def make_tag_lines(self,
+                       object_sha="a38d6181ff27824c79fc7df825164a212eff6a3f",
+                       object_type_name="commit",
+                       name="v2.6.22-rc7",
+                       tagger=default_tagger,
+                       message=default_message):
+        lines = []
+        if object_sha is not None:
+            lines.append("object %s" % object_sha)
+        if object_type_name is not None:
+            lines.append("type %s" % object_type_name)
+        if name is not None:
+            lines.append("tag %s" % name)
+        if tagger is not None:
+            lines.append("tagger %s" % tagger)
+        lines.append("")
+        if message is not None:
+            lines.append(message)
+        return lines
+
+    def make_tag_text(self, **kwargs):
+        return "\n".join(self.make_tag_lines(**kwargs))
+
+    def test_parse(self):
+        x = Tag()
+        x.set_raw_string(self.make_tag_text())
+        self.assertEquals(
+            "Linus Torvalds <torvalds@woody.linux-foundation.org>", x.tagger)
         self.assertEquals("v2.6.22-rc7", x.name)
         self.assertEquals("v2.6.22-rc7", x.name)
+        object_type, object_sha = x.object
+        self.assertEquals("a38d6181ff27824c79fc7df825164a212eff6a3f",
+                          object_sha)
+        self.assertEquals(Commit, object_type)
+        self.assertEquals(datetime.datetime.utcfromtimestamp(x.tag_time),
+                          datetime.datetime(2007, 7, 1, 19, 54, 34))
+        self.assertEquals(-25200, x.tag_timezone)
 
 
     def test_parse_no_tagger(self):
     def test_parse_no_tagger(self):
         x = Tag()
         x = Tag()
-        x.set_raw_string("""object a38d6181ff27824c79fc7df825164a212eff6a3f
-type commit
-tag v2.6.22-rc7
-
-Linux 2.6.22-rc7
------BEGIN PGP SIGNATURE-----
-Version: GnuPG v1.4.7 (GNU/Linux)
-
-iD8DBQBGiAaAF3YsRnbiHLsRAitMAKCiLboJkQECM/jpYsY3WPfvUgLXkACgg3ql
-OK2XeQOiEeXtT76rV4t2WR4=
-=ivrA
------END PGP SIGNATURE-----
-""")
+        x.set_raw_string(self.make_tag_text(tagger=None))
         self.assertEquals(None, x.tagger)
         self.assertEquals(None, x.tagger)
         self.assertEquals("v2.6.22-rc7", x.name)
         self.assertEquals("v2.6.22-rc7", x.name)
 
 
+    def test_check(self):
+        self.assertCheckSucceeds(Tag, self.make_tag_text())
+        self.assertCheckFails(Tag, self.make_tag_text(object_sha=None))
+        self.assertCheckFails(Tag, self.make_tag_text(object_type_name=None))
+        self.assertCheckFails(Tag, self.make_tag_text(name=None))
+        self.assertCheckFails(Tag, self.make_tag_text(name=''))
+        self.assertCheckFails(Tag, self.make_tag_text(
+          object_type_name="foobar"))
+        self.assertCheckFails(Tag, self.make_tag_text(
+          tagger="some guy without an email address 1183319674 -0700"))
+        self.assertCheckFails(Tag, self.make_tag_text(
+          tagger=("Linus Torvalds <torvalds@woody.linux-foundation.org> "
+                  "Sun 7 Jul 2007 12:54:34 +0700")))
+        self.assertCheckFails(Tag, self.make_tag_text(object_sha="xxx"))
+
+    def test_check_duplicates(self):
+        # duplicate each of the header fields
+        for i in xrange(4):
+            lines = self.make_tag_lines()
+            lines.insert(i, lines[i])
+            self.assertCheckFails(Tag, '\n'.join(lines))
+
+    def test_check_order(self):
+        lines = self.make_tag_lines()
+        headers = lines[:4]
+        rest = lines[4:]
+        # of all possible permutations, ensure only the original succeeds
+        for perm in permutations(headers):
+            perm = list(perm)
+            text = '\n'.join(perm + rest)
+            if perm == headers:
+                self.assertCheckSucceeds(Tag, text)
+            else:
+                self.assertCheckFails(Tag, text)
+
+
+class CheckTests(unittest.TestCase):
+
+    def test_check_hexsha(self):
+        check_hexsha(a_sha, "failed to check good sha")
+        self.assertRaises(ObjectFormatException, check_hexsha, '1' * 39,
+                          'sha too short')
+        self.assertRaises(ObjectFormatException, check_hexsha, '1' * 41,
+                          'sha too long')
+        self.assertRaises(ObjectFormatException, check_hexsha, 'x' * 40,
+                          'invalid characters')
+
+    def test_check_identity(self):
+        check_identity("Dave Borowitz <dborowitz@google.com>",
+                       "failed to check good identity")
+        check_identity("<dborowitz@google.com>",
+                       "failed to check good identity")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz", "no email")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <dborowitz", "incomplete email")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "dborowitz@google.com>", "incomplete email")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <<dborowitz@google.com>", "typo")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <dborowitz@google.com>>", "typo")
+        self.assertRaises(ObjectFormatException, check_identity,
+                          "Dave Borowitz <dborowitz@google.com>xxx",
+                          "trailing characters")
+
 
 
 class TimezoneTests(unittest.TestCase):
 class TimezoneTests(unittest.TestCase):
 
 
     def test_parse_timezone_utc(self):
     def test_parse_timezone_utc(self):
-        self.assertEquals(0, parse_timezone("+0000"))
+        self.assertEquals((0, False), parse_timezone("+0000"))
+
+    def test_parse_timezone_utc_negative(self):
+        self.assertEquals((0, True), parse_timezone("-0000"))
 
 
     def test_generate_timezone_utc(self):
     def test_generate_timezone_utc(self):
         self.assertEquals("+0000", format_timezone(0))
         self.assertEquals("+0000", format_timezone(0))
 
 
+    def test_generate_timezone_utc_negative(self):
+        self.assertEquals("-0000", format_timezone(0, True))
+
     def test_parse_timezone_cet(self):
     def test_parse_timezone_cet(self):
-        self.assertEquals(60 * 60, parse_timezone("+0100"))
+        self.assertEquals((60 * 60, False), parse_timezone("+0100"))
 
 
     def test_format_timezone_cet(self):
     def test_format_timezone_cet(self):
         self.assertEquals("+0100", format_timezone(60 * 60))
         self.assertEquals("+0100", format_timezone(60 * 60))
@@ -331,10 +645,12 @@ class TimezoneTests(unittest.TestCase):
         self.assertEquals("-0400", format_timezone(-4 * 60 * 60))
         self.assertEquals("-0400", format_timezone(-4 * 60 * 60))
 
 
     def test_parse_timezone_pdt(self):
     def test_parse_timezone_pdt(self):
-        self.assertEquals(-4 * 60 * 60, parse_timezone("-0400"))
+        self.assertEquals((-4 * 60 * 60, False), parse_timezone("-0400"))
 
 
     def test_format_timezone_pdt_half(self):
     def test_format_timezone_pdt_half(self):
-        self.assertEquals("-0440", format_timezone(int(((-4 * 60) - 40) * 60)))
+        self.assertEquals("-0440",
+            format_timezone(int(((-4 * 60) - 40) * 60)))
 
 
     def test_parse_timezone_pdt_half(self):
     def test_parse_timezone_pdt_half(self):
-        self.assertEquals(((-4 * 60) - 40) * 60, parse_timezone("-0440"))
+        self.assertEquals((((-4 * 60) - 40) * 60, False),
+            parse_timezone("-0440"))

+ 212 - 86
dulwich/tests/test_pack.py

@@ -1,17 +1,17 @@
 # test_pack.py -- Tests for the handling of git packs.
 # test_pack.py -- Tests for the handling of git packs.
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
 # Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # as published by the Free Software Foundation; version 2
 # of the License, or (at your option) any later version of the license.
 # of the License, or (at your option) any later version of the license.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -23,9 +23,17 @@
 
 
 from cStringIO import StringIO
 from cStringIO import StringIO
 import os
 import os
+import shutil
+import tempfile
 import unittest
 import unittest
+import zlib
 
 
+from dulwich.errors import (
+    ChecksumMismatch,
+    )
 from dulwich.objects import (
 from dulwich.objects import (
+    hex_to_sha,
+    sha_to_hex,
     Tree,
     Tree,
     )
     )
 from dulwich.pack import (
 from dulwich.pack import (
@@ -35,7 +43,7 @@ from dulwich.pack import (
     create_delta,
     create_delta,
     load_pack_index,
     load_pack_index,
     hex_to_sha,
     hex_to_sha,
-    read_zlib,
+    read_zlib_chunks,
     sha_to_hex,
     sha_to_hex,
     write_pack_index_v1,
     write_pack_index_v1,
     write_pack_index_v2,
     write_pack_index_v2,
@@ -48,26 +56,39 @@ a_sha = '6f670c0fb53f9463760b7295fbb814e965fb20c8'
 tree_sha = 'b2a2766a2879c209ab1176e7e778b81ae422eeaa'
 tree_sha = 'b2a2766a2879c209ab1176e7e778b81ae422eeaa'
 commit_sha = 'f18faa16531ac570a3fdc8c7ca16682548dafd12'
 commit_sha = 'f18faa16531ac570a3fdc8c7ca16682548dafd12'
 
 
+
 class PackTests(unittest.TestCase):
 class PackTests(unittest.TestCase):
     """Base class for testing packs"""
     """Base class for testing packs"""
-  
+
+    def setUp(self):
+        self.tempdir = tempfile.mkdtemp()
+
+    def tearDown(self):
+        shutil.rmtree(self.tempdir)
+
     datadir = os.path.join(os.path.dirname(__file__), 'data/packs')
     datadir = os.path.join(os.path.dirname(__file__), 'data/packs')
-  
+
     def get_pack_index(self, sha):
     def get_pack_index(self, sha):
         """Returns a PackIndex from the datadir with the given sha"""
         """Returns a PackIndex from the datadir with the given sha"""
         return load_pack_index(os.path.join(self.datadir, 'pack-%s.idx' % sha))
         return load_pack_index(os.path.join(self.datadir, 'pack-%s.idx' % sha))
-  
+
     def get_pack_data(self, sha):
     def get_pack_data(self, sha):
         """Returns a PackData object from the datadir with the given sha"""
         """Returns a PackData object from the datadir with the given sha"""
         return PackData(os.path.join(self.datadir, 'pack-%s.pack' % sha))
         return PackData(os.path.join(self.datadir, 'pack-%s.pack' % sha))
-  
+
     def get_pack(self, sha):
     def get_pack(self, sha):
         return Pack(os.path.join(self.datadir, 'pack-%s' % sha))
         return Pack(os.path.join(self.datadir, 'pack-%s' % sha))
 
 
+    def assertSucceeds(self, func, *args, **kwargs):
+        try:
+            func(*args, **kwargs)
+        except ChecksumMismatch, e:
+            self.fail(e)
+
 
 
 class PackIndexTests(PackTests):
 class PackIndexTests(PackTests):
     """Class that tests the index of packfiles"""
     """Class that tests the index of packfiles"""
-  
+
     def test_object_index(self):
     def test_object_index(self):
         """Tests that the correct object offset is returned from the index."""
         """Tests that the correct object offset is returned from the index."""
         p = self.get_pack_index(pack1_sha)
         p = self.get_pack_index(pack1_sha)
@@ -75,88 +96,118 @@ class PackIndexTests(PackTests):
         self.assertEqual(p.object_index(a_sha), 178)
         self.assertEqual(p.object_index(a_sha), 178)
         self.assertEqual(p.object_index(tree_sha), 138)
         self.assertEqual(p.object_index(tree_sha), 138)
         self.assertEqual(p.object_index(commit_sha), 12)
         self.assertEqual(p.object_index(commit_sha), 12)
-  
+
     def test_index_len(self):
     def test_index_len(self):
         p = self.get_pack_index(pack1_sha)
         p = self.get_pack_index(pack1_sha)
         self.assertEquals(3, len(p))
         self.assertEquals(3, len(p))
-  
+
     def test_get_stored_checksum(self):
     def test_get_stored_checksum(self):
         p = self.get_pack_index(pack1_sha)
         p = self.get_pack_index(pack1_sha)
-        self.assertEquals("\xf2\x84\x8e*\xd1o2\x9a\xe1\xc9.;\x95\xe9\x18\x88\xda\xa5\xbd\x01", str(p.get_stored_checksum()))
-        self.assertEquals( 'r\x19\x80\xe8f\xaf\x9a_\x93\xadgAD\xe1E\x9b\x8b\xa3\xe7\xb7' , str(p.get_pack_checksum()))
-  
+        self.assertEquals('f2848e2ad16f329ae1c92e3b95e91888daa5bd01',
+                          sha_to_hex(p.get_stored_checksum()))
+        self.assertEquals('721980e866af9a5f93ad674144e1459b8ba3e7b7',
+                          sha_to_hex(p.get_pack_checksum()))
+
     def test_index_check(self):
     def test_index_check(self):
         p = self.get_pack_index(pack1_sha)
         p = self.get_pack_index(pack1_sha)
-        self.assertEquals(True, p.check())
-  
+        self.assertSucceeds(p.check)
+
     def test_iterentries(self):
     def test_iterentries(self):
         p = self.get_pack_index(pack1_sha)
         p = self.get_pack_index(pack1_sha)
-        self.assertEquals([('og\x0c\x0f\xb5?\x94cv\x0br\x95\xfb\xb8\x14\xe9e\xfb \xc8', 178, None), ('\xb2\xa2vj(y\xc2\t\xab\x11v\xe7\xe7x\xb8\x1a\xe4"\xee\xaa', 138, None), ('\xf1\x8f\xaa\x16S\x1a\xc5p\xa3\xfd\xc8\xc7\xca\x16h%H\xda\xfd\x12', 12, None)], list(p.iterentries()))
-  
+        entries = [(sha_to_hex(s), o, c) for s, o, c in p.iterentries()]
+        self.assertEquals([
+          ('6f670c0fb53f9463760b7295fbb814e965fb20c8', 178, None),
+          ('b2a2766a2879c209ab1176e7e778b81ae422eeaa', 138, None),
+          ('f18faa16531ac570a3fdc8c7ca16682548dafd12', 12, None)
+          ], entries)
+
     def test_iter(self):
     def test_iter(self):
         p = self.get_pack_index(pack1_sha)
         p = self.get_pack_index(pack1_sha)
         self.assertEquals(set([tree_sha, commit_sha, a_sha]), set(p))
         self.assertEquals(set([tree_sha, commit_sha, a_sha]), set(p))
-  
+
 
 
 class TestPackDeltas(unittest.TestCase):
 class TestPackDeltas(unittest.TestCase):
-  
-    test_string1 = "The answer was flailing in the wind"
-    test_string2 = "The answer was falling down the pipe"
-    test_string3 = "zzzzz"
-  
-    test_string_empty = ""
-    test_string_big = "Z" * 8192
-  
+
+    test_string1 = 'The answer was flailing in the wind'
+    test_string2 = 'The answer was falling down the pipe'
+    test_string3 = 'zzzzz'
+
+    test_string_empty = ''
+    test_string_big = 'Z' * 8192
+
     def _test_roundtrip(self, base, target):
     def _test_roundtrip(self, base, target):
         self.assertEquals(target,
         self.assertEquals(target,
-            apply_delta(base, create_delta(base, target)))
-  
+                          ''.join(apply_delta(base, create_delta(base, target))))
+
     def test_nochange(self):
     def test_nochange(self):
         self._test_roundtrip(self.test_string1, self.test_string1)
         self._test_roundtrip(self.test_string1, self.test_string1)
-  
+
     def test_change(self):
     def test_change(self):
         self._test_roundtrip(self.test_string1, self.test_string2)
         self._test_roundtrip(self.test_string1, self.test_string2)
-  
+
     def test_rewrite(self):
     def test_rewrite(self):
         self._test_roundtrip(self.test_string1, self.test_string3)
         self._test_roundtrip(self.test_string1, self.test_string3)
-  
+
     def test_overflow(self):
     def test_overflow(self):
         self._test_roundtrip(self.test_string_empty, self.test_string_big)
         self._test_roundtrip(self.test_string_empty, self.test_string_big)
 
 
 
 
 class TestPackData(PackTests):
 class TestPackData(PackTests):
     """Tests getting the data from the packfile."""
     """Tests getting the data from the packfile."""
-  
+
     def test_create_pack(self):
     def test_create_pack(self):
         p = self.get_pack_data(pack1_sha)
         p = self.get_pack_data(pack1_sha)
-  
+
     def test_pack_len(self):
     def test_pack_len(self):
         p = self.get_pack_data(pack1_sha)
         p = self.get_pack_data(pack1_sha)
         self.assertEquals(3, len(p))
         self.assertEquals(3, len(p))
-  
+
     def test_index_check(self):
     def test_index_check(self):
         p = self.get_pack_data(pack1_sha)
         p = self.get_pack_data(pack1_sha)
-        self.assertEquals(True, p.check())
-  
+        self.assertSucceeds(p.check)
+
     def test_iterobjects(self):
     def test_iterobjects(self):
         p = self.get_pack_data(pack1_sha)
         p = self.get_pack_data(pack1_sha)
-        self.assertEquals([(12, 1, 'tree b2a2766a2879c209ab1176e7e778b81ae422eeaa\nauthor James Westby <jw+debian@jameswestby.net> 1174945067 +0100\ncommitter James Westby <jw+debian@jameswestby.net> 1174945067 +0100\n\nTest commit\n', 3775879613L), (138, 2, '100644 a\x00og\x0c\x0f\xb5?\x94cv\x0br\x95\xfb\xb8\x14\xe9e\xfb \xc8', 912998690L), (178, 3, 'test 1\n', 1373561701L)], list(p.iterobjects()))
-  
+        commit_data = ('tree b2a2766a2879c209ab1176e7e778b81ae422eeaa\n'
+                       'author James Westby <jw+debian@jameswestby.net> '
+                       '1174945067 +0100\n'
+                       'committer James Westby <jw+debian@jameswestby.net> '
+                       '1174945067 +0100\n'
+                       '\n'
+                       'Test commit\n')
+        blob_sha = '6f670c0fb53f9463760b7295fbb814e965fb20c8'
+        tree_data = '100644 a\0%s' % hex_to_sha(blob_sha)
+        actual = []
+        for offset, type_num, chunks, crc32 in p.iterobjects():
+            actual.append((offset, type_num, ''.join(chunks), crc32))
+        self.assertEquals([
+          (12, 1, commit_data, 3775879613L),
+          (138, 2, tree_data, 912998690L),
+          (178, 3, 'test 1\n', 1373561701L)
+          ], actual)
+
     def test_iterentries(self):
     def test_iterentries(self):
         p = self.get_pack_data(pack1_sha)
         p = self.get_pack_data(pack1_sha)
-        self.assertEquals(set([('og\x0c\x0f\xb5?\x94cv\x0br\x95\xfb\xb8\x14\xe9e\xfb \xc8', 178, 1373561701L), ('\xb2\xa2vj(y\xc2\t\xab\x11v\xe7\xe7x\xb8\x1a\xe4"\xee\xaa', 138, 912998690L), ('\xf1\x8f\xaa\x16S\x1a\xc5p\xa3\xfd\xc8\xc7\xca\x16h%H\xda\xfd\x12', 12, 3775879613L)]), set(p.iterentries()))
-  
+        entries = set((sha_to_hex(s), o, c) for s, o, c in p.iterentries())
+        self.assertEquals(set([
+          ('6f670c0fb53f9463760b7295fbb814e965fb20c8', 178, 1373561701L),
+          ('b2a2766a2879c209ab1176e7e778b81ae422eeaa', 138, 912998690L),
+          ('f18faa16531ac570a3fdc8c7ca16682548dafd12', 12, 3775879613L),
+          ]), entries)
+
     def test_create_index_v1(self):
     def test_create_index_v1(self):
         p = self.get_pack_data(pack1_sha)
         p = self.get_pack_data(pack1_sha)
-        p.create_index_v1("v1test.idx")
-        idx1 = load_pack_index("v1test.idx")
+        filename = os.path.join(self.tempdir, 'v1test.idx')
+        p.create_index_v1(filename)
+        idx1 = load_pack_index(filename)
         idx2 = self.get_pack_index(pack1_sha)
         idx2 = self.get_pack_index(pack1_sha)
         self.assertEquals(idx1, idx2)
         self.assertEquals(idx1, idx2)
-  
+
     def test_create_index_v2(self):
     def test_create_index_v2(self):
         p = self.get_pack_data(pack1_sha)
         p = self.get_pack_data(pack1_sha)
-        p.create_index_v2("v2test.idx")
-        idx1 = load_pack_index("v2test.idx")
+        filename = os.path.join(self.tempdir, 'v2test.idx')
+        p.create_index_v2(filename)
+        idx1 = load_pack_index(filename)
         idx2 = self.get_pack_index(pack1_sha)
         idx2 = self.get_pack_index(pack1_sha)
         self.assertEquals(idx1, idx2)
         self.assertEquals(idx1, idx2)
 
 
@@ -183,36 +234,38 @@ class TestPack(PackTests):
         """Tests random access for non-delta objects"""
         """Tests random access for non-delta objects"""
         p = self.get_pack(pack1_sha)
         p = self.get_pack(pack1_sha)
         obj = p[a_sha]
         obj = p[a_sha]
-        self.assertEqual(obj._type, 'blob')
+        self.assertEqual(obj.type_name, 'blob')
         self.assertEqual(obj.sha().hexdigest(), a_sha)
         self.assertEqual(obj.sha().hexdigest(), a_sha)
         obj = p[tree_sha]
         obj = p[tree_sha]
-        self.assertEqual(obj._type, 'tree')
+        self.assertEqual(obj.type_name, 'tree')
         self.assertEqual(obj.sha().hexdigest(), tree_sha)
         self.assertEqual(obj.sha().hexdigest(), tree_sha)
         obj = p[commit_sha]
         obj = p[commit_sha]
-        self.assertEqual(obj._type, 'commit')
+        self.assertEqual(obj.type_name, 'commit')
         self.assertEqual(obj.sha().hexdigest(), commit_sha)
         self.assertEqual(obj.sha().hexdigest(), commit_sha)
 
 
     def test_copy(self):
     def test_copy(self):
         origpack = self.get_pack(pack1_sha)
         origpack = self.get_pack(pack1_sha)
-        self.assertEquals(True, origpack.index.check())
-        write_pack("Elch", [(x, "") for x in origpack.iterobjects()], 
-            len(origpack))
-        newpack = Pack("Elch")
+        self.assertSucceeds(origpack.index.check)
+        basename = os.path.join(self.tempdir, 'Elch')
+        write_pack(basename, [(x, '') for x in origpack.iterobjects()],
+                   len(origpack))
+        newpack = Pack(basename)
         self.assertEquals(origpack, newpack)
         self.assertEquals(origpack, newpack)
-        self.assertEquals(True, newpack.index.check())
+        self.assertSucceeds(newpack.index.check)
         self.assertEquals(origpack.name(), newpack.name())
         self.assertEquals(origpack.name(), newpack.name())
-        self.assertEquals(origpack.index.get_pack_checksum(), 
+        self.assertEquals(origpack.index.get_pack_checksum(),
                           newpack.index.get_pack_checksum())
                           newpack.index.get_pack_checksum())
-        
-        self.assertTrue(
-                (origpack.index.version != newpack.index.version) or
-                (origpack.index.get_stored_checksum() == newpack.index.get_stored_checksum()))
+
+        wrong_version = origpack.index.version != newpack.index.version
+        orig_checksum = origpack.index.get_stored_checksum()
+        new_checksum = newpack.index.get_stored_checksum()
+        self.assertTrue(wrong_version or orig_checksum == new_checksum)
 
 
     def test_commit_obj(self):
     def test_commit_obj(self):
         p = self.get_pack(pack1_sha)
         p = self.get_pack(pack1_sha)
         commit = p[commit_sha]
         commit = p[commit_sha]
-        self.assertEquals("James Westby <jw+debian@jameswestby.net>",
-            commit.author)
+        self.assertEquals('James Westby <jw+debian@jameswestby.net>',
+                          commit.author)
         self.assertEquals([], commit.parents)
         self.assertEquals([], commit.parents)
 
 
     def test_name(self):
     def test_name(self):
@@ -220,69 +273,142 @@ class TestPack(PackTests):
         self.assertEquals(pack1_sha, p.name())
         self.assertEquals(pack1_sha, p.name())
 
 
 
 
-class TestHexToSha(unittest.TestCase):
+pack_checksum = hex_to_sha('721980e866af9a5f93ad674144e1459b8ba3e7b7')
 
 
-    def test_simple(self):
-        self.assertEquals('\xab\xcd' * 10, hex_to_sha("abcd" * 10))
 
 
-    def test_reverse(self):
-        self.assertEquals("abcd" * 10, sha_to_hex('\xab\xcd' * 10))
+class BaseTestPackIndexWriting(object):
 
 
+    def setUp(self):
+        self.tempdir = tempfile.mkdtemp()
 
 
-class BaseTestPackIndexWriting(object):
+    def tearDown(self):
+        shutil.rmtree(self.tempdir)
+
+    def assertSucceeds(self, func, *args, **kwargs):
+        try:
+            func(*args, **kwargs)
+        except ChecksumMismatch, e:
+            self.fail(e)
 
 
     def test_empty(self):
     def test_empty(self):
-        pack_checksum = 'r\x19\x80\xe8f\xaf\x9a_\x93\xadgAD\xe1E\x9b\x8b\xa3\xe7\xb7'
-        self._write_fn("empty.idx", [], pack_checksum)
-        idx = load_pack_index("empty.idx")
-        self.assertTrue(idx.check())
+        filename = os.path.join(self.tempdir, 'empty.idx')
+        self._write_fn(filename, [], pack_checksum)
+        idx = load_pack_index(filename)
+        self.assertSucceeds(idx.check)
         self.assertEquals(idx.get_pack_checksum(), pack_checksum)
         self.assertEquals(idx.get_pack_checksum(), pack_checksum)
         self.assertEquals(0, len(idx))
         self.assertEquals(0, len(idx))
 
 
     def test_single(self):
     def test_single(self):
-        pack_checksum = 'r\x19\x80\xe8f\xaf\x9a_\x93\xadgAD\xe1E\x9b\x8b\xa3\xe7\xb7'
-        my_entries = [('og\x0c\x0f\xb5?\x94cv\x0br\x95\xfb\xb8\x14\xe9e\xfb \xc8', 178, 42)]
-        my_entries.sort()
-        self._write_fn("single.idx", my_entries, pack_checksum)
-        idx = load_pack_index("single.idx")
+        entry_sha = hex_to_sha('6f670c0fb53f9463760b7295fbb814e965fb20c8')
+        my_entries = [(entry_sha, 178, 42)]
+        filename = os.path.join(self.tempdir, 'single.idx')
+        self._write_fn(filename, my_entries, pack_checksum)
+        idx = load_pack_index(filename)
         self.assertEquals(idx.version, self._expected_version)
         self.assertEquals(idx.version, self._expected_version)
-        self.assertTrue(idx.check())
+        self.assertSucceeds(idx.check)
         self.assertEquals(idx.get_pack_checksum(), pack_checksum)
         self.assertEquals(idx.get_pack_checksum(), pack_checksum)
         self.assertEquals(1, len(idx))
         self.assertEquals(1, len(idx))
         actual_entries = list(idx.iterentries())
         actual_entries = list(idx.iterentries())
         self.assertEquals(len(my_entries), len(actual_entries))
         self.assertEquals(len(my_entries), len(actual_entries))
-        for a, b in zip(my_entries, actual_entries):
-            self.assertEquals(a[0], b[0])
-            self.assertEquals(a[1], b[1])
+        for mine, actual in zip(my_entries, actual_entries):
+            my_sha, my_offset, my_crc = mine
+            actual_sha, actual_offset, actual_crc = actual
+            self.assertEquals(my_sha, actual_sha)
+            self.assertEquals(my_offset, actual_offset)
             if self._has_crc32_checksum:
             if self._has_crc32_checksum:
-                self.assertEquals(a[2], b[2])
+                self.assertEquals(my_crc, actual_crc)
             else:
             else:
-                self.assertTrue(b[2] is None)
+                self.assertTrue(actual_crc is None)
 
 
 
 
 class TestPackIndexWritingv1(unittest.TestCase, BaseTestPackIndexWriting):
 class TestPackIndexWritingv1(unittest.TestCase, BaseTestPackIndexWriting):
 
 
     def setUp(self):
     def setUp(self):
         unittest.TestCase.setUp(self)
         unittest.TestCase.setUp(self)
+        BaseTestPackIndexWriting.setUp(self)
         self._has_crc32_checksum = False
         self._has_crc32_checksum = False
         self._expected_version = 1
         self._expected_version = 1
         self._write_fn = write_pack_index_v1
         self._write_fn = write_pack_index_v1
 
 
+    def tearDown(self):
+        unittest.TestCase.tearDown(self)
+        BaseTestPackIndexWriting.tearDown(self)
+
 
 
 class TestPackIndexWritingv2(unittest.TestCase, BaseTestPackIndexWriting):
 class TestPackIndexWritingv2(unittest.TestCase, BaseTestPackIndexWriting):
 
 
     def setUp(self):
     def setUp(self):
         unittest.TestCase.setUp(self)
         unittest.TestCase.setUp(self)
+        BaseTestPackIndexWriting.setUp(self)
         self._has_crc32_checksum = True
         self._has_crc32_checksum = True
         self._expected_version = 2
         self._expected_version = 2
         self._write_fn = write_pack_index_v2
         self._write_fn = write_pack_index_v2
 
 
-TEST_COMP1 = """\x78\x9c\x9d\x8e\xc1\x0a\xc2\x30\x10\x44\xef\xf9\x8a\xbd\xa9\x08\x92\x86\xb4\x26\x20\xe2\xd9\x83\x78\xf2\xbe\x49\x37\xb5\xa5\x69\xca\x36\xf5\xfb\x4d\xfd\x04\x67\x6e\x33\xcc\xf0\x32\x13\x81\xc6\x16\x8d\xa9\xbd\xad\x6c\xe3\x8a\x03\x4a\x73\xd6\xda\xd5\xa6\x51\x2e\x58\x65\x6c\x13\xbc\x94\x4a\xcc\xc8\x34\x65\x78\xa4\x89\x04\xae\xf9\x9d\x18\xee\x34\x46\x62\x78\x11\x4f\x29\xf5\x03\x5c\x86\x5f\x70\x5b\x30\x3a\x3c\x25\xee\xae\x50\xa9\xf2\x60\xa4\xaa\x34\x1c\x65\x91\xf0\x29\xc6\x3e\x67\xfa\x6f\x2d\x9e\x9c\x3e\x7d\x4b\xc0\x34\x8f\xe8\x29\x6e\x48\xa1\xa0\xc4\x88\xf3\xfe\xb0\x5b\x20\x85\xb0\x50\x06\xe4\x6e\xdd\xca\xd3\x17\x26\xfa\x49\x23"""
+    def tearDown(self):
+        unittest.TestCase.tearDown(self)
+        BaseTestPackIndexWriting.tearDown(self)
+
 
 
+class ReadZlibTests(unittest.TestCase):
 
 
-class ZlibTests(unittest.TestCase):
+    decomp = (
+      'tree 4ada885c9196b6b6fa08744b5862bf92896fc002\n'
+      'parent None\n'
+      'author Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\n'
+      'committer Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\n'
+      '\n'
+      "Provide replacement for mmap()'s offset argument.")
+    comp = zlib.compress(decomp)
+    extra = 'nextobject'
+
+    def setUp(self):
+        self.read = StringIO(self.comp + self.extra).read
+
+    def test_decompress_size(self):
+        good_decomp_len = len(self.decomp)
+        self.assertRaises(ValueError, read_zlib_chunks, self.read, -1)
+        self.assertRaises(zlib.error, read_zlib_chunks, self.read,
+                          good_decomp_len - 1)
+        self.assertRaises(zlib.error, read_zlib_chunks, self.read,
+                          good_decomp_len + 1)
+
+    def test_decompress_truncated(self):
+        read = StringIO(self.comp[:10]).read
+        self.assertRaises(zlib.error, read_zlib_chunks, read, len(self.decomp))
+
+        read = StringIO(self.comp).read
+        self.assertRaises(zlib.error, read_zlib_chunks, read, len(self.decomp))
+
+    def test_decompress_empty(self):
+        comp = zlib.compress('')
+        read = StringIO(comp + self.extra).read
+        decomp, comp_len, unused_data = read_zlib_chunks(read, 0)
+        self.assertEqual('', ''.join(decomp))
+        self.assertEqual(len(comp), comp_len)
+        self.assertNotEquals('', unused_data)
+        self.assertEquals(self.extra, unused_data + read())
+
+    def _do_decompress_test(self, buffer_size):
+        decomp, comp_len, unused_data = read_zlib_chunks(
+          self.read, len(self.decomp), buffer_size=buffer_size)
+        self.assertEquals(self.decomp, ''.join(decomp))
+        self.assertEquals(len(self.comp), comp_len)
+        self.assertNotEquals('', unused_data)
+        self.assertEquals(self.extra, unused_data + self.read())
 
 
     def test_simple_decompress(self):
     def test_simple_decompress(self):
-        self.assertEquals(("tree 4ada885c9196b6b6fa08744b5862bf92896fc002\nparent None\nauthor Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\ncommitter Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\n\nProvide replacement for mmap()'s offset argument.", 158, 'Z'), 
-        read_zlib(StringIO(TEST_COMP1).read, 229))
+        self._do_decompress_test(4096)
+
+    # These buffer sizes are not intended to be realistic, but rather simulate
+    # larger buffer sizes that may end at various places.
+    def test_decompress_buffer_size_1(self):
+        self._do_decompress_test(1)
+
+    def test_decompress_buffer_size_2(self):
+        self._do_decompress_test(2)
+
+    def test_decompress_buffer_size_3(self):
+        self._do_decompress_test(3)
 
 
+    def test_decompress_buffer_size_4(self):
+        self._do_decompress_test(4)

+ 88 - 0
dulwich/tests/test_patch.py

@@ -0,0 +1,88 @@
+# test_patch.py -- tests for patch.py
+# Copryight (C) 2010 Jelmer Vernooij <jelmer@samba.org>
+# 
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) a later version.
+# 
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+# 
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Tests for patch.py."""
+
+from cStringIO import StringIO
+from unittest import TestCase
+
+from dulwich.objects import (
+    Commit,
+    Tree,
+    )
+from dulwich.patch import (
+    git_am_patch_split,
+    write_commit_patch,
+    )
+
+
+class WriteCommitPatchTests(TestCase):
+
+    def test_simple(self):
+        f = StringIO()
+        c = Commit()
+        c.committer = c.author = "Jelmer <jelmer@samba.org>"
+        c.commit_time = c.author_time = 1271350201
+        c.commit_timezone = c.author_timezone = 0
+        c.message = "This is the first line\nAnd this is the second line.\n"
+        c.tree = Tree().id
+        write_commit_patch(f, c, "CONTENTS", (1, 1), version="custom")
+        f.seek(0)
+        lines = f.readlines()
+        self.assertTrue(lines[0].startswith("From 0b0d34d1b5b596c928adc9a727a4b9e03d025298"))
+        self.assertEquals(lines[1], "From: Jelmer <jelmer@samba.org>\n")
+        self.assertTrue(lines[2].startswith("Date: "))
+        self.assertEquals([
+            "Subject: [PATCH 1/1] This is the first line\n",
+            "And this is the second line.\n",
+            "\n",
+            "\n",
+            "---\n"], lines[3:8])
+        self.assertEquals([
+            "CONTENTS-- \n",
+            "custom\n"], lines[-2:])
+        if len(lines) >= 12:
+            # diffstat may not be present
+            self.assertEquals(lines[8], " 0 files changed\n")
+
+
+class ReadGitAmPatch(TestCase):
+
+    def test_extract(self):
+        text = """From ff643aae102d8870cac88e8f007e70f58f3a7363 Mon Sep 17 00:00:00 2001
+From: Jelmer Vernooij <jelmer@samba.org>
+Date: Thu, 15 Apr 2010 15:40:28 +0200
+Subject: [PATCH 1/2] Remove executable bit from prey.ico (triggers a lintian warning).
+
+---
+ pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
+ 1 files changed, 0 insertions(+), 0 deletions(-)
+ mode change 100755 => 100644 pixmaps/prey.ico
+
+-- 
+1.7.0.4
+"""
+        c, diff, version = git_am_patch_split(StringIO(text))
+        self.assertEquals("Jelmer Vernooij <jelmer@samba.org>", c.committer)
+        self.assertEquals("Jelmer Vernooij <jelmer@samba.org>", c.author)
+        self.assertEquals(""" pixmaps/prey.ico |  Bin 9662 -> 9662 bytes
+ 1 files changed, 0 insertions(+), 0 deletions(-)
+ mode change 100755 => 100644 pixmaps/prey.ico
+
+""", diff)
+        self.assertEquals("1.7.0.4", version)

+ 90 - 7
dulwich/tests/test_protocol.py

@@ -20,11 +20,12 @@
 """Tests for the smart protocol utility functions."""
 """Tests for the smart protocol utility functions."""
 
 
 
 
-from cStringIO import StringIO
+from StringIO import StringIO
 from unittest import TestCase
 from unittest import TestCase
 
 
 from dulwich.protocol import (
 from dulwich.protocol import (
     Protocol,
     Protocol,
+    ReceivableProtocol,
     extract_capabilities,
     extract_capabilities,
     extract_want_line_capabilities,
     extract_want_line_capabilities,
     ack_type,
     ack_type,
@@ -33,12 +34,7 @@ from dulwich.protocol import (
     MULTI_ACK_DETAILED,
     MULTI_ACK_DETAILED,
     )
     )
 
 
-class ProtocolTests(TestCase):
-
-    def setUp(self):
-        self.rout = StringIO()
-        self.rin = StringIO()
-        self.proto = Protocol(self.rin.read, self.rout.write)
+class BaseProtocolTests(object):
 
 
     def test_write_pkt_line_none(self):
     def test_write_pkt_line_none(self):
         self.proto.write_pkt_line(None)
         self.proto.write_pkt_line(None)
@@ -82,6 +78,93 @@ class ProtocolTests(TestCase):
         self.assertRaises(AssertionError, self.proto.read_cmd)
         self.assertRaises(AssertionError, self.proto.read_cmd)
 
 
 
 
+class ProtocolTests(BaseProtocolTests, TestCase):
+
+    def setUp(self):
+        TestCase.setUp(self)
+        self.rout = StringIO()
+        self.rin = StringIO()
+        self.proto = Protocol(self.rin.read, self.rout.write)
+
+
+class ReceivableStringIO(StringIO):
+    """StringIO with socket-like recv semantics for testing."""
+
+    def recv(self, size):
+        # fail fast if no bytes are available; in a real socket, this would
+        # block forever
+        if self.tell() == len(self.getvalue()):
+            raise AssertionError("Blocking read past end of socket")
+        if size == 1:
+            return self.read(1)
+        # calls shouldn't return quite as much as asked for
+        return self.read(size - 1)
+
+
+class ReceivableProtocolTests(BaseProtocolTests, TestCase):
+
+    def setUp(self):
+        TestCase.setUp(self)
+        self.rout = StringIO()
+        self.rin = ReceivableStringIO()
+        self.proto = ReceivableProtocol(self.rin.recv, self.rout.write)
+        self.proto._rbufsize = 8
+
+    def test_recv(self):
+        all_data = "1234567" * 10  # not a multiple of bufsize
+        self.rin.write(all_data)
+        self.rin.seek(0)
+        data = ""
+        # We ask for 8 bytes each time and actually read 7, so it should take
+        # exactly 10 iterations.
+        for _ in xrange(10):
+            data += self.proto.recv(10)
+        # any more reads would block
+        self.assertRaises(AssertionError, self.proto.recv, 10)
+        self.assertEquals(all_data, data)
+
+    def test_recv_read(self):
+        all_data = "1234567"  # recv exactly in one call
+        self.rin.write(all_data)
+        self.rin.seek(0)
+        self.assertEquals("1234", self.proto.recv(4))
+        self.assertEquals("567", self.proto.read(3))
+        self.assertRaises(AssertionError, self.proto.recv, 10)
+
+    def test_read_recv(self):
+        all_data = "12345678abcdefg"
+        self.rin.write(all_data)
+        self.rin.seek(0)
+        self.assertEquals("1234", self.proto.read(4))
+        self.assertEquals("5678abc", self.proto.recv(8))
+        self.assertEquals("defg", self.proto.read(4))
+        self.assertRaises(AssertionError, self.proto.recv, 10)
+
+    def test_mixed(self):
+        # arbitrary non-repeating string
+        all_data = ",".join(str(i) for i in xrange(100))
+        self.rin.write(all_data)
+        self.rin.seek(0)
+        data = ""
+
+        for i in xrange(1, 100):
+            data += self.proto.recv(i)
+            # if we get to the end, do a non-blocking read instead of blocking
+            if len(data) + i > len(all_data):
+                data += self.proto.recv(i)
+                # ReceivableStringIO leaves off the last byte unless we ask
+                # nicely
+                data += self.proto.recv(1)
+                break
+            else:
+                data += self.proto.read(i)
+        else:
+            # didn't break, something must have gone wrong
+            self.fail()
+
+        self.assertEquals(all_data, data)
+
+
 class CapabilitiesTestCase(TestCase):
 class CapabilitiesTestCase(TestCase):
 
 
     def test_plain(self):
     def test_plain(self):

+ 477 - 141
dulwich/tests/test_repository.py

@@ -1,17 +1,17 @@
 # test_repository.py -- tests for repository.py
 # test_repository.py -- tests for repository.py
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
 # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net>
-# 
+#
 # This program is free software; you can redistribute it and/or
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # as published by the Free Software Foundation; version 2
-# of the License or (at your option) any later version of 
+# of the License or (at your option) any later version of
 # the License.
 # the License.
-# 
+#
 # This program is distributed in the hope that it will be useful,
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 # GNU General Public License for more details.
 # GNU General Public License for more details.
-# 
+#
 # You should have received a copy of the GNU General Public License
 # You should have received a copy of the GNU General Public License
 # along with this program; if not, write to the Free Software
 # along with this program; if not, write to the Free Software
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
 # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
@@ -25,44 +25,30 @@ import os
 import shutil
 import shutil
 import tempfile
 import tempfile
 import unittest
 import unittest
+import warnings
 
 
 from dulwich import errors
 from dulwich import errors
+from dulwich.object_store import (
+    tree_lookup_path,
+    )
+from dulwich import objects
 from dulwich.repo import (
 from dulwich.repo import (
     check_ref_format,
     check_ref_format,
+    DictRefsContainer,
     Repo,
     Repo,
     read_packed_refs,
     read_packed_refs,
     read_packed_refs_with_peeled,
     read_packed_refs_with_peeled,
     write_packed_refs,
     write_packed_refs,
     _split_ref_line,
     _split_ref_line,
     )
     )
+from dulwich.tests.utils import (
+    open_repo,
+    tear_down_repo,
+    )
 
 
 missing_sha = 'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
 missing_sha = 'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
 
 
 
 
-def open_repo(name):
-    """Open a copy of a repo in a temporary directory.
-
-    Use this function for accessing repos in dulwich/tests/data/repos to avoid
-    accidentally or intentionally modifying those repos in place. Use
-    tear_down_repo to delete any temp files created.
-
-    :param name: The name of the repository, relative to
-        dulwich/tests/data/repos
-    :returns: An initialized Repo object that lives in a temporary directory.
-    """
-    temp_dir = tempfile.mkdtemp()
-    repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos', name)
-    temp_repo_dir = os.path.join(temp_dir, name)
-    shutil.copytree(repo_dir, temp_repo_dir, symlinks=True)
-    return Repo(temp_repo_dir)
-
-def tear_down_repo(repo):
-    """Tear down a test repository."""
-    temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
-    shutil.rmtree(temp_dir)
-
-
-
 class CreateRepositoryTests(unittest.TestCase):
 class CreateRepositoryTests(unittest.TestCase):
 
 
     def test_create(self):
     def test_create(self):
@@ -86,73 +72,158 @@ class RepositoryTests(unittest.TestCase):
     def test_simple_props(self):
     def test_simple_props(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         self.assertEqual(r.controldir(), r.path)
         self.assertEqual(r.controldir(), r.path)
-  
+
     def test_ref(self):
     def test_ref(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         self.assertEqual(r.ref('refs/heads/master'),
         self.assertEqual(r.ref('refs/heads/master'),
                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
-  
+
+    def test_setitem(self):
+        r = self._repo = open_repo('a.git')
+        r["refs/tags/foo"] = 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
+        self.assertEquals('a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+                          r["refs/tags/foo"].id)
+
     def test_get_refs(self):
     def test_get_refs(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         self.assertEqual({
         self.assertEqual({
-            'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097', 
-            'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
+            'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+            'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097',
+            'refs/tags/mytag': '28237f4dc30d0d462658d6b937b08a0f0b6ef55a',
+            'refs/tags/mytag-packed': 'b0931cadc54336e78a1d980420e3268903b57a50',
             }, r.get_refs())
             }, r.get_refs())
-  
+
     def test_head(self):
     def test_head(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
-  
+
     def test_get_object(self):
     def test_get_object(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         obj = r.get_object(r.head())
         obj = r.get_object(r.head())
-        self.assertEqual(obj._type, 'commit')
-  
+        self.assertEqual(obj.type_name, 'commit')
+
     def test_get_object_non_existant(self):
     def test_get_object_non_existant(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         self.assertRaises(KeyError, r.get_object, missing_sha)
         self.assertRaises(KeyError, r.get_object, missing_sha)
-  
+
+    def test_contains_object(self):
+        r = self._repo = open_repo('a.git')
+        self.assertTrue(r.head() in r)
+
+    def test_contains_ref(self):
+        r = self._repo = open_repo('a.git')
+        self.assertTrue("HEAD" in r)
+
+    def test_contains_missing(self):
+        r = self._repo = open_repo('a.git')
+        self.assertFalse("bar" in r)
+
     def test_commit(self):
     def test_commit(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
-        obj = r.commit(r.head())
-        self.assertEqual(obj._type, 'commit')
-  
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            obj = r.commit(r.head())
+        finally:
+            warnings.resetwarnings()
+        self.assertEqual(obj.type_name, 'commit')
+
     def test_commit_not_commit(self):
     def test_commit_not_commit(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
-        self.assertRaises(errors.NotCommitError,
-                          r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
-  
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            self.assertRaises(errors.NotCommitError,
+                r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
+        finally:
+            warnings.resetwarnings()
+
     def test_tree(self):
     def test_tree(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
-        commit = r.commit(r.head())
-        tree = r.tree(commit.tree)
-        self.assertEqual(tree._type, 'tree')
+        commit = r[r.head()]
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            tree = r.tree(commit.tree)
+        finally:
+            warnings.resetwarnings()
+        self.assertEqual(tree.type_name, 'tree')
         self.assertEqual(tree.sha().hexdigest(), commit.tree)
         self.assertEqual(tree.sha().hexdigest(), commit.tree)
-  
+
     def test_tree_not_tree(self):
     def test_tree_not_tree(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
-        self.assertRaises(errors.NotTreeError, r.tree, r.head())
-  
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            self.assertRaises(errors.NotTreeError, r.tree, r.head())
+        finally:
+            warnings.resetwarnings()
+
+    def test_tag(self):
+        r = self._repo = open_repo('a.git')
+        tag_sha = '28237f4dc30d0d462658d6b937b08a0f0b6ef55a'
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            tag = r.tag(tag_sha)
+        finally:
+            warnings.resetwarnings()
+        self.assertEqual(tag.type_name, 'tag')
+        self.assertEqual(tag.sha().hexdigest(), tag_sha)
+        obj_class, obj_sha = tag.object
+        self.assertEqual(obj_class, objects.Commit)
+        self.assertEqual(obj_sha, r.head())
+
+    def test_tag_not_tag(self):
+        r = self._repo = open_repo('a.git')
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            self.assertRaises(errors.NotTagError, r.tag, r.head())
+        finally:
+            warnings.resetwarnings()
+
+    def test_get_peeled(self):
+        # unpacked ref
+        r = self._repo = open_repo('a.git')
+        tag_sha = '28237f4dc30d0d462658d6b937b08a0f0b6ef55a'
+        self.assertNotEqual(r[tag_sha].sha().hexdigest(), r.head())
+        self.assertEqual(r.get_peeled('refs/tags/mytag'), r.head())
+
+        # packed ref with cached peeled value
+        packed_tag_sha = 'b0931cadc54336e78a1d980420e3268903b57a50'
+        parent_sha = r[r.head()].parents[0]
+        self.assertNotEqual(r[packed_tag_sha].sha().hexdigest(), parent_sha)
+        self.assertEqual(r.get_peeled('refs/tags/mytag-packed'), parent_sha)
+
+        # TODO: add more corner cases to test repo
+
+    def test_get_peeled_not_tag(self):
+        r = self._repo = open_repo('a.git')
+        self.assertEqual(r.get_peeled('HEAD'), r.head())
+
     def test_get_blob(self):
     def test_get_blob(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
-        commit = r.commit(r.head())
-        tree = r.tree(commit.tree)
+        commit = r[r.head()]
+        tree = r[commit.tree]
         blob_sha = tree.entries()[0][2]
         blob_sha = tree.entries()[0][2]
-        blob = r.get_blob(blob_sha)
-        self.assertEqual(blob._type, 'blob')
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            blob = r.get_blob(blob_sha)
+        finally:
+            warnings.resetwarnings()
+        self.assertEqual(blob.type_name, 'blob')
         self.assertEqual(blob.sha().hexdigest(), blob_sha)
         self.assertEqual(blob.sha().hexdigest(), blob_sha)
-  
+
     def test_get_blob_notblob(self):
     def test_get_blob_notblob(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
-        self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
-    
+        warnings.simplefilter("ignore", DeprecationWarning)
+        try:
+            self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
+        finally:
+            warnings.resetwarnings()
+
     def test_linear_history(self):
     def test_linear_history(self):
         r = self._repo = open_repo('a.git')
         r = self._repo = open_repo('a.git')
         history = r.revision_history(r.head())
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, [r.head(),
         self.assertEqual(shas, [r.head(),
                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
-  
+
     def test_merge_history(self):
     def test_merge_history(self):
         r = self._repo = open_repo('simple_merge.git')
         r = self._repo = open_repo('simple_merge.git')
         history = r.revision_history(r.head())
         history = r.revision_history(r.head())
@@ -162,12 +233,12 @@ class RepositoryTests(unittest.TestCase):
                                 '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6',
                                 '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6',
                                 '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
                                 '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
                                 '0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
                                 '0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
-  
+
     def test_revision_history_missing_commit(self):
     def test_revision_history_missing_commit(self):
         r = self._repo = open_repo('simple_merge.git')
         r = self._repo = open_repo('simple_merge.git')
         self.assertRaises(errors.MissingCommitError, r.revision_history,
         self.assertRaises(errors.MissingCommitError, r.revision_history,
                           missing_sha)
                           missing_sha)
-  
+
     def test_out_of_order_merge(self):
     def test_out_of_order_merge(self):
         """Test that revision history is ordered by date, not parent order."""
         """Test that revision history is ordered by date, not parent order."""
         r = self._repo = open_repo('ooo_merge.git')
         r = self._repo = open_repo('ooo_merge.git')
@@ -177,7 +248,7 @@ class RepositoryTests(unittest.TestCase):
                                 'f507291b64138b875c28e03469025b1ea20bc614',
                                 'f507291b64138b875c28e03469025b1ea20bc614',
                                 'fb5b0425c7ce46959bec94d54b9a157645e114f5',
                                 'fb5b0425c7ce46959bec94d54b9a157645e114f5',
                                 'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
                                 'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
-  
+
     def test_get_tags_empty(self):
     def test_get_tags_empty(self):
         r = self._repo = open_repo('ooo_merge.git')
         r = self._repo = open_repo('ooo_merge.git')
         self.assertEqual({}, r.refs.as_dict('refs/tags'))
         self.assertEqual({}, r.refs.as_dict('refs/tags'))
@@ -186,6 +257,158 @@ class RepositoryTests(unittest.TestCase):
         r = self._repo = open_repo('ooo_merge.git')
         r = self._repo = open_repo('ooo_merge.git')
         self.assertEquals({}, r.get_config())
         self.assertEquals({}, r.get_config())
 
 
+    def test_common_revisions(self):
+        """
+        This test demonstrates that ``find_common_revisions()`` actually returns
+        common heads, not revisions; dulwich already uses
+        ``find_common_revisions()`` in such a manner (see
+        ``Repo.fetch_objects()``).
+        """
+
+        expected_shas = set(['60dacdc733de308bb77bb76ce0fb0f9b44c9769e'])
+
+        # Source for objects.
+        r_base = open_repo('simple_merge.git')
+
+        # Re-create each-side of the merge in simple_merge.git.
+        #
+        # Since the trees and blobs are missing, the repository created is
+        # corrupted, but we're only checking for commits for the purpose of this
+        # test, so it's immaterial.
+        r1_dir = tempfile.mkdtemp()
+        r1_commits = ['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd', # HEAD
+                      '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
+                      '0d89f20333fbb1d2f3a94da77f4981373d8f4310']
+
+        r2_dir = tempfile.mkdtemp()
+        r2_commits = ['4cffe90e0a41ad3f5190079d7c8f036bde29cbe6', # HEAD
+                      '60dacdc733de308bb77bb76ce0fb0f9b44c9769e',
+                      '0d89f20333fbb1d2f3a94da77f4981373d8f4310']
+
+        try:
+            r1 = Repo.init_bare(r1_dir)
+            map(lambda c: r1.object_store.add_object(r_base.get_object(c)), \
+                r1_commits)
+            r1.refs['HEAD'] = r1_commits[0]
+
+            r2 = Repo.init_bare(r2_dir)
+            map(lambda c: r2.object_store.add_object(r_base.get_object(c)), \
+                r2_commits)
+            r2.refs['HEAD'] = r2_commits[0]
+
+            # Finally, the 'real' testing!
+            shas = r2.object_store.find_common_revisions(r1.get_graph_walker())
+            self.assertEqual(set(shas), expected_shas)
+
+            shas = r1.object_store.find_common_revisions(r2.get_graph_walker())
+            self.assertEqual(set(shas), expected_shas)
+        finally:
+            shutil.rmtree(r1_dir)
+            shutil.rmtree(r2_dir)
+
+
+class BuildRepoTests(unittest.TestCase):
+    """Tests that build on-disk repos from scratch.
+
+    Repos live in a temp dir and are torn down after each test. They start with
+    a single commit in master having single file named 'a'.
+    """
+
+    def setUp(self):
+        repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
+        os.makedirs(repo_dir)
+        r = self._repo = Repo.init(repo_dir)
+        self.assertFalse(r.bare)
+        self.assertEqual('ref: refs/heads/master', r.refs.read_ref('HEAD'))
+        self.assertRaises(KeyError, lambda: r.refs['refs/heads/master'])
+
+        f = open(os.path.join(r.path, 'a'), 'wb')
+        try:
+            f.write('file contents')
+        finally:
+            f.close()
+        r.stage(['a'])
+        commit_sha = r.do_commit('msg',
+                                 committer='Test Committer <test@nodomain.com>',
+                                 author='Test Author <test@nodomain.com>',
+                                 commit_timestamp=12345, commit_timezone=0,
+                                 author_timestamp=12345, author_timezone=0)
+        self.assertEqual([], r[commit_sha].parents)
+        self._root_commit = commit_sha
+
+    def tearDown(self):
+        tear_down_repo(self._repo)
+
+    def test_build_repo(self):
+        r = self._repo
+        self.assertEqual('ref: refs/heads/master', r.refs.read_ref('HEAD'))
+        self.assertEqual(self._root_commit, r.refs['refs/heads/master'])
+        expected_blob = objects.Blob.from_string('file contents')
+        self.assertEqual(expected_blob.data, r[expected_blob.id].data)
+        actual_commit = r[self._root_commit]
+        self.assertEqual('msg', actual_commit.message)
+
+    def test_commit_modified(self):
+        r = self._repo
+        f = open(os.path.join(r.path, 'a'), 'wb')
+        try:
+            f.write('new contents')
+        finally:
+            f.close()
+        r.stage(['a'])
+        commit_sha = r.do_commit('modified a',
+                                 committer='Test Committer <test@nodomain.com>',
+                                 author='Test Author <test@nodomain.com>',
+                                 commit_timestamp=12395, commit_timezone=0,
+                                 author_timestamp=12395, author_timezone=0)
+        self.assertEqual([self._root_commit], r[commit_sha].parents)
+        _, blob_id = tree_lookup_path(r.get_object, r[commit_sha].tree, 'a')
+        self.assertEqual('new contents', r[blob_id].data)
+
+    def test_commit_deleted(self):
+        r = self._repo
+        os.remove(os.path.join(r.path, 'a'))
+        r.stage(['a'])
+        commit_sha = r.do_commit('deleted a',
+                                 committer='Test Committer <test@nodomain.com>',
+                                 author='Test Author <test@nodomain.com>',
+                                 commit_timestamp=12395, commit_timezone=0,
+                                 author_timestamp=12395, author_timezone=0)
+        self.assertEqual([self._root_commit], r[commit_sha].parents)
+        self.assertEqual([], list(r.open_index()))
+        tree = r[r[commit_sha].tree]
+        self.assertEqual([], tree.iteritems())
+
+    def test_commit_fail_ref(self):
+        r = self._repo
+
+        def set_if_equals(name, old_ref, new_ref):
+            return False
+        r.refs.set_if_equals = set_if_equals
+
+        def add_if_new(name, new_ref):
+            self.fail('Unexpected call to add_if_new')
+        r.refs.add_if_new = add_if_new
+
+        old_shas = set(r.object_store)
+        self.assertRaises(errors.CommitError, r.do_commit, 'failed commit',
+                          committer='Test Committer <test@nodomain.com>',
+                          author='Test Author <test@nodomain.com>',
+                          commit_timestamp=12345, commit_timezone=0,
+                          author_timestamp=12345, author_timezone=0)
+        new_shas = set(r.object_store) - old_shas
+        self.assertEqual(1, len(new_shas))
+        # Check that the new commit (now garbage) was added.
+        new_commit = r[new_shas.pop()]
+        self.assertEqual(r[self._root_commit].tree, new_commit.tree)
+        self.assertEqual('failed commit', new_commit.message)
+
+    def test_stage_deleted(self):
+        r = self._repo
+        os.remove(os.path.join(r.path, 'a'))
+        r.stage(['a'])
+        r.stage(['a'])  # double-stage a deleted path
+
 
 
 class CheckRefFormatTests(unittest.TestCase):
 class CheckRefFormatTests(unittest.TestCase):
     """Tests for the check_ref_format function.
     """Tests for the check_ref_format function.
@@ -239,12 +462,12 @@ class PackedRefsFileTests(unittest.TestCase):
 
 
     def test_read_with_peeled(self):
     def test_read_with_peeled(self):
         f = StringIO('%s ref/1\n%s ref/2\n^%s\n%s ref/4' % (
         f = StringIO('%s ref/1\n%s ref/2\n^%s\n%s ref/4' % (
-            ONES, TWOS, THREES, FOURS))
+          ONES, TWOS, THREES, FOURS))
         self.assertEqual([
         self.assertEqual([
-            (ONES, 'ref/1', None),
-            (TWOS, 'ref/2', THREES),
-            (FOURS, 'ref/4', None),
-            ], list(read_packed_refs_with_peeled(f)))
+          (ONES, 'ref/1', None),
+          (TWOS, 'ref/2', THREES),
+          (FOURS, 'ref/4', None),
+          ], list(read_packed_refs_with_peeled(f)))
 
 
     def test_read_with_peeled_errors(self):
     def test_read_with_peeled_errors(self):
         f = StringIO('^%s\n%s ref/1' % (TWOS, ONES))
         f = StringIO('^%s\n%s ref/1' % (TWOS, ONES))
@@ -258,8 +481,8 @@ class PackedRefsFileTests(unittest.TestCase):
         write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS},
         write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS},
                           {'ref/1': THREES})
                           {'ref/1': THREES})
         self.assertEqual(
         self.assertEqual(
-            "# pack-refs with: peeled\n%s ref/1\n^%s\n%s ref/2\n" % (
-            ONES, THREES, TWOS), f.getvalue())
+          "# pack-refs with: peeled\n%s ref/1\n^%s\n%s ref/2\n" % (
+          ONES, THREES, TWOS), f.getvalue())
 
 
     def test_write_without_peeled(self):
     def test_write_without_peeled(self):
         f = StringIO()
         f = StringIO()
@@ -267,62 +490,39 @@ class PackedRefsFileTests(unittest.TestCase):
         self.assertEqual("%s ref/1\n%s ref/2\n" % (ONES, TWOS), f.getvalue())
         self.assertEqual("%s ref/1\n%s ref/2\n" % (ONES, TWOS), f.getvalue())
 
 
 
 
-class RefsContainerTests(unittest.TestCase):
+# Dict of refs that we expect all RefsContainerTests subclasses to define.
+_TEST_REFS = {
+  'HEAD': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+  'refs/heads/master': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+  'refs/heads/packed': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+  'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
+  'refs/tags/refs-0.2': '3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8',
+  }
 
 
-    def setUp(self):
-        self._repo = open_repo('refs.git')
-        self._refs = self._repo.refs
 
 
-    def tearDown(self):
-        tear_down_repo(self._repo)
-
-    def test_get_packed_refs(self):
-        self.assertEqual(
-            {'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c'},
-            self._refs.get_packed_refs())
+class RefsContainerTests(object):
 
 
     def test_keys(self):
     def test_keys(self):
-        self.assertEqual([
-            'HEAD',
-            'refs/heads/loop',
-            'refs/heads/master',
-            'refs/tags/refs-0.1',
-            ], sorted(list(self._refs.keys())))
-        self.assertEqual(['loop', 'master'],
-                         sorted(self._refs.keys('refs/heads')))
-        self.assertEqual(['refs-0.1'], list(self._refs.keys('refs/tags')))
+        actual_keys = set(self._refs.keys())
+        self.assertEqual(set(self._refs.allkeys()), actual_keys)
+        # ignore the symref loop if it exists
+        actual_keys.discard('refs/heads/loop')
+        self.assertEqual(set(_TEST_REFS.iterkeys()), actual_keys)
+
+        actual_keys = self._refs.keys('refs/heads')
+        actual_keys.discard('loop')
+        self.assertEqual(['master', 'packed'], sorted(actual_keys))
+        self.assertEqual(['refs-0.1', 'refs-0.2'],
+                         sorted(self._refs.keys('refs/tags')))
 
 
     def test_as_dict(self):
     def test_as_dict(self):
-        # refs/heads/loop does not show up
-        self.assertEqual({
-            'HEAD': '42d06bd4b77fed026b154d16493e5deab78f02ec',
-            'refs/heads/master': '42d06bd4b77fed026b154d16493e5deab78f02ec',
-            'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
-            }, self._refs.as_dict())
+        # refs/heads/loop does not show up even if it exists
+        self.assertEqual(_TEST_REFS, self._refs.as_dict())
 
 
     def test_setitem(self):
     def test_setitem(self):
         self._refs['refs/some/ref'] = '42d06bd4b77fed026b154d16493e5deab78f02ec'
         self._refs['refs/some/ref'] = '42d06bd4b77fed026b154d16493e5deab78f02ec'
         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
                          self._refs['refs/some/ref'])
                          self._refs['refs/some/ref'])
-        f = open(os.path.join(self._refs.path, 'refs', 'some', 'ref'), 'rb')
-        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
-                          f.read()[:40])
-        f.close()
-
-    def test_setitem_symbolic(self):
-        ones = '1' * 40
-        self._refs['HEAD'] = ones
-        self.assertEqual(ones, self._refs['HEAD'])
-
-        # ensure HEAD was not modified
-        f = open(os.path.join(self._refs.path, 'HEAD'), 'rb')
-        self.assertEqual('ref: refs/heads/master', iter(f).next().rstrip('\n'))
-        f.close()
-
-        # ensure the symbolic link was written through
-        f = open(os.path.join(self._refs.path, 'refs', 'heads', 'master'), 'rb')
-        self.assertEqual(ones, f.read()[:40])
-        f.close()
 
 
     def test_set_if_equals(self):
     def test_set_if_equals(self):
         nines = '9' * 40
         nines = '9' * 40
@@ -331,17 +531,13 @@ class RefsContainerTests(unittest.TestCase):
                          self._refs['HEAD'])
                          self._refs['HEAD'])
 
 
         self.assertTrue(self._refs.set_if_equals(
         self.assertTrue(self._refs.set_if_equals(
-            'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec', nines))
+          'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec', nines))
         self.assertEqual(nines, self._refs['HEAD'])
         self.assertEqual(nines, self._refs['HEAD'])
 
 
-        # ensure symref was followed
+        self.assertTrue(self._refs.set_if_equals('refs/heads/master', None,
+                                                 nines))
         self.assertEqual(nines, self._refs['refs/heads/master'])
         self.assertEqual(nines, self._refs['refs/heads/master'])
 
 
-        self.assertFalse(os.path.exists(
-            os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
-        self.assertFalse(os.path.exists(
-            os.path.join(self._refs.path, 'HEAD.lock')))
-
     def test_add_if_new(self):
     def test_add_if_new(self):
         nines = '9' * 40
         nines = '9' * 40
         self.assertFalse(self._refs.add_if_new('refs/heads/master', nines))
         self.assertFalse(self._refs.add_if_new('refs/heads/master', nines))
@@ -351,10 +547,23 @@ class RefsContainerTests(unittest.TestCase):
         self.assertTrue(self._refs.add_if_new('refs/some/ref', nines))
         self.assertTrue(self._refs.add_if_new('refs/some/ref', nines))
         self.assertEqual(nines, self._refs['refs/some/ref'])
         self.assertEqual(nines, self._refs['refs/some/ref'])
 
 
-        # don't overwrite packed ref
-        self.assertFalse(self._refs.add_if_new('refs/tags/refs-0.1', nines))
-        self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
-                         self._refs['refs/tags/refs-0.1'])
+    def test_set_symbolic_ref(self):
+        self._refs.set_symbolic_ref('refs/heads/symbolic', 'refs/heads/master')
+        self.assertEqual('ref: refs/heads/master',
+                         self._refs.read_loose_ref('refs/heads/symbolic'))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/heads/symbolic'])
+
+    def test_set_symbolic_ref_overwrite(self):
+        nines = '9' * 40
+        self.assertFalse('refs/heads/symbolic' in self._refs)
+        self._refs['refs/heads/symbolic'] = nines
+        self.assertEqual(nines, self._refs.read_loose_ref('refs/heads/symbolic'))
+        self._refs.set_symbolic_ref('refs/heads/symbolic', 'refs/heads/master')
+        self.assertEqual('ref: refs/heads/master',
+                         self._refs.read_loose_ref('refs/heads/symbolic'))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/heads/symbolic'])
 
 
     def test_check_refname(self):
     def test_check_refname(self):
         try:
         try:
@@ -370,21 +579,131 @@ class RefsContainerTests(unittest.TestCase):
         self.assertRaises(KeyError, self._refs._check_refname, 'refs')
         self.assertRaises(KeyError, self._refs._check_refname, 'refs')
         self.assertRaises(KeyError, self._refs._check_refname, 'notrefs/foo')
         self.assertRaises(KeyError, self._refs._check_refname, 'notrefs/foo')
 
 
+    def test_contains(self):
+        self.assertTrue('refs/heads/master' in self._refs)
+        self.assertFalse('refs/heads/bar' in self._refs)
+
+    def test_delitem(self):
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                          self._refs['refs/heads/master'])
+        del self._refs['refs/heads/master']
+        self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
+
+    def test_remove_if_equals(self):
+        self.assertFalse(self._refs.remove_if_equals('HEAD', 'c0ffee'))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['HEAD'])
+        self.assertTrue(self._refs.remove_if_equals(
+          'refs/tags/refs-0.2', '3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8'))
+        self.assertFalse('refs/tags/refs-0.2' in self._refs)
+
+
+class DictRefsContainerTests(RefsContainerTests, unittest.TestCase):
+
+    def setUp(self):
+        self._refs = DictRefsContainer(dict(_TEST_REFS))
+
+
+class DiskRefsContainerTests(RefsContainerTests, unittest.TestCase):
+
+    def setUp(self):
+        self._repo = open_repo('refs.git')
+        self._refs = self._repo.refs
+
+    def tearDown(self):
+        tear_down_repo(self._repo)
+
+    def test_get_packed_refs(self):
+        self.assertEqual({
+          'refs/heads/packed': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+          'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
+          }, self._refs.get_packed_refs())
+
+    def test_get_peeled_not_packed(self):
+        # not packed
+        self.assertEqual(None, self._refs.get_peeled('refs/tags/refs-0.2'))
+        self.assertEqual('3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8',
+                         self._refs['refs/tags/refs-0.2'])
+
+        # packed, known not peelable
+        self.assertEqual(self._refs['refs/heads/packed'],
+                         self._refs.get_peeled('refs/heads/packed'))
+
+        # packed, peeled
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs.get_peeled('refs/tags/refs-0.1'))
+
+    def test_setitem(self):
+        RefsContainerTests.test_setitem(self)
+        f = open(os.path.join(self._refs.path, 'refs', 'some', 'ref'), 'rb')
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                          f.read()[:40])
+        f.close()
+
+    def test_setitem_symbolic(self):
+        ones = '1' * 40
+        self._refs['HEAD'] = ones
+        self.assertEqual(ones, self._refs['HEAD'])
+
+        # ensure HEAD was not modified
+        f = open(os.path.join(self._refs.path, 'HEAD'), 'rb')
+        self.assertEqual('ref: refs/heads/master', iter(f).next().rstrip('\n'))
+        f.close()
+
+        # ensure the symbolic link was written through
+        f = open(os.path.join(self._refs.path, 'refs', 'heads', 'master'), 'rb')
+        self.assertEqual(ones, f.read()[:40])
+        f.close()
+
+    def test_set_if_equals(self):
+        RefsContainerTests.test_set_if_equals(self)
+
+        # ensure symref was followed
+        self.assertEqual('9' * 40, self._refs['refs/heads/master'])
+
+        # ensure lockfile was deleted
+        self.assertFalse(os.path.exists(
+          os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
+        self.assertFalse(os.path.exists(
+          os.path.join(self._refs.path, 'HEAD.lock')))
+
+    def test_add_if_new_packed(self):
+        # don't overwrite packed ref
+        self.assertFalse(self._refs.add_if_new('refs/tags/refs-0.1', '9' * 40))
+        self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
+                         self._refs['refs/tags/refs-0.1'])
+
+    def test_add_if_new_symbolic(self):
+        # Use an empty repo instead of the default.
+        tear_down_repo(self._repo)
+        repo_dir = os.path.join(tempfile.mkdtemp(), 'test')
+        os.makedirs(repo_dir)
+        self._repo = Repo.init(repo_dir)
+        refs = self._repo.refs
+
+        nines = '9' * 40
+        self.assertEqual('ref: refs/heads/master', refs.read_ref('HEAD'))
+        self.assertFalse('refs/heads/master' in refs)
+        self.assertTrue(refs.add_if_new('HEAD', nines))
+        self.assertEqual('ref: refs/heads/master', refs.read_ref('HEAD'))
+        self.assertEqual(nines, refs['HEAD'])
+        self.assertEqual(nines, refs['refs/heads/master'])
+        self.assertFalse(refs.add_if_new('HEAD', '1' * 40))
+        self.assertEqual(nines, refs['HEAD'])
+        self.assertEqual(nines, refs['refs/heads/master'])
+
     def test_follow(self):
     def test_follow(self):
         self.assertEquals(
         self.assertEquals(
-            ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
-            self._refs._follow('HEAD'))
+          ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
+          self._refs._follow('HEAD'))
         self.assertEquals(
         self.assertEquals(
-            ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
-            self._refs._follow('refs/heads/master'))
+          ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
+          self._refs._follow('refs/heads/master'))
         self.assertRaises(KeyError, self._refs._follow, 'notrefs/foo')
         self.assertRaises(KeyError, self._refs._follow, 'notrefs/foo')
         self.assertRaises(KeyError, self._refs._follow, 'refs/heads/loop')
         self.assertRaises(KeyError, self._refs._follow, 'refs/heads/loop')
 
 
     def test_delitem(self):
     def test_delitem(self):
-        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
-                          self._refs['refs/heads/master'])
-        del self._refs['refs/heads/master']
-        self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
+        RefsContainerTests.test_delitem(self)
         ref_file = os.path.join(self._refs.path, 'refs', 'heads', 'master')
         ref_file = os.path.join(self._refs.path, 'refs', 'heads', 'master')
         self.assertFalse(os.path.exists(ref_file))
         self.assertFalse(os.path.exists(ref_file))
         self.assertFalse('refs/heads/master' in self._refs.get_packed_refs())
         self.assertFalse('refs/heads/master' in self._refs.get_packed_refs())
@@ -398,17 +717,12 @@ class RefsContainerTests(unittest.TestCase):
                          self._refs['refs/heads/master'])
                          self._refs['refs/heads/master'])
         self.assertFalse(os.path.exists(os.path.join(self._refs.path, 'HEAD')))
         self.assertFalse(os.path.exists(os.path.join(self._refs.path, 'HEAD')))
 
 
-    def test_remove_if_equals(self):
-        nines = '9' * 40
-        self.assertFalse(self._refs.remove_if_equals('HEAD', 'c0ffee'))
-        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
-                         self._refs['HEAD'])
-
+    def test_remove_if_equals_symref(self):
         # HEAD is a symref, so shouldn't equal its dereferenced value
         # HEAD is a symref, so shouldn't equal its dereferenced value
         self.assertFalse(self._refs.remove_if_equals(
         self.assertFalse(self._refs.remove_if_equals(
-            'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
+          'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
         self.assertTrue(self._refs.remove_if_equals(
         self.assertTrue(self._refs.remove_if_equals(
-            'refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
+          'refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
         self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
         self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
 
 
         # HEAD is now a broken symref
         # HEAD is now a broken symref
@@ -421,10 +735,32 @@ class RefsContainerTests(unittest.TestCase):
         self.assertFalse(os.path.exists(
         self.assertFalse(os.path.exists(
             os.path.join(self._refs.path, 'HEAD.lock')))
             os.path.join(self._refs.path, 'HEAD.lock')))
 
 
+    def test_remove_packed_without_peeled(self):
+        refs_file = os.path.join(self._repo.path, 'packed-refs')
+        f = open(refs_file)
+        refs_data = f.read()
+        f.close()
+        f = open(refs_file, 'w')
+        f.write('\n'.join(l for l in refs_data.split('\n')
+                          if not l or l[0] not in '#^'))
+        f.close()
+        self._repo = Repo(self._repo.path)
+        refs = self._repo.refs
+        self.assertTrue(refs.remove_if_equals(
+          'refs/heads/packed', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
+
+    def test_remove_if_equals_packed(self):
         # test removing ref that is only packed
         # test removing ref that is only packed
         self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
         self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
                          self._refs['refs/tags/refs-0.1'])
                          self._refs['refs/tags/refs-0.1'])
         self.assertTrue(
         self.assertTrue(
-            self._refs.remove_if_equals('refs/tags/refs-0.1',
-            'df6800012397fb85c56e7418dd4eb9405dee075c'))
+          self._refs.remove_if_equals('refs/tags/refs-0.1',
+          'df6800012397fb85c56e7418dd4eb9405dee075c'))
         self.assertRaises(KeyError, lambda: self._refs['refs/tags/refs-0.1'])
         self.assertRaises(KeyError, lambda: self._refs['refs/tags/refs-0.1'])
+
+    def test_read_ref(self):
+        self.assertEqual('ref: refs/heads/master', self._refs.read_ref("HEAD"))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+            self._refs.read_ref("refs/heads/packed"))
+        self.assertEqual(None,
+            self._refs.read_ref("nonexistant"))

+ 199 - 78
dulwich/tests/test_server.py

@@ -20,32 +20,34 @@
 """Tests for the smart protocol server."""
 """Tests for the smart protocol server."""
 
 
 
 
-from cStringIO import StringIO
 from unittest import TestCase
 from unittest import TestCase
 
 
 from dulwich.errors import (
 from dulwich.errors import (
     GitProtocolError,
     GitProtocolError,
     )
     )
 from dulwich.server import (
 from dulwich.server import (
-    UploadPackHandler,
-    ProtocolGraphWalker,
-    SingleAckGraphWalkerImpl,
+    Backend,
+    DictBackend,
+    BackendRepo,
+    Handler,
     MultiAckGraphWalkerImpl,
     MultiAckGraphWalkerImpl,
     MultiAckDetailedGraphWalkerImpl,
     MultiAckDetailedGraphWalkerImpl,
+    ProtocolGraphWalker,
+    SingleAckGraphWalkerImpl,
+    UploadPackHandler,
     )
     )
 
 
-from dulwich.protocol import (
-    SINGLE_ACK,
-    MULTI_ACK,
-    )
 
 
 ONE = '1' * 40
 ONE = '1' * 40
 TWO = '2' * 40
 TWO = '2' * 40
 THREE = '3' * 40
 THREE = '3' * 40
 FOUR = '4' * 40
 FOUR = '4' * 40
 FIVE = '5' * 40
 FIVE = '5' * 40
+SIX = '6' * 40
+
 
 
 class TestProto(object):
 class TestProto(object):
+
     def __init__(self):
     def __init__(self):
         self._output = []
         self._output = []
         self._received = {0: [], 1: [], 2: [], 3: []}
         self._received = {0: [], 1: [], 2: [], 3: []}
@@ -75,76 +77,158 @@ class TestProto(object):
             return None
             return None
 
 
 
 
-class UploadPackHandlerTestCase(TestCase):
+class HandlerTestCase(TestCase):
+
     def setUp(self):
     def setUp(self):
-        self._handler = UploadPackHandler(None, None, None)
+        self._handler = Handler(Backend(), None)
+        self._handler.capabilities = lambda: ('cap1', 'cap2', 'cap3')
+        self._handler.required_capabilities = lambda: ('cap2',)
 
 
-    def test_set_client_capabilities(self):
+    def assertSucceeds(self, func, *args, **kwargs):
         try:
         try:
-            self._handler.set_client_capabilities([])
-        except GitProtocolError:
-            self.fail()
+            func(*args, **kwargs)
+        except GitProtocolError, e:
+            self.fail(e)
 
 
-        try:
-            self._handler.set_client_capabilities([
-                'multi_ack', 'side-band-64k', 'thin-pack', 'ofs-delta'])
-        except GitProtocolError:
-            self.fail()
-
-    def test_set_client_capabilities_error(self):
-        self.assertRaises(GitProtocolError,
-                          self._handler.set_client_capabilities,
-                          ['weird_ack_level', 'ofs-delta'])
-        try:
-            self._handler.set_client_capabilities(['include-tag'])
-        except GitProtocolError:
-            self.fail()
+    def test_capability_line(self):
+        self.assertEquals('cap1 cap2 cap3', self._handler.capability_line())
+
+    def test_set_client_capabilities(self):
+        set_caps = self._handler.set_client_capabilities
+        self.assertSucceeds(set_caps, ['cap2'])
+        self.assertSucceeds(set_caps, ['cap1', 'cap2'])
+
+        # different order
+        self.assertSucceeds(set_caps, ['cap3', 'cap1', 'cap2'])
+
+        # error cases
+        self.assertRaises(GitProtocolError, set_caps, ['capxxx', 'cap2'])
+        self.assertRaises(GitProtocolError, set_caps, ['cap1', 'cap3'])
+
+        # ignore innocuous but unknown capabilities
+        self.assertRaises(GitProtocolError, set_caps, ['cap2', 'ignoreme'])
+        self.assertFalse('ignoreme' in self._handler.capabilities())
+        self._handler.innocuous_capabilities = lambda: ('ignoreme',)
+        self.assertSucceeds(set_caps, ['cap2', 'ignoreme'])
+
+    def test_has_capability(self):
+        self.assertRaises(GitProtocolError, self._handler.has_capability, 'cap')
+        caps = self._handler.capabilities()
+        self._handler.set_client_capabilities(caps)
+        for cap in caps:
+            self.assertTrue(self._handler.has_capability(cap))
+        self.assertFalse(self._handler.has_capability('capxxx'))
+
+
+class UploadPackHandlerTestCase(TestCase):
+
+    def setUp(self):
+        self._backend = DictBackend({"/": BackendRepo()})
+        self._handler = UploadPackHandler(self._backend,
+                ["/", "host=lolcathost"], None, None)
+        self._handler.proto = TestProto()
+
+    def test_progress(self):
+        caps = self._handler.required_capabilities()
+        self._handler.set_client_capabilities(caps)
+        self._handler.progress('first message')
+        self._handler.progress('second message')
+        self.assertEqual('first message',
+                         self._handler.proto.get_received_line(2))
+        self.assertEqual('second message',
+                         self._handler.proto.get_received_line(2))
+        self.assertEqual(None, self._handler.proto.get_received_line(2))
+
+    def test_no_progress(self):
+        caps = list(self._handler.required_capabilities()) + ['no-progress']
+        self._handler.set_client_capabilities(caps)
+        self._handler.progress('first message')
+        self._handler.progress('second message')
+        self.assertEqual(None, self._handler.proto.get_received_line(2))
+
+    def test_get_tagged(self):
+        refs = {
+            'refs/tags/tag1': ONE,
+            'refs/tags/tag2': TWO,
+            'refs/heads/master': FOUR,  # not a tag, no peeled value
+            }
+        peeled = {
+            'refs/tags/tag1': '1234',
+            'refs/tags/tag2': '5678',
+            }
+
+        class TestRepo(object):
+            def get_peeled(self, ref):
+                return peeled.get(ref, refs[ref])
+
+        caps = list(self._handler.required_capabilities()) + ['include-tag']
+        self._handler.set_client_capabilities(caps)
+        self.assertEquals({'1234': ONE, '5678': TWO},
+                          self._handler.get_tagged(refs, repo=TestRepo()))
+
+        # non-include-tag case
+        caps = self._handler.required_capabilities()
+        self._handler.set_client_capabilities(caps)
+        self.assertEquals({}, self._handler.get_tagged(refs, repo=TestRepo()))
 
 
 
 
 class TestCommit(object):
 class TestCommit(object):
+
     def __init__(self, sha, parents, commit_time):
     def __init__(self, sha, parents, commit_time):
         self.id = sha
         self.id = sha
-        self._parents = parents
+        self.parents = parents
         self.commit_time = commit_time
         self.commit_time = commit_time
-
-    def get_parents(self):
-        return self._parents
+        self.type_name = "commit"
 
 
     def __repr__(self):
     def __repr__(self):
         return '%s(%s)' % (self.__class__.__name__, self._sha)
         return '%s(%s)' % (self.__class__.__name__, self._sha)
 
 
 
 
+class TestRepo(object):
+    def __init__(self):
+        self.peeled = {}
+
+    def get_peeled(self, name):
+        return self.peeled[name]
+
+
 class TestBackend(object):
 class TestBackend(object):
-    def __init__(self, objects):
+
+    def __init__(self, repo, objects):
+        self.repo = repo
         self.object_store = objects
         self.object_store = objects
 
 
 
 
-class TestHandler(object):
+class TestUploadPackHandler(Handler):
+
     def __init__(self, objects, proto):
     def __init__(self, objects, proto):
-        self.backend = TestBackend(objects)
+        self.backend = TestBackend(TestRepo(), objects)
         self.proto = proto
         self.proto = proto
         self.stateless_rpc = False
         self.stateless_rpc = False
         self.advertise_refs = False
         self.advertise_refs = False
 
 
     def capabilities(self):
     def capabilities(self):
-        return 'multi_ack'
+        return ('multi_ack',)
 
 
 
 
 class ProtocolGraphWalkerTestCase(TestCase):
 class ProtocolGraphWalkerTestCase(TestCase):
+
     def setUp(self):
     def setUp(self):
         # Create the following commit tree:
         # Create the following commit tree:
         #   3---5
         #   3---5
         #  /
         #  /
         # 1---2---4
         # 1---2---4
         self._objects = {
         self._objects = {
-            ONE: TestCommit(ONE, [], 111),
-            TWO: TestCommit(TWO, [ONE], 222),
-            THREE: TestCommit(THREE, [ONE], 333),
-            FOUR: TestCommit(FOUR, [TWO], 444),
-            FIVE: TestCommit(FIVE, [THREE], 555),
-            }
+          ONE: TestCommit(ONE, [], 111),
+          TWO: TestCommit(TWO, [ONE], 222),
+          THREE: TestCommit(THREE, [ONE], 333),
+          FOUR: TestCommit(FOUR, [TWO], 444),
+          FIVE: TestCommit(FIVE, [THREE], 555),
+          }
+
         self._walker = ProtocolGraphWalker(
         self._walker = ProtocolGraphWalker(
-            TestHandler(self._objects, TestProto()))
+            TestUploadPackHandler(self._objects, TestProto()),
+            self._objects, None)
 
 
     def test_is_satisfied_no_haves(self):
     def test_is_satisfied_no_haves(self):
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))
         self.assertFalse(self._walker._is_satisfied([], ONE, 0))
@@ -173,13 +257,13 @@ class ProtocolGraphWalkerTestCase(TestCase):
 
 
     def test_read_proto_line(self):
     def test_read_proto_line(self):
         self._walker.proto.set_output([
         self._walker.proto.set_output([
-            'want %s' % ONE,
-            'want %s' % TWO,
-            'have %s' % THREE,
-            'foo %s' % FOUR,
-            'bar',
-            'done',
-            ])
+          'want %s' % ONE,
+          'want %s' % TWO,
+          'have %s' % THREE,
+          'foo %s' % FOUR,
+          'bar',
+          'done',
+          ])
         self.assertEquals(('want', ONE), self._walker.read_proto_line())
         self.assertEquals(('want', ONE), self._walker.read_proto_line())
         self.assertEquals(('want', TWO), self._walker.read_proto_line())
         self.assertEquals(('want', TWO), self._walker.read_proto_line())
         self.assertEquals(('have', THREE), self._walker.read_proto_line())
         self.assertEquals(('have', THREE), self._walker.read_proto_line())
@@ -192,10 +276,11 @@ class ProtocolGraphWalkerTestCase(TestCase):
         self.assertRaises(GitProtocolError, self._walker.determine_wants, {})
         self.assertRaises(GitProtocolError, self._walker.determine_wants, {})
 
 
         self._walker.proto.set_output([
         self._walker.proto.set_output([
-            'want %s multi_ack' % ONE,
-            'want %s' % TWO,
-            ])
+          'want %s multi_ack' % ONE,
+          'want %s' % TWO,
+          ])
         heads = {'ref1': ONE, 'ref2': TWO, 'ref3': THREE}
         heads = {'ref1': ONE, 'ref2': TWO, 'ref3': THREE}
+        self._walker.get_peeled = heads.get
         self.assertEquals([ONE, TWO], self._walker.determine_wants(heads))
         self.assertEquals([ONE, TWO], self._walker.determine_wants(heads))
 
 
         self._walker.proto.set_output(['want %s multi_ack' % FOUR])
         self._walker.proto.set_output(['want %s multi_ack' % FOUR])
@@ -210,10 +295,40 @@ class ProtocolGraphWalkerTestCase(TestCase):
         self._walker.proto.set_output(['want %s multi_ack' % FOUR])
         self._walker.proto.set_output(['want %s multi_ack' % FOUR])
         self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
         self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
 
 
+    def test_determine_wants_advertisement(self):
+        self._walker.proto.set_output([])
+        # advertise branch tips plus tag
+        heads = {'ref4': FOUR, 'ref5': FIVE, 'tag6': SIX}
+        peeled = {'ref4': FOUR, 'ref5': FIVE, 'tag6': FIVE}
+        self._walker.get_peeled = peeled.get
+        self._walker.determine_wants(heads)
+        lines = []
+        while True:
+            line = self._walker.proto.get_received_line()
+            if line == 'None':
+                break
+            # strip capabilities list if present
+            if '\x00' in line:
+                line = line[:line.index('\x00')]
+            lines.append(line.rstrip())
+
+        self.assertEquals([
+          '%s ref4' % FOUR,
+          '%s ref5' % FIVE,
+          '%s tag6^{}' % FIVE,
+          '%s tag6' % SIX,
+          ], sorted(lines))
+
+        # ensure peeled tag was advertised immediately following tag
+        for i, line in enumerate(lines):
+            if line.endswith(' tag6'):
+                self.assertEquals('%s tag6^{}' % FIVE, lines[i+1])
+
     # TODO: test commit time cutoff
     # TODO: test commit time cutoff
 
 
 
 
 class TestProtocolGraphWalker(object):
 class TestProtocolGraphWalker(object):
+
     def __init__(self):
     def __init__(self):
         self.acks = []
         self.acks = []
         self.lines = []
         self.lines = []
@@ -241,14 +356,15 @@ class TestProtocolGraphWalker(object):
 
 
 class AckGraphWalkerImplTestCase(TestCase):
 class AckGraphWalkerImplTestCase(TestCase):
     """Base setup and asserts for AckGraphWalker tests."""
     """Base setup and asserts for AckGraphWalker tests."""
+
     def setUp(self):
     def setUp(self):
         self._walker = TestProtocolGraphWalker()
         self._walker = TestProtocolGraphWalker()
         self._walker.lines = [
         self._walker.lines = [
-            ('have', TWO),
-            ('have', ONE),
-            ('have', THREE),
-            ('done', None),
-            ]
+          ('have', TWO),
+          ('have', ONE),
+          ('have', THREE),
+          ('done', None),
+          ]
         self._impl = self.impl_cls(self._walker)
         self._impl = self.impl_cls(self._walker)
 
 
     def assertNoAck(self):
     def assertNoAck(self):
@@ -270,6 +386,7 @@ class AckGraphWalkerImplTestCase(TestCase):
 
 
 
 
 class SingleAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
 class SingleAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
+
     impl_cls = SingleAckGraphWalkerImpl
     impl_cls = SingleAckGraphWalkerImpl
 
 
     def test_single_ack(self):
     def test_single_ack(self):
@@ -335,7 +452,9 @@ class SingleAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
         self.assertNextEquals(None)
         self.assertNextEquals(None)
         self.assertNak()
         self.assertNak()
 
 
+
 class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
 class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
+
     impl_cls = MultiAckGraphWalkerImpl
     impl_cls = MultiAckGraphWalkerImpl
 
 
     def test_multi_ack(self):
     def test_multi_ack(self):
@@ -371,17 +490,17 @@ class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
 
 
     def test_multi_ack_flush(self):
     def test_multi_ack_flush(self):
         self._walker.lines = [
         self._walker.lines = [
-            ('have', TWO),
-            (None, None),
-            ('have', ONE),
-            ('have', THREE),
-            ('done', None),
-            ]
+          ('have', TWO),
+          (None, None),
+          ('have', ONE),
+          ('have', THREE),
+          ('done', None),
+          ]
         self.assertNextEquals(TWO)
         self.assertNextEquals(TWO)
         self.assertNoAck()
         self.assertNoAck()
 
 
         self.assertNextEquals(ONE)
         self.assertNextEquals(ONE)
-        self.assertNak() # nak the flush-pkt
+        self.assertNak()  # nak the flush-pkt
 
 
         self._walker.done = True
         self._walker.done = True
         self._impl.ack(ONE)
         self._impl.ack(ONE)
@@ -407,7 +526,9 @@ class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
         self.assertNextEquals(None)
         self.assertNextEquals(None)
         self.assertNak()
         self.assertNak()
 
 
+
 class MultiAckDetailedGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
 class MultiAckDetailedGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
+
     impl_cls = MultiAckDetailedGraphWalkerImpl
     impl_cls = MultiAckDetailedGraphWalkerImpl
 
 
     def test_multi_ack(self):
     def test_multi_ack(self):
@@ -444,17 +565,17 @@ class MultiAckDetailedGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
     def test_multi_ack_flush(self):
     def test_multi_ack_flush(self):
         # same as ack test but contains a flush-pkt in the middle
         # same as ack test but contains a flush-pkt in the middle
         self._walker.lines = [
         self._walker.lines = [
-            ('have', TWO),
-            (None, None),
-            ('have', ONE),
-            ('have', THREE),
-            ('done', None),
-            ]
+          ('have', TWO),
+          (None, None),
+          ('have', ONE),
+          ('have', THREE),
+          ('done', None),
+          ]
         self.assertNextEquals(TWO)
         self.assertNextEquals(TWO)
         self.assertNoAck()
         self.assertNoAck()
 
 
         self.assertNextEquals(ONE)
         self.assertNextEquals(ONE)
-        self.assertNak() # nak the flush-pkt
+        self.assertNak()  # nak the flush-pkt
 
 
         self._walker.done = True
         self._walker.done = True
         self._impl.ack(ONE)
         self._impl.ack(ONE)
@@ -483,12 +604,12 @@ class MultiAckDetailedGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
     def test_multi_ack_nak_flush(self):
     def test_multi_ack_nak_flush(self):
         # same as nak test but contains a flush-pkt in the middle
         # same as nak test but contains a flush-pkt in the middle
         self._walker.lines = [
         self._walker.lines = [
-            ('have', TWO),
-            (None, None),
-            ('have', ONE),
-            ('have', THREE),
-            ('done', None),
-            ]
+          ('have', TWO),
+          (None, None),
+          ('have', ONE),
+          ('have', THREE),
+          ('done', None),
+          ]
         self.assertNextEquals(TWO)
         self.assertNextEquals(TWO)
         self.assertNoAck()
         self.assertNoAck()
 
 

+ 69 - 55
dulwich/tests/test_web.py

@@ -23,8 +23,6 @@ import re
 from unittest import TestCase
 from unittest import TestCase
 
 
 from dulwich.objects import (
 from dulwich.objects import (
-    type_map,
-    Tag,
     Blob,
     Blob,
     )
     )
 from dulwich.web import (
 from dulwich.web import (
@@ -42,9 +40,11 @@ from dulwich.web import (
 
 
 class WebTestCase(TestCase):
 class WebTestCase(TestCase):
     """Base TestCase that sets up some useful instance vars."""
     """Base TestCase that sets up some useful instance vars."""
+
     def setUp(self):
     def setUp(self):
         self._environ = {}
         self._environ = {}
-        self._req = HTTPGitRequest(self._environ, self._start_response)
+        self._req = HTTPGitRequest(self._environ, self._start_response,
+                                   handlers=self._handlers())
         self._status = None
         self._status = None
         self._headers = []
         self._headers = []
 
 
@@ -52,6 +52,9 @@ class WebTestCase(TestCase):
         self._status = status
         self._status = status
         self._headers = list(headers)
         self._headers = list(headers)
 
 
+    def _handlers(self):
+        return None
+
 
 
 class DumbHandlersTestCase(WebTestCase):
 class DumbHandlersTestCase(WebTestCase):
 
 
@@ -97,15 +100,11 @@ class DumbHandlersTestCase(WebTestCase):
         self._environ['QUERY_STRING'] = ''
         self._environ['QUERY_STRING'] = ''
 
 
         class TestTag(object):
         class TestTag(object):
-            type = Tag().type
-
-            def __init__(self, sha, obj_type, obj_sha):
+            def __init__(self, sha, obj_class, obj_sha):
                 self.sha = lambda: sha
                 self.sha = lambda: sha
-                self.object = (obj_type, obj_sha)
+                self.object = (obj_class, obj_sha)
 
 
         class TestBlob(object):
         class TestBlob(object):
-            type = Blob().type
-
             def __init__(self, sha):
             def __init__(self, sha):
                 self.sha = lambda: sha
                 self.sha = lambda: sha
 
 
@@ -113,13 +112,19 @@ class DumbHandlersTestCase(WebTestCase):
         blob2 = TestBlob('222')
         blob2 = TestBlob('222')
         blob3 = TestBlob('333')
         blob3 = TestBlob('333')
 
 
-        tag1 = TestTag('aaa', TestTag.type, 'bbb')
-        tag2 = TestTag('bbb', TestBlob.type, '222')
+        tag1 = TestTag('aaa', Blob, '222')
 
 
-        class TestBackend(object):
-            def __init__(self):
-                objects = [blob1, blob2, blob3, tag1, tag2]
-                self.repo = dict((o.sha(), o) for o in objects)
+        class TestRepo(object):
+
+            def __init__(self, objects, peeled):
+                self._objects = dict((o.sha(), o) for o in objects)
+                self._peeled = peeled
+
+            def get_peeled(self, sha):
+                return self._peeled[sha]
+
+            def __getitem__(self, sha):
+                return self._objects[sha]
 
 
             def get_refs(self):
             def get_refs(self):
                 return {
                 return {
@@ -129,43 +134,55 @@ class DumbHandlersTestCase(WebTestCase):
                     'refs/tags/blob-tag': blob3.sha(),
                     'refs/tags/blob-tag': blob3.sha(),
                     }
                     }
 
 
+        class TestBackend(object):
+            def __init__(self):
+                objects = [blob1, blob2, blob3, tag1]
+                self.repo = TestRepo(objects, {
+                  'HEAD': '000',
+                  'refs/heads/master': blob1.sha(),
+                  'refs/tags/tag-tag': blob2.sha(),
+                  'refs/tags/blob-tag': blob3.sha(),
+                  })
+
+            def open_repository(self, path):
+                assert path == '/'
+                return self.repo
+
+            def get_refs(self):
+                return {
+                  'HEAD': '000',
+                  'refs/heads/master': blob1.sha(),
+                  'refs/tags/tag-tag': tag1.sha(),
+                  'refs/tags/blob-tag': blob3.sha(),
+                  }
+
+        mat = re.search('.*', '//info/refs')
         self.assertEquals(['111\trefs/heads/master\n',
         self.assertEquals(['111\trefs/heads/master\n',
                            '333\trefs/tags/blob-tag\n',
                            '333\trefs/tags/blob-tag\n',
                            'aaa\trefs/tags/tag-tag\n',
                            'aaa\trefs/tags/tag-tag\n',
                            '222\trefs/tags/tag-tag^{}\n'],
                            '222\trefs/tags/tag-tag^{}\n'],
-                          list(get_info_refs(self._req, TestBackend(), None)))
+                          list(get_info_refs(self._req, TestBackend(), mat)))
 
 
 
 
 class SmartHandlersTestCase(WebTestCase):
 class SmartHandlersTestCase(WebTestCase):
 
 
-    class TestProtocol(object):
-        def __init__(self, handler):
-            self._handler = handler
-
-        def write_pkt_line(self, line):
-            if line is None:
-                self._handler.write('flush-pkt\n')
-            else:
-                self._handler.write('pkt-line: %s' % line)
-
     class _TestUploadPackHandler(object):
     class _TestUploadPackHandler(object):
-        def __init__(self, backend, read, write, stateless_rpc=False,
+        def __init__(self, backend, args, proto, stateless_rpc=False,
                      advertise_refs=False):
                      advertise_refs=False):
-            self.read = read
-            self.write = write
-            self.proto = SmartHandlersTestCase.TestProtocol(self)
+            self.args = args
+            self.proto = proto
             self.stateless_rpc = stateless_rpc
             self.stateless_rpc = stateless_rpc
             self.advertise_refs = advertise_refs
             self.advertise_refs = advertise_refs
 
 
         def handle(self):
         def handle(self):
-            self.write('handled input: %s' % self.read())
+            self.proto.write('handled input: %s' % self.proto.recv(1024))
 
 
-    def _MakeHandler(self, *args, **kwargs):
+    def _make_handler(self, *args, **kwargs):
         self._handler = self._TestUploadPackHandler(*args, **kwargs)
         self._handler = self._TestUploadPackHandler(*args, **kwargs)
         return self._handler
         return self._handler
 
 
-    def services(self):
-        return {'git-upload-pack': self._MakeHandler}
+    def _handlers(self):
+        return {'git-upload-pack': self._make_handler}
 
 
     def test_handle_service_request_unknown(self):
     def test_handle_service_request_unknown(self):
         mat = re.search('.*', '/git-evil-handler')
         mat = re.search('.*', '/git-evil-handler')
@@ -175,8 +192,7 @@ class SmartHandlersTestCase(WebTestCase):
     def test_handle_service_request(self):
     def test_handle_service_request(self):
         self._environ['wsgi.input'] = StringIO('foo')
         self._environ['wsgi.input'] = StringIO('foo')
         mat = re.search('.*', '/git-upload-pack')
         mat = re.search('.*', '/git-upload-pack')
-        output = ''.join(handle_service_request(self._req, 'backend', mat,
-                                                services=self.services()))
+        output = ''.join(handle_service_request(self._req, 'backend', mat))
         self.assertEqual('handled input: foo', output)
         self.assertEqual('handled input: foo', output)
         response_type = 'application/x-git-upload-pack-response'
         response_type = 'application/x-git-upload-pack-response'
         self.assertTrue(('Content-Type', response_type) in self._headers)
         self.assertTrue(('Content-Type', response_type) in self._headers)
@@ -187,26 +203,24 @@ class SmartHandlersTestCase(WebTestCase):
         self._environ['wsgi.input'] = StringIO('foobar')
         self._environ['wsgi.input'] = StringIO('foobar')
         self._environ['CONTENT_LENGTH'] = 3
         self._environ['CONTENT_LENGTH'] = 3
         mat = re.search('.*', '/git-upload-pack')
         mat = re.search('.*', '/git-upload-pack')
-        output = ''.join(handle_service_request(self._req, 'backend', mat,
-                                                services=self.services()))
+        output = ''.join(handle_service_request(self._req, 'backend', mat))
         self.assertEqual('handled input: foo', output)
         self.assertEqual('handled input: foo', output)
         response_type = 'application/x-git-upload-pack-response'
         response_type = 'application/x-git-upload-pack-response'
         self.assertTrue(('Content-Type', response_type) in self._headers)
         self.assertTrue(('Content-Type', response_type) in self._headers)
 
 
     def test_get_info_refs_unknown(self):
     def test_get_info_refs_unknown(self):
         self._environ['QUERY_STRING'] = 'service=git-evil-handler'
         self._environ['QUERY_STRING'] = 'service=git-evil-handler'
-        list(get_info_refs(self._req, 'backend', None,
-                           services=self.services()))
+        list(get_info_refs(self._req, 'backend', None))
         self.assertEquals(HTTP_FORBIDDEN, self._status)
         self.assertEquals(HTTP_FORBIDDEN, self._status)
 
 
     def test_get_info_refs(self):
     def test_get_info_refs(self):
         self._environ['wsgi.input'] = StringIO('foo')
         self._environ['wsgi.input'] = StringIO('foo')
         self._environ['QUERY_STRING'] = 'service=git-upload-pack'
         self._environ['QUERY_STRING'] = 'service=git-upload-pack'
 
 
-        output = ''.join(get_info_refs(self._req, 'backend', None,
-                                       services=self.services()))
-        self.assertEquals(('pkt-line: # service=git-upload-pack\n'
-                           'flush-pkt\n'
+        mat = re.search('.*', '/git-upload-pack')
+        output = ''.join(get_info_refs(self._req, 'backend', mat))
+        self.assertEquals(('001e# service=git-upload-pack\n'
+                           '0000'
                            # input is ignored by the handler
                            # input is ignored by the handler
                            'handled input: '), output)
                            'handled input: '), output)
         self.assertTrue(self._handler.advertise_refs)
         self.assertTrue(self._handler.advertise_refs)
@@ -257,13 +271,13 @@ class HTTPGitRequestTestCase(WebTestCase):
         self._req.respond(status=402, content_type='some/type',
         self._req.respond(status=402, content_type='some/type',
                           headers=[('X-Foo', 'foo'), ('X-Bar', 'bar')])
                           headers=[('X-Foo', 'foo'), ('X-Bar', 'bar')])
         self.assertEquals(set([
         self.assertEquals(set([
-            ('X-Foo', 'foo'),
-            ('X-Bar', 'bar'),
-            ('Content-Type', 'some/type'),
-            ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
-            ('Pragma', 'no-cache'),
-            ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
-            ]), set(self._headers))
+          ('X-Foo', 'foo'),
+          ('X-Bar', 'bar'),
+          ('Content-Type', 'some/type'),
+          ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
+          ('Pragma', 'no-cache'),
+          ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
+          ]), set(self._headers))
         self.assertEquals(402, self._status)
         self.assertEquals(402, self._status)
 
 
 
 
@@ -280,10 +294,10 @@ class HTTPGitApplicationTestCase(TestCase):
             return 'output'
             return 'output'
 
 
         self._app.services = {
         self._app.services = {
-            ('GET', re.compile('/foo$')): test_handler,
+          ('GET', re.compile('/foo$')): test_handler,
         }
         }
         environ = {
         environ = {
-            'PATH_INFO': '/foo',
-            'REQUEST_METHOD': 'GET',
-            }
+          'PATH_INFO': '/foo',
+          'REQUEST_METHOD': 'GET',
+          }
         self.assertEquals('output', self._app(environ, None))
         self.assertEquals('output', self._app(environ, None))

+ 86 - 0
dulwich/tests/utils.py

@@ -0,0 +1,86 @@
+# utils.py -- Test utilities for Dulwich.
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) any later version of
+# the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Utility functions common to Dulwich tests."""
+
+
+import datetime
+import os
+import shutil
+import tempfile
+import time
+
+from dulwich.objects import Commit
+from dulwich.repo import Repo
+
+
+def open_repo(name):
+    """Open a copy of a repo in a temporary directory.
+
+    Use this function for accessing repos in dulwich/tests/data/repos to avoid
+    accidentally or intentionally modifying those repos in place. Use
+    tear_down_repo to delete any temp files created.
+
+    :param name: The name of the repository, relative to
+        dulwich/tests/data/repos
+    :returns: An initialized Repo object that lives in a temporary directory.
+    """
+    temp_dir = tempfile.mkdtemp()
+    repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos', name)
+    temp_repo_dir = os.path.join(temp_dir, name)
+    shutil.copytree(repo_dir, temp_repo_dir, symlinks=True)
+    return Repo(temp_repo_dir)
+
+
+def tear_down_repo(repo):
+    """Tear down a test repository."""
+    temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
+    shutil.rmtree(temp_dir)
+
+
+def make_object(cls, **attrs):
+    """Make an object for testing and assign some members.
+
+    :param attrs: dict of attributes to set on the new object.
+    :return: A newly initialized object of type cls.
+    """
+    obj = cls()
+    for name, value in attrs.iteritems():
+        setattr(obj, name, value)
+    return obj
+
+
+def make_commit(**attrs):
+    """Make a Commit object with a default set of members.
+
+    :param attrs: dict of attributes to overwrite from the default values.
+    :return: A newly initialized Commit object.
+    """
+    default_time = int(time.mktime(datetime.datetime(2010, 1, 1).timetuple()))
+    all_attrs = {'author': 'Test Author <test@nodomain.com>',
+                 'author_time': default_time,
+                 'author_timezone': 0,
+                 'committer': 'Test Committer <test@nodomain.com>',
+                 'commit_time': default_time,
+                 'commit_timezone': 0,
+                 'message': 'Test message.',
+                 'parents': [],
+                 'tree': '0' * 40}
+    all_attrs.update(attrs)
+    return make_object(Commit, **all_attrs)

+ 90 - 75
dulwich/web.py

@@ -19,30 +19,30 @@
 """HTTP server for dulwich that implements the git smart HTTP protocol."""
 """HTTP server for dulwich that implements the git smart HTTP protocol."""
 
 
 from cStringIO import StringIO
 from cStringIO import StringIO
-import cgi
-import os
 import re
 import re
 import time
 import time
 
 
-from dulwich.objects import (
-    Tag,
-    num_type_map,
-    )
-from dulwich.repo import (
-    Repo,
+try:
+    from urlparse import parse_qs
+except ImportError:
+    from dulwich.misc import parse_qs
+from dulwich.protocol import (
+    ReceivableProtocol,
     )
     )
 from dulwich.server import (
 from dulwich.server import (
-    GitBackend,
     ReceivePackHandler,
     ReceivePackHandler,
     UploadPackHandler,
     UploadPackHandler,
+    DEFAULT_HANDLERS,
     )
     )
 
 
+
+# HTTP error strings
 HTTP_OK = '200 OK'
 HTTP_OK = '200 OK'
 HTTP_NOT_FOUND = '404 Not Found'
 HTTP_NOT_FOUND = '404 Not Found'
 HTTP_FORBIDDEN = '403 Forbidden'
 HTTP_FORBIDDEN = '403 Forbidden'
 
 
 
 
-def date_time_string(self, timestamp=None):
+def date_time_string(timestamp=None):
     # Based on BaseHTTPServer.py in python2.5
     # Based on BaseHTTPServer.py in python2.5
     weekdays = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
     weekdays = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
     months = [None,
     months = [None,
@@ -55,6 +55,22 @@ def date_time_string(self, timestamp=None):
             weekdays[wd], day, months[month], year, hh, mm, ss)
             weekdays[wd], day, months[month], year, hh, mm, ss)
 
 
 
 
+def url_prefix(mat):
+    """Extract the URL prefix from a regex match.
+
+    :param mat: A regex match object.
+    :returns: The URL prefix, defined as the text before the match in the
+        original string. Normalized to start with one leading slash and end with
+        zero.
+    """
+    return '/' + mat.string[:mat.start()].strip('/')
+
+
+def get_repo(backend, mat):
+    """Get a Repo instance for the given backend and URL regex match."""
+    return backend.open_repository(url_prefix(mat))
+
+
 def send_file(req, f, content_type):
 def send_file(req, f, content_type):
     """Send a file-like object to the request output.
     """Send a file-like object to the request output.
 
 
@@ -67,28 +83,30 @@ def send_file(req, f, content_type):
         yield req.not_found('File not found')
         yield req.not_found('File not found')
         return
         return
     try:
     try:
-        try:
-            req.respond(HTTP_OK, content_type)
-            while True:
-                data = f.read(10240)
-                if not data:
-                    break
-                yield data
-        except IOError:
-            yield req.not_found('Error reading file')
-    finally:
+        req.respond(HTTP_OK, content_type)
+        while True:
+            data = f.read(10240)
+            if not data:
+                break
+            yield data
+        f.close()
+    except IOError:
         f.close()
         f.close()
+        yield req.not_found('Error reading file')
+    except:
+        f.close()
+        raise
 
 
 
 
 def get_text_file(req, backend, mat):
 def get_text_file(req, backend, mat):
     req.nocache()
     req.nocache()
-    return send_file(req, backend.repo.get_named_file(mat.group()),
+    return send_file(req, get_repo(backend, mat).get_named_file(mat.group()),
                      'text/plain')
                      'text/plain')
 
 
 
 
 def get_loose_object(req, backend, mat):
 def get_loose_object(req, backend, mat):
     sha = mat.group(1) + mat.group(2)
     sha = mat.group(1) + mat.group(2)
-    object_store = backend.object_store
+    object_store = get_repo(backend, mat).object_store
     if not object_store.contains_loose(sha):
     if not object_store.contains_loose(sha):
         yield req.not_found('Object not found')
         yield req.not_found('Object not found')
         return
         return
@@ -103,33 +121,29 @@ def get_loose_object(req, backend, mat):
 
 
 def get_pack_file(req, backend, mat):
 def get_pack_file(req, backend, mat):
     req.cache_forever()
     req.cache_forever()
-    return send_file(req, backend.repo.get_named_file(mat.group()),
-                     'application/x-git-packed-objects', False)
+    return send_file(req, get_repo(backend, mat).get_named_file(mat.group()),
+                     'application/x-git-packed-objects')
 
 
 
 
 def get_idx_file(req, backend, mat):
 def get_idx_file(req, backend, mat):
     req.cache_forever()
     req.cache_forever()
-    return send_file(req, backend.repo.get_named_file(mat.group()),
-                     'application/x-git-packed-objects-toc', False)
+    return send_file(req, get_repo(backend, mat).get_named_file(mat.group()),
+                     'application/x-git-packed-objects-toc')
 
 
 
 
-services = {'git-upload-pack': UploadPackHandler,
-            'git-receive-pack': ReceivePackHandler}
-def get_info_refs(req, backend, mat, services=None):
-    if services is None:
-        services = services
-    params = cgi.parse_qs(req.environ['QUERY_STRING'])
+def get_info_refs(req, backend, mat):
+    params = parse_qs(req.environ['QUERY_STRING'])
     service = params.get('service', [None])[0]
     service = params.get('service', [None])[0]
-    if service:
-        handler_cls = services.get(service, None)
+    if service and not req.dumb:
+        handler_cls = req.handlers.get(service, None)
         if handler_cls is None:
         if handler_cls is None:
             yield req.forbidden('Unsupported service %s' % service)
             yield req.forbidden('Unsupported service %s' % service)
             return
             return
         req.nocache()
         req.nocache()
         req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
         req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
         output = StringIO()
         output = StringIO()
-        dummy_input = StringIO()  # GET request, handler doesn't need to read
-        handler = handler_cls(backend, dummy_input.read, output.write,
+        proto = ReceivableProtocol(StringIO().read, output.write)
+        handler = handler_cls(backend, [url_prefix(mat)], proto,
                               stateless_rpc=True, advertise_refs=True)
                               stateless_rpc=True, advertise_refs=True)
         handler.proto.write_pkt_line('# service=%s\n' % service)
         handler.proto.write_pkt_line('# service=%s\n' % service)
         handler.proto.write_pkt_line(None)
         handler.proto.write_pkt_line(None)
@@ -140,32 +154,27 @@ def get_info_refs(req, backend, mat, services=None):
         # TODO: select_getanyfile() (see http-backend.c)
         # TODO: select_getanyfile() (see http-backend.c)
         req.nocache()
         req.nocache()
         req.respond(HTTP_OK, 'text/plain')
         req.respond(HTTP_OK, 'text/plain')
-        refs = backend.get_refs()
+        repo = get_repo(backend, mat)
+        refs = repo.get_refs()
         for name in sorted(refs.iterkeys()):
         for name in sorted(refs.iterkeys()):
             # get_refs() includes HEAD as a special case, but we don't want to
             # get_refs() includes HEAD as a special case, but we don't want to
             # advertise it
             # advertise it
             if name == 'HEAD':
             if name == 'HEAD':
                 continue
                 continue
             sha = refs[name]
             sha = refs[name]
-            o = backend.repo[sha]
+            o = repo[sha]
             if not o:
             if not o:
                 continue
                 continue
             yield '%s\t%s\n' % (sha, name)
             yield '%s\t%s\n' % (sha, name)
-            obj_type = num_type_map[o.type]
-            if obj_type == Tag:
-                while obj_type == Tag:
-                    num_type, sha = o.object
-                    obj_type = num_type_map[num_type]
-                    o = backend.repo[sha]
-                if not o:
-                    continue
-                yield '%s\t%s^{}\n' % (o.sha(), name)
+            peeled_sha = repo.get_peeled(name)
+            if peeled_sha != sha:
+                yield '%s\t%s^{}\n' % (peeled_sha, name)
 
 
 
 
 def get_info_packs(req, backend, mat):
 def get_info_packs(req, backend, mat):
     req.nocache()
     req.nocache()
     req.respond(HTTP_OK, 'text/plain')
     req.respond(HTTP_OK, 'text/plain')
-    for pack in backend.object_store.packs:
+    for pack in get_repo(backend, mat).object_store.packs:
         yield 'P pack-%s.pack\n' % pack.name()
         yield 'P pack-%s.pack\n' % pack.name()
 
 
 
 
@@ -176,6 +185,7 @@ class _LengthLimitedFile(object):
     Content-Length bytes are read. This behavior is required by the WSGI spec
     Content-Length bytes are read. This behavior is required by the WSGI spec
     but not implemented in wsgiref as of 2.5.
     but not implemented in wsgiref as of 2.5.
     """
     """
+
     def __init__(self, input, max_bytes):
     def __init__(self, input, max_bytes):
         self._input = input
         self._input = input
         self._bytes_avail = max_bytes
         self._bytes_avail = max_bytes
@@ -190,11 +200,10 @@ class _LengthLimitedFile(object):
 
 
     # TODO: support more methods as necessary
     # TODO: support more methods as necessary
 
 
-def handle_service_request(req, backend, mat, services=services):
-    if services is None:
-        services = services
+
+def handle_service_request(req, backend, mat):
     service = mat.group().lstrip('/')
     service = mat.group().lstrip('/')
-    handler_cls = services.get(service, None)
+    handler_cls = req.handlers.get(service, None)
     if handler_cls is None:
     if handler_cls is None:
         yield req.forbidden('Unsupported service %s' % service)
         yield req.forbidden('Unsupported service %s' % service)
         return
         return
@@ -209,7 +218,8 @@ def handle_service_request(req, backend, mat, services=services):
     # content-length
     # content-length
     if 'CONTENT_LENGTH' in req.environ:
     if 'CONTENT_LENGTH' in req.environ:
         input = _LengthLimitedFile(input, int(req.environ['CONTENT_LENGTH']))
         input = _LengthLimitedFile(input, int(req.environ['CONTENT_LENGTH']))
-    handler = handler_cls(backend, input.read, output.write, stateless_rpc=True)
+    proto = ReceivableProtocol(input.read, output.write)
+    handler = handler_cls(backend, [url_prefix(mat)], proto, stateless_rpc=True)
     handler.handle()
     handler.handle()
     yield output.getvalue()
     yield output.getvalue()
 
 
@@ -220,8 +230,10 @@ class HTTPGitRequest(object):
     :ivar environ: the WSGI environment for the request.
     :ivar environ: the WSGI environment for the request.
     """
     """
 
 
-    def __init__(self, environ, start_response):
+    def __init__(self, environ, start_response, dumb=False, handlers=None):
         self.environ = environ
         self.environ = environ
+        self.dumb = dumb
+        self.handlers = handlers and handlers or DEFAULT_HANDLERS
         self._start_response = start_response
         self._start_response = start_response
         self._cache_headers = []
         self._cache_headers = []
         self._headers = []
         self._headers = []
@@ -255,19 +267,19 @@ class HTTPGitRequest(object):
     def nocache(self):
     def nocache(self):
         """Set the response to never be cached by the client."""
         """Set the response to never be cached by the client."""
         self._cache_headers = [
         self._cache_headers = [
-            ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
-            ('Pragma', 'no-cache'),
-            ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
-            ]
+          ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
+          ('Pragma', 'no-cache'),
+          ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
+          ]
 
 
     def cache_forever(self):
     def cache_forever(self):
         """Set the response to be cached forever by the client."""
         """Set the response to be cached forever by the client."""
         now = time.time()
         now = time.time()
         self._cache_headers = [
         self._cache_headers = [
-            ('Date', date_time_string(now)),
-            ('Expires', date_time_string(now + 31536000)),
-            ('Cache-Control', 'public, max-age=31536000'),
-            ]
+          ('Date', date_time_string(now)),
+          ('Expires', date_time_string(now + 31536000)),
+          ('Cache-Control', 'public, max-age=31536000'),
+          ]
 
 
 
 
 class HTTPGitApplication(object):
 class HTTPGitApplication(object):
@@ -277,26 +289,29 @@ class HTTPGitApplication(object):
     """
     """
 
 
     services = {
     services = {
-        ('GET', re.compile('/HEAD$')): get_text_file,
-        ('GET', re.compile('/info/refs$')): get_info_refs,
-        ('GET', re.compile('/objects/info/alternates$')): get_text_file,
-        ('GET', re.compile('/objects/info/http-alternates$')): get_text_file,
-        ('GET', re.compile('/objects/info/packs$')): get_info_packs,
-        ('GET', re.compile('/objects/([0-9a-f]{2})/([0-9a-f]{38})$')): get_loose_object,
-        ('GET', re.compile('/objects/pack/pack-([0-9a-f]{40})\\.pack$')): get_pack_file,
-        ('GET', re.compile('/objects/pack/pack-([0-9a-f]{40})\\.idx$')): get_idx_file,
-
-        ('POST', re.compile('/git-upload-pack$')): handle_service_request,
-        ('POST', re.compile('/git-receive-pack$')): handle_service_request,
+      ('GET', re.compile('/HEAD$')): get_text_file,
+      ('GET', re.compile('/info/refs$')): get_info_refs,
+      ('GET', re.compile('/objects/info/alternates$')): get_text_file,
+      ('GET', re.compile('/objects/info/http-alternates$')): get_text_file,
+      ('GET', re.compile('/objects/info/packs$')): get_info_packs,
+      ('GET', re.compile('/objects/([0-9a-f]{2})/([0-9a-f]{38})$')): get_loose_object,
+      ('GET', re.compile('/objects/pack/pack-([0-9a-f]{40})\\.pack$')): get_pack_file,
+      ('GET', re.compile('/objects/pack/pack-([0-9a-f]{40})\\.idx$')): get_idx_file,
+
+      ('POST', re.compile('/git-upload-pack$')): handle_service_request,
+      ('POST', re.compile('/git-receive-pack$')): handle_service_request,
     }
     }
 
 
-    def __init__(self, backend):
+    def __init__(self, backend, dumb=False, handlers=None):
         self.backend = backend
         self.backend = backend
+        self.dumb = dumb
+        self.handlers = handlers
 
 
     def __call__(self, environ, start_response):
     def __call__(self, environ, start_response):
         path = environ['PATH_INFO']
         path = environ['PATH_INFO']
         method = environ['REQUEST_METHOD']
         method = environ['REQUEST_METHOD']
-        req = HTTPGitRequest(environ, start_response)
+        req = HTTPGitRequest(environ, start_response, dumb=self.dumb,
+                             handlers=self.handlers)
         # environ['QUERY_STRING'] has qs args
         # environ['QUERY_STRING'] has qs args
         handler = None
         handler = None
         for smethod, spath in self.services.iterkeys():
         for smethod, spath in self.services.iterkeys():

+ 1 - 1
setup.py

@@ -8,7 +8,7 @@ except ImportError:
     from distutils.core import setup, Extension
     from distutils.core import setup, Extension
 from distutils.core import Distribution
 from distutils.core import Distribution
 
 
-dulwich_version_string = '0.5.0'
+dulwich_version_string = '0.6.0'
 
 
 include_dirs = []
 include_dirs = []
 # Windows MSVC support
 # Windows MSVC support

Some files were not shown because too many files changed in this diff