Jelmer Vernooij преди 15 години
родител
ревизия
27c851a5ad
променени са 34 файла, в които са добавени 3202 реда и са изтрити 506 реда
  1. 29 0
      NEWS
  2. 3 2
      bin/dul-daemon
  3. 3 2
      bin/dul-receive-pack
  4. 3 2
      bin/dul-upload-pack
  5. 37 0
      bin/dul-web
  6. 3 2
      debian/changelog
  7. 1 1
      dulwich/__init__.py
  8. 3 0
      dulwich/_pack.c
  9. 12 0
      dulwich/errors.py
  10. 138 0
      dulwich/file.py
  11. 27 5
      dulwich/index.py
  12. 133 81
      dulwich/object_store.py
  13. 23 10
      dulwich/objects.py
  14. 104 88
      dulwich/pack.py
  15. 33 5
      dulwich/protocol.py
  16. 565 166
      dulwich/repo.py
  17. 414 99
      dulwich/server.py
  18. 1 0
      dulwich/tests/data/repos/refs.git/HEAD
  19. BIN
      dulwich/tests/data/repos/refs.git/objects/3b/9e5457140e738c2dcd39bf6d7acf88379b90d1
  20. BIN
      dulwich/tests/data/repos/refs.git/objects/42/d06bd4b77fed026b154d16493e5deab78f02ec
  21. BIN
      dulwich/tests/data/repos/refs.git/objects/a1/8114c31713746a33a2e70d9914d1ef3e781425
  22. BIN
      dulwich/tests/data/repos/refs.git/objects/df/6800012397fb85c56e7418dd4eb9405dee075c
  23. 3 0
      dulwich/tests/data/repos/refs.git/packed-refs
  24. 1 0
      dulwich/tests/data/repos/refs.git/refs/heads/loop
  25. 1 0
      dulwich/tests/data/repos/refs.git/refs/heads/master
  26. 131 0
      dulwich/tests/test_file.py
  27. 35 0
      dulwich/tests/test_objects.py
  28. 3 2
      dulwich/tests/test_pack.py
  29. 28 3
      dulwich/tests/test_protocol.py
  30. 324 31
      dulwich/tests/test_repository.py
  31. 519 0
      dulwich/tests/test_server.py
  32. 289 0
      dulwich/tests/test_web.py
  33. 311 0
      dulwich/web.py
  34. 25 7
      setup.py

+ 29 - 0
NEWS

@@ -1,3 +1,32 @@
+0.5.0	2010-03-03
+
+ BUG FIXES
+
+  * Support custom fields in commits.
+
+  * Improved ref handling. (Dave Borowitz)
+
+  * Rework server protocol to be smarter and interoperate with cgit client.
+    (Dave Borowitz)
+
+  * Add a GitFile class that uses the same locking protocol for writes as 
+    cgit. (Dave Borowitz)
+
+  * Cope with forward slashes correctly in the index on Windows.
+    (Jelmer Vernooij, #526793)
+
+ FEATURES
+
+  * --pure option to setup.py to allow building/installing without the C 
+    extensions. (Hal Wine, Anatoly Techtonik, Jelmer Vernooij, #434326)
+
+  * Implement Repo.get_config(). (Jelmer Vernooij, Augie Fackler)
+
+  * HTTP dumb and smart server. (Dave Borowitz)
+
+  * Add abstract baseclass for Repo that does not require file system 
+    operations. (Dave Borowitz)
+
 0.4.1	2010-01-03
 0.4.1	2010-01-03
 
 
  FEATURES
  FEATURES

+ 3 - 2
bin/dul-daemon

@@ -5,7 +5,7 @@
 # This program is free software; you can redistribute it and/or
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # as published by the Free Software Foundation; version 2
-# of the License.
+# or (at your option) a later version of the License.
 # 
 # 
 # This program is distributed in the hope that it will be useful,
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
@@ -18,6 +18,7 @@
 # MA  02110-1301, USA.
 # MA  02110-1301, USA.
 
 
 import sys
 import sys
+from dulwich.repo import Repo
 from dulwich.server import GitBackend, TCPGitServer
 from dulwich.server import GitBackend, TCPGitServer
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
@@ -25,6 +26,6 @@ if __name__ == "__main__":
     if len(sys.argv) > 1:
     if len(sys.argv) > 1:
         gitdir = sys.argv[1]
         gitdir = sys.argv[1]
 
 
-    backend = GitBackend(gitdir)
+    backend = GitBackend(Repo(gitdir))
     server = TCPGitServer(backend, 'localhost')
     server = TCPGitServer(backend, 'localhost')
     server.serve_forever()
     server.serve_forever()

+ 3 - 2
bin/dul-receive-pack

@@ -5,7 +5,7 @@
 # This program is free software; you can redistribute it and/or
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # as published by the Free Software Foundation; version 2
-# of the License.
+# or (at your option) a later version of the License.
 # 
 # 
 # This program is distributed in the hope that it will be useful,
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
@@ -18,6 +18,7 @@
 # MA  02110-1301, USA.
 # MA  02110-1301, USA.
 
 
 import sys
 import sys
+from dulwich.repo import Repo
 from dulwich.server import GitBackend, ReceivePackHandler
 from dulwich.server import GitBackend, ReceivePackHandler
 
 
 def send_fn(data):
 def send_fn(data):
@@ -29,6 +30,6 @@ if __name__ == "__main__":
     if len(sys.argv) > 1:
     if len(sys.argv) > 1:
         gitdir = sys.argv[1]
         gitdir = sys.argv[1]
 
 
-    backend = GitBackend(gitdir)
+    backend = GitBackend(Repo(gitdir))
     handler = ReceivePackHandler(backend, sys.stdin.read, send_fn)
     handler = ReceivePackHandler(backend, sys.stdin.read, send_fn)
     handler.handle()
     handler.handle()

+ 3 - 2
bin/dul-upload-pack

@@ -5,7 +5,7 @@
 # This program is free software; you can redistribute it and/or
 # This program is free software; you can redistribute it and/or
 # modify it under the terms of the GNU General Public License
 # modify it under the terms of the GNU General Public License
 # as published by the Free Software Foundation; version 2
 # as published by the Free Software Foundation; version 2
-# of the License.
+# or (at your option) a later version of the License.
 # 
 # 
 # This program is distributed in the hope that it will be useful,
 # This program is distributed in the hope that it will be useful,
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
 # but WITHOUT ANY WARRANTY; without even the implied warranty of
@@ -18,6 +18,7 @@
 # MA  02110-1301, USA.
 # MA  02110-1301, USA.
 
 
 import sys
 import sys
+from dulwich.repo import Repo
 from dulwich.server import GitBackend, UploadPackHandler
 from dulwich.server import GitBackend, UploadPackHandler
 
 
 def send_fn(data):
 def send_fn(data):
@@ -29,6 +30,6 @@ if __name__ == "__main__":
     if len(sys.argv) > 1:
     if len(sys.argv) > 1:
         gitdir = sys.argv[1]
         gitdir = sys.argv[1]
 
 
-    backend = GitBackend(gitdir)
+    backend = GitBackend(Repo(gitdir))
     handler = UploadPackHandler(backend, sys.stdin.read, send_fn)
     handler = UploadPackHandler(backend, sys.stdin.read, send_fn)
     handler.handle()
     handler.handle()

+ 37 - 0
bin/dul-web

@@ -0,0 +1,37 @@
+#!/usr/bin/python
+# dul-web - HTTP-based git server
+# Copyright (C) 2010 David Borowitz <dborowitz@google.com>
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# or (at your option) a later version of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+import os
+import sys
+from dulwich.repo import Repo
+from dulwich.server import GitBackend
+from dulwich.web import HTTPGitApplication
+from wsgiref.simple_server import make_server
+
+if __name__ == "__main__":
+    if len(sys.argv) > 1:
+        gitdir = sys.argv[1]
+    else:
+        gitdir = os.getcwd()
+
+    backend = GitBackend(Repo(gitdir))
+    app = HTTPGitApplication(backend)
+    # TODO: allow serving on other ports via command-line flag
+    server = make_server('', 8000, app)
+    server.serve_forever()

+ 3 - 2
debian/changelog

@@ -1,9 +1,10 @@
-dulwich (0.4.1-2) UNRELEASED; urgency=low
+dulwich (0.5.0-1) unstable; urgency=low
 
 
+  * New upstream release.
   * Switch to dpkg-source 3.0 (quilt) format.
   * Switch to dpkg-source 3.0 (quilt) format.
   * Bump standards version to 3.8.4.
   * Bump standards version to 3.8.4.
 
 
- -- Jelmer Vernooij <jelmer@debian.org>  Thu, 04 Feb 2010 12:29:43 +0100
+ -- Jelmer Vernooij <jelmer@debian.org>  Wed, 03 Mar 2010 16:43:41 +0100
 
 
 dulwich (0.4.1-1) unstable; urgency=low
 dulwich (0.4.1-1) unstable; urgency=low
 
 

+ 1 - 1
dulwich/__init__.py

@@ -27,4 +27,4 @@ import protocol
 import repo
 import repo
 import server
 import server
 
 
-__version__ = (0, 4, 1)
+__version__ = (0, 5, 0)

+ 3 - 0
dulwich/_pack.c

@@ -110,17 +110,20 @@ static PyObject *py_apply_delta(PyObject *self, PyObject *args)
             index += cmd;
             index += cmd;
 		} else {
 		} else {
 			PyErr_SetString(PyExc_ValueError, "Invalid opcode 0");
 			PyErr_SetString(PyExc_ValueError, "Invalid opcode 0");
+			Py_DECREF(ret);
 			return NULL;
 			return NULL;
 		}
 		}
 	}
 	}
     
     
     if (index != delta_len) {
     if (index != delta_len) {
 		PyErr_SetString(PyExc_ValueError, "delta not empty");
 		PyErr_SetString(PyExc_ValueError, "delta not empty");
+		Py_DECREF(ret);
 		return NULL;
 		return NULL;
 	}
 	}
 
 
 	if (dest_size != outindex) {
 	if (dest_size != outindex) {
         PyErr_SetString(PyExc_ValueError, "dest size incorrect");
         PyErr_SetString(PyExc_ValueError, "dest size incorrect");
+		Py_DECREF(ret);
 		return NULL;
 		return NULL;
 	}
 	}
 
 

+ 12 - 0
dulwich/errors.py

@@ -108,3 +108,15 @@ class HangupException(GitProtocolError):
     def __init__(self):
     def __init__(self):
         Exception.__init__(self,
         Exception.__init__(self,
             "The remote server unexpectedly closed the connection.")
             "The remote server unexpectedly closed the connection.")
+
+
+class FileFormatException(Exception):
+    """Base class for exceptions relating to reading git file formats."""
+
+
+class PackedRefsException(FileFormatException):
+    """Indicates an error parsing a packed-refs file."""
+
+
+class NoIndexPresent(Exception):
+    """No index is present."""

+ 138 - 0
dulwich/file.py

@@ -0,0 +1,138 @@
+# file.py -- Safe access to git files
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) a later version of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+
+"""Safe access to git files."""
+
+
+import errno
+import os
+
+def ensure_dir_exists(dirname):
+    """Ensure a directory exists, creating if necessary."""
+    try:
+        os.makedirs(dirname)
+    except OSError, e:
+        if e.errno != errno.EEXIST:
+            raise
+
+def GitFile(filename, mode='r', bufsize=-1):
+    """Create a file object that obeys the git file locking protocol.
+
+    See _GitFile for a description of the file locking protocol.
+
+    Only read-only and write-only (binary) modes are supported; r+, w+, and a
+    are not.  To read and write from the same file, you can take advantage of
+    the fact that opening a file for write does not actually open the file you
+    request:
+
+    >>> write_file = GitFile('filename', 'wb')
+    >>> read_file = GitFile('filename', 'rb')
+    >>> read_file.readlines()
+    ['contents\n', 'of\n', 'the\n', 'file\n']
+    >>> write_file.write('foo')
+    >>> read_file.close()
+    >>> write_file.close()
+    >>> new_file = GitFile('filename', 'rb')
+    'foo'
+    >>> new_file.close()
+    >>> other_file = GitFile('filename', 'wb')
+    Traceback (most recent call last):
+        ...
+    OSError: [Errno 17] File exists: 'filename.lock'
+
+    :return: a builtin file object or a _GitFile object
+    """
+    if 'a' in mode:
+        raise IOError('append mode not supported for Git files')
+    if '+' in mode:
+        raise IOError('read/write mode not supported for Git files')
+    if 'b' not in mode:
+        raise IOError('text mode not supported for Git files')
+    if 'w' in mode:
+        return _GitFile(filename, mode, bufsize)
+    else:
+        return file(filename, mode, bufsize)
+
+
+class _GitFile(object):
+    """File that follows the git locking protocol for writes.
+
+    All writes to a file foo will be written into foo.lock in the same
+    directory, and the lockfile will be renamed to overwrite the original file
+    on close.
+
+    :note: You *must* call close() or abort() on a _GitFile for the lock to be
+        released. Typically this will happen in a finally block.
+    """
+
+    PROXY_PROPERTIES = set(['closed', 'encoding', 'errors', 'mode', 'name',
+                            'newlines', 'softspace'])
+    PROXY_METHODS = ('__iter__', 'flush', 'fileno', 'isatty', 'next', 'read',
+                     'readline', 'readlines', 'xreadlines', 'seek', 'tell',
+                     'truncate', 'write', 'writelines')
+    def __init__(self, filename, mode, bufsize):
+        self._filename = filename
+        self._lockfilename = '%s.lock' % self._filename
+        fd = os.open(self._lockfilename, os.O_RDWR | os.O_CREAT | os.O_EXCL)
+        self._file = os.fdopen(fd, mode, bufsize)
+        self._closed = False
+
+        for method in self.PROXY_METHODS:
+            setattr(self, method, getattr(self._file, method))
+
+    def abort(self):
+        """Close and discard the lockfile without overwriting the target.
+
+        If the file is already closed, this is a no-op.
+        """
+        if self._closed:
+            return
+        self._file.close()
+        try:
+            os.remove(self._lockfilename)
+            self._closed = True
+        except OSError, e:
+            # The file may have been removed already, which is ok.
+            if e.errno != errno.ENOENT:
+                raise
+
+    def close(self):
+        """Close this file, saving the lockfile over the original.
+
+        :note: If this method fails, it will attempt to delete the lockfile.
+            However, it is not guaranteed to do so (e.g. if a filesystem becomes
+            suddenly read-only), which will prevent future writes to this file
+            until the lockfile is removed manually.
+        :raises OSError: if the original file could not be overwritten. The lock
+            file is still closed, so further attempts to write to the same file
+            object will raise ValueError.
+        """
+        if self._closed:
+            return
+        self._file.close()
+        try:
+            os.rename(self._lockfilename, self._filename)
+        finally:
+            self.abort()
+
+    def __getattr__(self, name):
+        """Proxy property calls to the underlying file."""
+        if name in self.PROXY_PROPERTIES:
+            return getattr(self._file, name)
+        raise AttributeError(name)

+ 27 - 5
dulwich/index.py

@@ -22,6 +22,7 @@ import os
 import stat
 import stat
 import struct
 import struct
 
 
+from dulwich.file import GitFile
 from dulwich.objects import (
 from dulwich.objects import (
     S_IFGITLINK,
     S_IFGITLINK,
     S_ISGITLINK,
     S_ISGITLINK,
@@ -35,6 +36,27 @@ from dulwich.pack import (
     )
     )
 
 
 
 
+def pathsplit(path):
+    """Split a /-delimited path into a directory part and a basename.
+
+    :param path: The path to split.
+    :return: Tuple with directory name and basename
+    """
+    try:
+        (dirname, basename) = path.rsplit("/", 1)
+    except ValueError:
+        return ("", path)
+    else:
+        return (dirname, basename)
+
+
+def pathjoin(*args):
+    """Join a /-delimited path.
+
+    """
+    return "/".join([p for p in args if p])
+
+
 def read_cache_time(f):
 def read_cache_time(f):
     """Read a cache time.
     """Read a cache time.
     
     
@@ -173,7 +195,7 @@ class Index(object):
 
 
     def write(self):
     def write(self):
         """Write current contents of index to disk."""
         """Write current contents of index to disk."""
-        f = open(self._filename, 'wb')
+        f = GitFile(self._filename, 'wb')
         try:
         try:
             f = SHA1Writer(f)
             f = SHA1Writer(f)
             write_index_dict(f, self._byname)
             write_index_dict(f, self._byname)
@@ -182,7 +204,7 @@ class Index(object):
 
 
     def read(self):
     def read(self):
         """Read current contents of index from disk."""
         """Read current contents of index from disk."""
-        f = open(self._filename, 'rb')
+        f = GitFile(self._filename, 'rb')
         try:
         try:
             f = SHA1Reader(f)
             f = SHA1Reader(f)
             for x in read_index(f):
             for x in read_index(f):
@@ -273,7 +295,7 @@ def commit_tree(object_store, blobs):
     def add_tree(path):
     def add_tree(path):
         if path in trees:
         if path in trees:
             return trees[path]
             return trees[path]
-        dirname, basename = os.path.split(path)
+        dirname, basename = pathsplit(path)
         t = add_tree(dirname)
         t = add_tree(dirname)
         assert isinstance(basename, str)
         assert isinstance(basename, str)
         newtree = {}
         newtree = {}
@@ -282,7 +304,7 @@ def commit_tree(object_store, blobs):
         return newtree
         return newtree
 
 
     for path, sha, mode in blobs:
     for path, sha, mode in blobs:
-        tree_path, basename = os.path.split(path)
+        tree_path, basename = pathsplit(path)
         tree = add_tree(tree_path)
         tree = add_tree(tree_path)
         tree[basename] = (mode, sha)
         tree[basename] = (mode, sha)
 
 
@@ -291,7 +313,7 @@ def commit_tree(object_store, blobs):
         for basename, entry in trees[path].iteritems():
         for basename, entry in trees[path].iteritems():
             if type(entry) == dict:
             if type(entry) == dict:
                 mode = stat.S_IFDIR
                 mode = stat.S_IFDIR
-                sha = build_tree(os.path.join(path, basename))
+                sha = build_tree(pathjoin(path, basename))
             else:
             else:
                 (mode, sha) = entry
                 (mode, sha) = entry
             tree.add(mode, basename, sha)
             tree.add(mode, basename, sha)

+ 133 - 81
dulwich/object_store.py

@@ -20,6 +20,7 @@
 """Git object store interfaces and implementation."""
 """Git object store interfaces and implementation."""
 
 
 
 
+import errno
 import itertools
 import itertools
 import os
 import os
 import stat
 import stat
@@ -29,6 +30,7 @@ import urllib2
 from dulwich.errors import (
 from dulwich.errors import (
     NotTreeError,
     NotTreeError,
     )
     )
+from dulwich.file import GitFile
 from dulwich.objects import (
 from dulwich.objects import (
     Commit,
     Commit,
     ShaFile,
     ShaFile,
@@ -65,9 +67,25 @@ class BaseObjectStore(object):
         """
         """
         return ObjectStoreIterator(self, shas)
         return ObjectStoreIterator(self, shas)
 
 
+    def contains_loose(self, sha):
+        """Check if a particular object is present by SHA1 and is loose."""
+        raise NotImplementedError(self.contains_loose)
+
+    def contains_packed(self, sha):
+        """Check if a particular object is present by SHA1 and is packed."""
+        raise NotImplementedError(self.contains_packed)
+
     def __contains__(self, sha):
     def __contains__(self, sha):
-        """Check if a particular object is present by SHA1."""
-        raise NotImplementedError(self.__contains__)
+        """Check if a particular object is present by SHA1.
+
+        This method makes no distinction between loose and packed objects.
+        """
+        return self.contains_packed(sha) or self.contains_loose(sha)
+
+    @property
+    def packs(self):
+        """Iterable of pack objects."""
+        raise NotImplementedError
 
 
     def get_raw(self, name):
     def get_raw(self, name):
         """Obtain the raw text for an object.
         """Obtain the raw text for an object.
@@ -220,85 +238,49 @@ class BaseObjectStore(object):
         return self.iter_shas(self.find_missing_objects(have, want))
         return self.iter_shas(self.find_missing_objects(have, want))
 
 
 
 
-class DiskObjectStore(BaseObjectStore):
-    """Git-style object store that exists on disk."""
-
-    def __init__(self, path):
-        """Open an object store.
+class PackBasedObjectStore(BaseObjectStore):
 
 
-        :param path: Path of the object store.
-        """
-        self.path = path
+    def __init__(self):
         self._pack_cache = None
         self._pack_cache = None
-        self.pack_dir = os.path.join(self.path, PACKDIR)
 
 
-    def __contains__(self, sha):
-        """Check if a particular object is present by SHA1."""
+    def contains_packed(self, sha):
+        """Check if a particular object is present by SHA1 and is packed."""
         for pack in self.packs:
         for pack in self.packs:
             if sha in pack:
             if sha in pack:
                 return True
                 return True
-        ret = self._get_shafile(sha)
-        if ret is not None:
-            return True
         return False
         return False
 
 
-    def __iter__(self):
-        """Iterate over the SHAs that are present in this store."""
-        iterables = self.packs + [self._iter_shafile_shas()]
-        return itertools.chain(*iterables)
-
-    @property
-    def packs(self):
-        """List with pack objects."""
-        if self._pack_cache is None:
-            self._pack_cache = list(self._load_packs())
-        return self._pack_cache
-
     def _load_packs(self):
     def _load_packs(self):
-        if not os.path.exists(self.pack_dir):
-            return
-        for name in os.listdir(self.pack_dir):
-            if name.startswith("pack-") and name.endswith(".pack"):
-                yield Pack(os.path.join(self.pack_dir, name[:-len(".pack")]))
+        raise NotImplementedError(self._load_packs)
 
 
-    def _add_known_pack(self, path):
+    def _add_known_pack(self, pack):
         """Add a newly appeared pack to the cache by path.
         """Add a newly appeared pack to the cache by path.
 
 
         """
         """
         if self._pack_cache is not None:
         if self._pack_cache is not None:
-            self._pack_cache.append(Pack(path))
+            self._pack_cache.append(pack)
 
 
-    def _get_shafile_path(self, sha):
-        dir = sha[:2]
-        file = sha[2:]
-        # Check from object dir
-        return os.path.join(self.path, dir, file)
+    @property
+    def packs(self):
+        """List with pack objects."""
+        if self._pack_cache is None:
+            self._pack_cache = self._load_packs()
+        return self._pack_cache
 
 
-    def _iter_shafile_shas(self):
-        for base in os.listdir(self.path):
-            if len(base) != 2:
-                continue
-            for rest in os.listdir(os.path.join(self.path, base)):
-                yield base+rest
+    def _iter_loose_objects(self):
+        raise NotImplementedError(self._iter_loose_objects)
 
 
-    def _get_shafile(self, sha):
-        path = self._get_shafile_path(sha)
-        if os.path.exists(path):
-          return ShaFile.from_file(path)
-        return None
+    def _get_loose_object(self, sha):
+        raise NotImplementedError(self._get_loose_object)
 
 
-    def _add_shafile(self, sha, o):
-        dir = os.path.join(self.path, sha[:2])
-        if not os.path.isdir(dir):
-            os.mkdir(dir)
-        path = os.path.join(dir, sha[2:])
-        if os.path.exists(path):
-            return # Already there, no need to write again
-        f = open(path, 'w+')
-        try:
-            f.write(o.as_legacy_object())
-        finally:
-            f.close()
+    def __iter__(self):
+        """Iterate over the SHAs that are present in this store."""
+        iterables = self.packs + [self._iter_loose_objects()]
+        return itertools.chain(*iterables)
+
+    def contains_loose(self, sha):
+        """Check if a particular object is present by SHA1 and is loose."""
+        return self._get_loose_object(sha) is not None
 
 
     def get_raw(self, name):
     def get_raw(self, name):
         """Obtain the raw text for an object.
         """Obtain the raw text for an object.
@@ -321,11 +303,74 @@ class DiskObjectStore(BaseObjectStore):
                 pass
                 pass
         if hexsha is None: 
         if hexsha is None: 
             hexsha = sha_to_hex(name)
             hexsha = sha_to_hex(name)
-        ret = self._get_shafile(hexsha)
+        ret = self._get_loose_object(hexsha)
         if ret is not None:
         if ret is not None:
             return ret.type, ret.as_raw_string()
             return ret.type, ret.as_raw_string()
         raise KeyError(hexsha)
         raise KeyError(hexsha)
 
 
+    def add_objects(self, objects):
+        """Add a set of objects to this object store.
+
+        :param objects: Iterable over objects, should support __len__.
+        """
+        if len(objects) == 0:
+            # Don't bother writing an empty pack file
+            return
+        f, commit = self.add_pack()
+        write_pack_data(f, objects, len(objects))
+        commit()
+
+
+class DiskObjectStore(PackBasedObjectStore):
+    """Git-style object store that exists on disk."""
+
+    def __init__(self, path):
+        """Open an object store.
+
+        :param path: Path of the object store.
+        """
+        super(DiskObjectStore, self).__init__()
+        self.path = path
+        self.pack_dir = os.path.join(self.path, PACKDIR)
+
+    def _load_packs(self):
+        pack_files = []
+        try:
+            for name in os.listdir(self.pack_dir):
+                # TODO: verify that idx exists first
+                if name.startswith("pack-") and name.endswith(".pack"):
+                    filename = os.path.join(self.pack_dir, name)
+                    pack_files.append((os.stat(filename).st_mtime, filename))
+        except OSError, e:
+            if e.errno == errno.ENOENT:
+                return []
+            raise
+        pack_files.sort(reverse=True)
+        suffix_len = len(".pack")
+        return [Pack(f[:-suffix_len]) for _, f in pack_files]
+
+    def _get_shafile_path(self, sha):
+        dir = sha[:2]
+        file = sha[2:]
+        # Check from object dir
+        return os.path.join(self.path, dir, file)
+
+    def _iter_loose_objects(self):
+        for base in os.listdir(self.path):
+            if len(base) != 2:
+                continue
+            for rest in os.listdir(os.path.join(self.path, base)):
+                yield base+rest
+
+    def _get_loose_object(self, sha):
+        path = self._get_shafile_path(sha)
+        try:
+            return ShaFile.from_file(path)
+        except OSError, e:
+            if e.errno == errno.ENOENT:
+                return None
+            raise
+
     def move_in_thin_pack(self, path):
     def move_in_thin_pack(self, path):
         """Move a specific file containing a pack into the pack directory.
         """Move a specific file containing a pack into the pack directory.
 
 
@@ -351,7 +396,7 @@ class DiskObjectStore(BaseObjectStore):
         newbasename = os.path.join(self.pack_dir, "pack-%s" % pack_sha)
         newbasename = os.path.join(self.pack_dir, "pack-%s" % pack_sha)
         os.rename(temppath+".pack", newbasename+".pack")
         os.rename(temppath+".pack", newbasename+".pack")
         os.rename(temppath+".idx", newbasename+".idx")
         os.rename(temppath+".idx", newbasename+".idx")
-        self._add_known_pack(newbasename)
+        self._add_known_pack(Pack(newbasename))
 
 
     def move_in_pack(self, path):
     def move_in_pack(self, path):
         """Move a specific file containing a pack into the pack directory.
         """Move a specific file containing a pack into the pack directory.
@@ -368,7 +413,7 @@ class DiskObjectStore(BaseObjectStore):
         write_pack_index_v2(basename+".idx", entries, p.get_stored_checksum())
         write_pack_index_v2(basename+".idx", entries, p.get_stored_checksum())
         p.close()
         p.close()
         os.rename(path, basename + ".pack")
         os.rename(path, basename + ".pack")
-        self._add_known_pack(basename)
+        self._add_known_pack(Pack(basename))
 
 
     def add_thin_pack(self):
     def add_thin_pack(self):
         """Add a new thin pack to this object store.
         """Add a new thin pack to this object store.
@@ -405,19 +450,17 @@ class DiskObjectStore(BaseObjectStore):
 
 
         :param obj: Object to add
         :param obj: Object to add
         """
         """
-        self._add_shafile(obj.id, obj)
-
-    def add_objects(self, objects):
-        """Add a set of objects to this object store.
-
-        :param objects: Iterable over objects, should support __len__.
-        """
-        if len(objects) == 0:
-            # Don't bother writing an empty pack file
-            return
-        f, commit = self.add_pack()
-        write_pack_data(f, objects, len(objects))
-        commit()
+        dir = os.path.join(self.path, obj.id[:2])
+        if not os.path.isdir(dir):
+            os.mkdir(dir)
+        path = os.path.join(dir, obj.id[2:])
+        if os.path.exists(path):
+            return # Already there, no need to write again
+        f = GitFile(path, 'wb')
+        try:
+            f.write(obj.as_legacy_object())
+        finally:
+            f.close()
 
 
 
 
 class MemoryObjectStore(BaseObjectStore):
 class MemoryObjectStore(BaseObjectStore):
@@ -427,14 +470,23 @@ class MemoryObjectStore(BaseObjectStore):
         super(MemoryObjectStore, self).__init__()
         super(MemoryObjectStore, self).__init__()
         self._data = {}
         self._data = {}
 
 
-    def __contains__(self, sha):
-        """Check if the object with a particular SHA is present."""
+    def contains_loose(self, sha):
+        """Check if a particular object is present by SHA1 and is loose."""
         return sha in self._data
         return sha in self._data
 
 
+    def contains_packed(self, sha):
+        """Check if a particular object is present by SHA1 and is packed."""
+        return False
+
     def __iter__(self):
     def __iter__(self):
         """Iterate over the SHAs that are present in this store."""
         """Iterate over the SHAs that are present in this store."""
         return self._data.iterkeys()
         return self._data.iterkeys()
 
 
+    @property
+    def packs(self):
+        """List with pack objects."""
+        return []
+
     def get_raw(self, name):
     def get_raw(self, name):
         """Obtain the raw text for an object.
         """Obtain the raw text for an object.
         
         

+ 23 - 10
dulwich/objects.py

@@ -36,6 +36,7 @@ from dulwich.errors import (
     NotCommitError,
     NotCommitError,
     NotTreeError,
     NotTreeError,
     )
     )
+from dulwich.file import GitFile
 from dulwich.misc import (
 from dulwich.misc import (
     make_sha,
     make_sha,
     )
     )
@@ -184,7 +185,7 @@ class ShaFile(object):
     def from_file(cls, filename):
     def from_file(cls, filename):
         """Get the contents of a SHA file on disk"""
         """Get the contents of a SHA file on disk"""
         size = os.path.getsize(filename)
         size = os.path.getsize(filename)
-        f = open(filename, 'rb')
+        f = GitFile(filename, 'rb')
         try:
         try:
             map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
             map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
             shafile = cls._parse_file(map)
             shafile = cls._parse_file(map)
@@ -205,6 +206,13 @@ class ShaFile(object):
         obj.set_raw_string(string)
         obj.set_raw_string(string)
         return obj
         return obj
 
 
+    @classmethod
+    def from_string(cls, string):
+        """Create a blob from a string."""
+        shafile = cls()
+        shafile.set_raw_string(string)
+        return shafile
+
     def _header(self):
     def _header(self):
         return "%s %lu\0" % (self._type, len(self.as_raw_string()))
         return "%s %lu\0" % (self._type, len(self.as_raw_string()))
 
 
@@ -267,13 +275,6 @@ class Blob(ShaFile):
             raise NotBlobError(filename)
             raise NotBlobError(filename)
         return blob
         return blob
 
 
-    @classmethod
-    def from_string(cls, string):
-        """Create a blob from a string."""
-        shafile = cls()
-        shafile.set_raw_string(string)
-        return shafile
-
 
 
 class Tag(ShaFile):
 class Tag(ShaFile):
     """A Git Tag object."""
     """A Git Tag object."""
@@ -512,6 +513,7 @@ class Commit(ShaFile):
         self._encoding = None
         self._encoding = None
         self._needs_parsing = False
         self._needs_parsing = False
         self._needs_serialization = True
         self._needs_serialization = True
+        self._extra = {}
 
 
     @classmethod
     @classmethod
     def from_file(cls, filename):
     def from_file(cls, filename):
@@ -522,6 +524,7 @@ class Commit(ShaFile):
 
 
     def _parse_text(self):
     def _parse_text(self):
         self._parents = []
         self._parents = []
+        self._extra = []
         self._author = None
         self._author = None
         f = StringIO(self._text)
         f = StringIO(self._text)
         for l in f:
         for l in f:
@@ -545,7 +548,7 @@ class Commit(ShaFile):
             elif field == ENCODING_ID:
             elif field == ENCODING_ID:
                 self._encoding = value
                 self._encoding = value
             else:
             else:
-                raise AssertionError("Unknown field %s" % field)
+                self._extra.append((field, value))
         self._message = f.read()
         self._message = f.read()
         self._needs_parsing = False
         self._needs_parsing = False
 
 
@@ -558,6 +561,10 @@ class Commit(ShaFile):
         f.write("%s %s %s %s\n" % (COMMITTER_ID, self._committer, str(self._commit_time), format_timezone(self._commit_timezone)))
         f.write("%s %s %s %s\n" % (COMMITTER_ID, self._committer, str(self._commit_time), format_timezone(self._commit_timezone)))
         if self.encoding:
         if self.encoding:
             f.write("%s %s\n" % (ENCODING_ID, self.encoding))
             f.write("%s %s\n" % (ENCODING_ID, self.encoding))
+        for k, v in self.extra:
+            if "\n" in k or "\n" in v:
+                raise AssertionError("newline in extra data: %r -> %r" % (k, v))
+            f.write("%s %s\n" % (k, v))
         f.write("\n") # There must be a new line after the headers
         f.write("\n") # There must be a new line after the headers
         f.write(self._message)
         f.write(self._message)
         self._text = f.getvalue()
         self._text = f.getvalue()
@@ -578,6 +585,13 @@ class Commit(ShaFile):
 
 
     parents = property(get_parents, set_parents)
     parents = property(get_parents, set_parents)
 
 
+    def get_extra(self):
+        """Return extra settings of this commit."""
+        self._ensure_parsed()
+        return self._extra
+
+    extra = property(get_extra)
+
     author = serializable_property("author",
     author = serializable_property("author",
         "The name of the author of the commit")
         "The name of the author of the commit")
 
 
@@ -624,4 +638,3 @@ try:
     from dulwich._objects import parse_tree
     from dulwich._objects import parse_tree
 except ImportError:
 except ImportError:
     pass
     pass
-

+ 104 - 88
dulwich/pack.py

@@ -55,6 +55,7 @@ from dulwich.errors import (
     ApplyDeltaError,
     ApplyDeltaError,
     ChecksumMismatch,
     ChecksumMismatch,
     )
     )
+from dulwich.file import GitFile
 from dulwich.lru_cache import (
 from dulwich.lru_cache import (
     LRUSizeCache,
     LRUSizeCache,
     )
     )
@@ -71,53 +72,50 @@ supports_mmap_offset = (sys.version_info[0] >= 3 or
         (sys.version_info[0] == 2 and sys.version_info[1] >= 6))
         (sys.version_info[0] == 2 and sys.version_info[1] >= 6))
 
 
 
 
-def take_msb_bytes(map, offset):
+def take_msb_bytes(read):
     """Read bytes marked with most significant bit.
     """Read bytes marked with most significant bit.
     
     
-    :param map: The buffer.
-    :param offset: Offset in the buffer at which to start reading.
+    :param read: Read function
     """
     """
     ret = []
     ret = []
     while len(ret) == 0 or ret[-1] & 0x80:
     while len(ret) == 0 or ret[-1] & 0x80:
-        ret.append(ord(map[offset]))
-        offset += 1
+        ret.append(ord(read(1)))
     return ret
     return ret
 
 
 
 
-def read_zlib_chunks(data, offset):
+def read_zlib_chunks(read, buffer_size=4096):
     """Read chunks of zlib data from a buffer.
     """Read chunks of zlib data from a buffer.
     
     
-    :param data: Buffer to read from
-    :param offset: Offset at which to start reading
-    :return: Tuple with list of chunks and length of 
-        compressed data length
+    :param read: Read function
+    :return: Tuple with list of chunks, length of 
+        compressed data length and unused read data
     """
     """
     obj = zlib.decompressobj()
     obj = zlib.decompressobj()
     ret = []
     ret = []
     fed = 0
     fed = 0
     while obj.unused_data == "":
     while obj.unused_data == "":
-        base = offset+fed
-        add = data[base:base+1024]
-        if len(add) < 1024:
+        add = read(buffer_size)
+        if len(add) < buffer_size:
             add += "Z"
             add += "Z"
         fed += len(add)
         fed += len(add)
         ret.append(obj.decompress(add))
         ret.append(obj.decompress(add))
     comp_len = fed-len(obj.unused_data)
     comp_len = fed-len(obj.unused_data)
-    return ret, comp_len
+    return ret, comp_len, obj.unused_data
 
 
 
 
-def read_zlib(data, offset, dec_size):
+def read_zlib(read, dec_size):
     """Read zlib-compressed data from a buffer.
     """Read zlib-compressed data from a buffer.
     
     
-    :param data: Buffer
-    :param offset: Offset in the buffer at which to read
+    :param read: Read function
     :param dec_size: Size of the decompressed buffer
     :param dec_size: Size of the decompressed buffer
-    :return: Uncompressed buffer and compressed buffer length.
+    :return: Uncompressed buffer, compressed buffer length and unused read
+        data.
     """
     """
-    ret, comp_len = read_zlib_chunks(data, offset)
+    ret, comp_len, unused = read_zlib_chunks(read)
     x = "".join(ret)
     x = "".join(ret)
     assert len(x) == dec_size
     assert len(x) == dec_size
-    return x, comp_len
+    return x, comp_len, unused
+
 
 
 
 
 def iter_sha1(iter):
 def iter_sha1(iter):
@@ -132,35 +130,31 @@ def iter_sha1(iter):
     return sha1.hexdigest()
     return sha1.hexdigest()
 
 
 
 
-def simple_mmap(f, offset, size, access=mmap.ACCESS_READ):
-    """Simple wrapper for mmap() which always supports the offset parameter.
+def load_pack_index(path):
+    """Load an index file by path.
 
 
-    :param f: File object.
-    :param offset: Offset in the file, from the beginning of the file.
-    :param size: Size of the mmap'ed area
-    :param access: Access mechanism.
-    :return: MMAP'd area.
+    :param filename: Path to the index file
     """
     """
-    mem = mmap.mmap(f.fileno(), size+offset, access=access)
-    return mem, offset
+    f = GitFile(path, 'rb')
+    return load_pack_index_file(path, f)
 
 
 
 
-def load_pack_index(filename):
-    """Load an index file by path.
+def load_pack_index_file(path, f):
+    """Load an index file from a file-like object.
 
 
-    :param filename: Path to the index file
+    :param path: Path for the index file
+    :param f: File-like object
     """
     """
-    f = open(filename, 'rb')
     if f.read(4) == '\377tOc':
     if f.read(4) == '\377tOc':
         version = struct.unpack(">L", f.read(4))[0]
         version = struct.unpack(">L", f.read(4))[0]
         if version == 2:
         if version == 2:
             f.seek(0)
             f.seek(0)
-            return PackIndex2(filename, file=f)
+            return PackIndex2(path, file=f)
         else:
         else:
             raise KeyError("Unknown pack index format %d" % version)
             raise KeyError("Unknown pack index format %d" % version)
     else:
     else:
         f.seek(0)
         f.seek(0)
-        return PackIndex1(filename, file=f)
+        return PackIndex1(path, file=f)
 
 
 
 
 def bisect_find_sha(start, end, sha, unpack_name):
 def bisect_find_sha(start, end, sha, unpack_name):
@@ -200,7 +194,7 @@ class PackIndex(object):
     the start and end offset and then bisect in to find if the value is present.
     the start and end offset and then bisect in to find if the value is present.
     """
     """
   
   
-    def __init__(self, filename, file=None):
+    def __init__(self, filename, file=None, size=None):
         """Create a pack index object.
         """Create a pack index object.
     
     
         Provide it with the name of the index file to consider, and it will map
         Provide it with the name of the index file to consider, and it will map
@@ -209,13 +203,23 @@ class PackIndex(object):
         self._filename = filename
         self._filename = filename
         # Take the size now, so it can be checked each time we map the file to
         # Take the size now, so it can be checked each time we map the file to
         # ensure that it hasn't changed.
         # ensure that it hasn't changed.
-        self._size = os.path.getsize(filename)
         if file is None:
         if file is None:
-            self._file = open(filename, 'rb')
+            self._file = GitFile(filename, 'rb')
         else:
         else:
             self._file = file
             self._file = file
-        self._contents, map_offset = simple_mmap(self._file, 0, self._size)
-        assert map_offset == 0
+        fileno = getattr(self._file, 'fileno', None)
+        if fileno is not None:
+            fd = self._file.fileno()
+            if size is None:
+                self._size = os.fstat(fd).st_size
+            else:
+                self._size = size
+            self._contents = mmap.mmap(fd, self._size,
+                access=mmap.ACCESS_READ)
+        else:
+            self._file.seek(0)
+            self._contents = self._file.read()
+            self._size = len(self._contents)
   
   
     def __eq__(self, other):
     def __eq__(self, other):
         if not isinstance(other, PackIndex):
         if not isinstance(other, PackIndex):
@@ -346,8 +350,8 @@ class PackIndex(object):
 class PackIndex1(PackIndex):
 class PackIndex1(PackIndex):
     """Version 1 Pack Index."""
     """Version 1 Pack Index."""
 
 
-    def __init__(self, filename, file=None):
-        PackIndex.__init__(self, filename, file)
+    def __init__(self, filename, file=None, size=None):
+        PackIndex.__init__(self, filename, file, size)
         self.version = 1
         self.version = 1
         self._fan_out_table = self._read_fan_out_table(0)
         self._fan_out_table = self._read_fan_out_table(0)
 
 
@@ -372,8 +376,8 @@ class PackIndex1(PackIndex):
 class PackIndex2(PackIndex):
 class PackIndex2(PackIndex):
     """Version 2 Pack Index."""
     """Version 2 Pack Index."""
 
 
-    def __init__(self, filename, file=None):
-        PackIndex.__init__(self, filename, file)
+    def __init__(self, filename, file=None, size=None):
+        PackIndex.__init__(self, filename, file, size)
         assert self._contents[:4] == '\377tOc', "Not a v2 pack index file"
         assert self._contents[:4] == '\377tOc', "Not a v2 pack index file"
         (self.version, ) = unpack_from(">L", self._contents, 4)
         (self.version, ) = unpack_from(">L", self._contents, 4)
         assert self.version == 2, "Version was %d" % self.version
         assert self.version == 2, "Version was %d" % self.version
@@ -413,38 +417,40 @@ def read_pack_header(f):
     return (version, num_objects)
     return (version, num_objects)
 
 
 
 
-def unpack_object(map, offset=0):
+def unpack_object(read):
     """Unpack a Git object.
     """Unpack a Git object.
 
 
-    :return: tuple with type, uncompressed data and compressed size
+    :return: tuple with type, uncompressed data, compressed size and 
+        tail data
     """
     """
-    bytes = take_msb_bytes(map, offset)
+    bytes = take_msb_bytes(read)
     type = (bytes[0] >> 4) & 0x07
     type = (bytes[0] >> 4) & 0x07
     size = bytes[0] & 0x0f
     size = bytes[0] & 0x0f
     for i, byte in enumerate(bytes[1:]):
     for i, byte in enumerate(bytes[1:]):
         size += (byte & 0x7f) << ((i * 7) + 4)
         size += (byte & 0x7f) << ((i * 7) + 4)
     raw_base = len(bytes)
     raw_base = len(bytes)
     if type == 6: # offset delta
     if type == 6: # offset delta
-        bytes = take_msb_bytes(map, raw_base + offset)
+        bytes = take_msb_bytes(read)
+        raw_base += len(bytes)
         assert not (bytes[-1] & 0x80)
         assert not (bytes[-1] & 0x80)
         delta_base_offset = bytes[0] & 0x7f
         delta_base_offset = bytes[0] & 0x7f
         for byte in bytes[1:]:
         for byte in bytes[1:]:
             delta_base_offset += 1
             delta_base_offset += 1
             delta_base_offset <<= 7
             delta_base_offset <<= 7
             delta_base_offset += (byte & 0x7f)
             delta_base_offset += (byte & 0x7f)
-        raw_base+=len(bytes)
-        uncomp, comp_len = read_zlib(map, offset + raw_base, size)
+        uncomp, comp_len, unused = read_zlib(read, size)
         assert size == len(uncomp)
         assert size == len(uncomp)
-        return type, (delta_base_offset, uncomp), comp_len+raw_base
+        return type, (delta_base_offset, uncomp), comp_len+raw_base, unused
     elif type == 7: # ref delta
     elif type == 7: # ref delta
-        basename = map[offset+raw_base:offset+raw_base+20]
-        uncomp, comp_len = read_zlib(map, offset+raw_base+20, size)
+        basename = read(20)
+        raw_base += 20
+        uncomp, comp_len, unused = read_zlib(read, size)
         assert size == len(uncomp)
         assert size == len(uncomp)
-        return type, (basename, uncomp), comp_len+raw_base+20
+        return type, (basename, uncomp), comp_len+raw_base, unused
     else:
     else:
-        uncomp, comp_len = read_zlib(map, offset+raw_base, size)
+        uncomp, comp_len, unused = read_zlib(read, size)
         assert len(uncomp) == size
         assert len(uncomp) == size
-        return type, uncomp, comp_len+raw_base
+        return type, uncomp, comp_len+raw_base, unused
 
 
 
 
 def _compute_object_size((num, obj)):
 def _compute_object_size((num, obj)):
@@ -483,7 +489,7 @@ class PackData(object):
     It will all just throw a zlib or KeyError.
     It will all just throw a zlib or KeyError.
     """
     """
   
   
-    def __init__(self, filename):
+    def __init__(self, filename, file=None, size=None):
         """Create a PackData object that represents the pack in the given filename.
         """Create a PackData object that represents the pack in the given filename.
     
     
         The file must exist and stay readable until the object is disposed of. It
         The file must exist and stay readable until the object is disposed of. It
@@ -493,22 +499,33 @@ class PackData(object):
         mmap implementation is flawed.
         mmap implementation is flawed.
         """
         """
         self._filename = filename
         self._filename = filename
-        assert os.path.exists(filename), "%s is not a packfile" % filename
-        self._size = os.path.getsize(filename)
+        self._size = size
         self._header_size = 12
         self._header_size = 12
-        assert self._size >= self._header_size, "%s is too small for a packfile (%d < %d)" % (filename, self._size, self._header_size)
-        self._file = open(self._filename, 'rb')
-        self._read_header()
+        if file is None:
+            self._file = GitFile(self._filename, 'rb')
+        else:
+            self._file = file
+        (version, self._num_objects) = read_pack_header(self._file)
         self._offset_cache = LRUSizeCache(1024*1024*20, 
         self._offset_cache = LRUSizeCache(1024*1024*20, 
             compute_size=_compute_object_size)
             compute_size=_compute_object_size)
 
 
+    @classmethod
+    def from_file(cls, file, size):
+        return cls(str(file), file=file, size=size)
+
+    @classmethod
+    def from_path(cls, path):
+        return cls(filename=path)
+
     def close(self):
     def close(self):
         self._file.close()
         self._file.close()
-  
-    def _read_header(self):
-        (version, self._num_objects) = read_pack_header(self._file)
-        self._file.seek(self._size-20)
-        self._stored_checksum = self._file.read(20)
+
+    def _get_size(self):
+        if self._size is not None:
+            return self._size
+        self._size = os.path.getsize(self._filename)
+        assert self._size >= self._header_size, "%s is too small for a packfile (%d < %d)" % (self._filename, self._size, self._header_size)
+        return self._size
   
   
     def __len__(self):
     def __len__(self):
         """Returns the number of objects in this pack."""
         """Returns the number of objects in this pack."""
@@ -519,11 +536,14 @@ class PackData(object):
 
 
         :return: 20-byte binary SHA1 digest
         :return: 20-byte binary SHA1 digest
         """
         """
-        map, map_offset = simple_mmap(self._file, 0, self._size - 20)
-        try:
-            return make_sha(map[map_offset:self._size-20]).digest()
-        finally:
-            map.close()
+        s = make_sha()
+        self._file.seek(0)
+        todo = self._get_size() - 20
+        while todo > 0:
+            x = self._file.read(min(todo, 1<<16))
+            s.update(x)
+            todo -= len(x)
+        return s.digest()
 
 
     def resolve_object(self, offset, type, obj, get_ref, get_offset=None):
     def resolve_object(self, offset, type, obj, get_ref, get_offset=None):
         """Resolve an object, possibly resolving deltas when necessary.
         """Resolve an object, possibly resolving deltas when necessary.
@@ -566,10 +586,7 @@ class PackData(object):
                 self.i = 0
                 self.i = 0
                 self.offset = pack._header_size
                 self.offset = pack._header_size
                 self.num = len(pack)
                 self.num = len(pack)
-                self.map, _ = simple_mmap(pack._file, 0, pack._size)
-
-            def __del__(self):
-                self.map.close()
+                self.map = pack._file
 
 
             def __iter__(self):
             def __iter__(self):
                 return self
                 return self
@@ -580,8 +597,10 @@ class PackData(object):
             def next(self):
             def next(self):
                 if self.i == self.num:
                 if self.i == self.num:
                     raise StopIteration
                     raise StopIteration
-                (type, obj, total_size) = unpack_object(self.map, self.offset)
-                crc32 = zlib.crc32(self.map[self.offset:self.offset+total_size]) & 0xffffffff
+                self.map.seek(self.offset)
+                (type, obj, total_size, unused) = unpack_object(self.map.read)
+                self.map.seek(self.offset)
+                crc32 = zlib.crc32(self.map.read(total_size)) & 0xffffffff
                 ret = (self.offset, type, obj, crc32)
                 ret = (self.offset, type, obj, crc32)
                 self.offset += total_size
                 self.offset += total_size
                 if progress:
                 if progress:
@@ -687,7 +706,8 @@ class PackData(object):
   
   
     def get_stored_checksum(self):
     def get_stored_checksum(self):
         """Return the expected checksum stored in this pack."""
         """Return the expected checksum stored in this pack."""
-        return self._stored_checksum
+        self._file.seek(self._get_size()-20)
+        return self._file.read(20)
   
   
     def check(self):
     def check(self):
         """Check the consistency of this pack."""
         """Check the consistency of this pack."""
@@ -705,12 +725,8 @@ class PackData(object):
         assert isinstance(offset, long) or isinstance(offset, int),\
         assert isinstance(offset, long) or isinstance(offset, int),\
                 "offset was %r" % offset
                 "offset was %r" % offset
         assert offset >= self._header_size
         assert offset >= self._header_size
-        map, map_offset = simple_mmap(self._file, offset, self._size-offset)
-        try:
-            ret = unpack_object(map, map_offset)[:2]
-            return ret
-        finally:
-            map.close()
+        self._file.seek(offset)
+        return unpack_object(self._file.read)[:2]
 
 
 
 
 class SHA1Reader(object):
 class SHA1Reader(object):
@@ -809,7 +825,7 @@ def write_pack(filename, objects, num_objects):
     :param objects: Iterable over (object, path) tuples to write
     :param objects: Iterable over (object, path) tuples to write
     :param num_objects: Number of objects to write
     :param num_objects: Number of objects to write
     """
     """
-    f = open(filename + ".pack", 'wb')
+    f = GitFile(filename + ".pack", 'wb')
     try:
     try:
         entries, data_sum = write_pack_data(f, objects, num_objects)
         entries, data_sum = write_pack_data(f, objects, num_objects)
     finally:
     finally:
@@ -873,7 +889,7 @@ def write_pack_index_v1(filename, entries, pack_checksum):
             crc32_checksum.
             crc32_checksum.
     :param pack_checksum: Checksum of the pack file.
     :param pack_checksum: Checksum of the pack file.
     """
     """
-    f = open(filename, 'wb')
+    f = GitFile(filename, 'wb')
     f = SHA1Writer(f)
     f = SHA1Writer(f)
     fan_out_table = defaultdict(lambda: 0)
     fan_out_table = defaultdict(lambda: 0)
     for (name, offset, entry_checksum) in entries:
     for (name, offset, entry_checksum) in entries:
@@ -1021,7 +1037,7 @@ def write_pack_index_v2(filename, entries, pack_checksum):
             crc32_checksum.
             crc32_checksum.
     :param pack_checksum: Checksum of the pack file.
     :param pack_checksum: Checksum of the pack file.
     """
     """
-    f = open(filename, 'wb')
+    f = GitFile(filename, 'wb')
     f = SHA1Writer(f)
     f = SHA1Writer(f)
     f.write('\377tOc') # Magic!
     f.write('\377tOc') # Magic!
     f.write(struct.pack(">L", 2))
     f.write(struct.pack(">L", 2))

+ 33 - 5
dulwich/protocol.py

@@ -28,6 +28,10 @@ from dulwich.errors import (
 
 
 TCP_GIT_PORT = 9418
 TCP_GIT_PORT = 9418
 
 
+SINGLE_ACK = 0
+MULTI_ACK = 1
+MULTI_ACK_DETAILED = 2
+
 class ProtocolFile(object):
 class ProtocolFile(object):
     """
     """
     Some network ops are like file ops. The file ops expect to operate on
     Some network ops are like file ops. The file ops expect to operate on
@@ -160,11 +164,35 @@ def extract_capabilities(text):
     """Extract a capabilities list from a string, if present.
     """Extract a capabilities list from a string, if present.
 
 
     :param text: String to extract from
     :param text: String to extract from
-    :return: Tuple with text with capabilities removed and list of 
-        capabilities or None (if no capabilities were present.
+    :return: Tuple with text with capabilities removed and list of capabilities
     """
     """
     if not "\0" in text:
     if not "\0" in text:
-        return text, None
-    capabilities = text.split("\0")
-    return (capabilities[0], capabilities[1:])
+        return text, []
+    text, capabilities = text.rstrip().split("\0")
+    return (text, capabilities.split(" "))
+
+
+def extract_want_line_capabilities(text):
+    """Extract a capabilities list from a want line, if present.
 
 
+    Note that want lines have capabilities separated from the rest of the line
+    by a space instead of a null byte. Thus want lines have the form:
+
+        want obj-id cap1 cap2 ...
+
+    :param text: Want line to extract from
+    :return: Tuple with text with capabilities removed and list of capabilities
+    """
+    split_text = text.rstrip().split(" ")
+    if len(split_text) < 3:
+        return text, []
+    return (" ".join(split_text[:2]), split_text[2:])
+
+
+def ack_type(capabilities):
+    """Extract the ack type from a capabilities list."""
+    if 'multi_ack_detailed' in capabilities:
+      return MULTI_ACK_DETAILED
+    elif 'multi_ack' in capabilities:
+        return MULTI_ACK
+    return SINGLE_ACK

+ 565 - 166
dulwich/repo.py

@@ -22,15 +22,21 @@
 """Repository access."""
 """Repository access."""
 
 
 
 
+import errno
 import os
 import os
-import stat
 
 
 from dulwich.errors import (
 from dulwich.errors import (
     MissingCommitError, 
     MissingCommitError, 
+    NoIndexPresent,
     NotBlobError, 
     NotBlobError, 
     NotCommitError, 
     NotCommitError, 
     NotGitRepository,
     NotGitRepository,
     NotTreeError, 
     NotTreeError, 
+    PackedRefsException,
+    )
+from dulwich.file import (
+    ensure_dir_exists,
+    GitFile,
     )
     )
 from dulwich.object_store import (
 from dulwich.object_store import (
     DiskObjectStore,
     DiskObjectStore,
@@ -41,6 +47,7 @@ from dulwich.objects import (
     ShaFile,
     ShaFile,
     Tag,
     Tag,
     Tree,
     Tree,
+    hex_to_sha,
     )
     )
 
 
 OBJECTDIR = 'objects'
 OBJECTDIR = 'objects'
@@ -50,104 +57,228 @@ REFSDIR_TAGS = 'tags'
 REFSDIR_HEADS = 'heads'
 REFSDIR_HEADS = 'heads'
 INDEX_FILENAME = "index"
 INDEX_FILENAME = "index"
 
 
+BASE_DIRECTORIES = [
+    [OBJECTDIR], 
+    [OBJECTDIR, "info"], 
+    [OBJECTDIR, "pack"],
+    ["branches"],
+    [REFSDIR],
+    [REFSDIR, REFSDIR_TAGS],
+    [REFSDIR, REFSDIR_HEADS],
+    ["hooks"],
+    ["info"]
+    ]
+
+
+def read_info_refs(f):
+    ret = {}
+    for l in f.readlines():
+        (sha, name) = l.rstrip("\n").split("\t", 1)
+        ret[name] = sha
+    return ret
+
+
+def check_ref_format(refname):
+    """Check if a refname is correctly formatted.
 
 
-def follow_ref(container, name):
-    """Follow a ref back to a SHA1.
-    
-    :param container: Ref container to use for looking up refs.
-    :param name: Name of the original ref.
+    Implements all the same rules as git-check-ref-format[1].
+
+    [1] http://www.kernel.org/pub/software/scm/git/docs/git-check-ref-format.html
+
+    :param refname: The refname to check
+    :return: True if refname is valid, False otherwise
     """
     """
-    contents = container[name]
-    if contents.startswith(SYMREF):
-        ref = contents[len(SYMREF):]
-        if ref[-1] == '\n':
-            ref = ref[:-1]
-        return follow_ref(container, ref)
-    assert len(contents) == 40, 'Invalid ref in %s' % name
-    return contents
+    # These could be combined into one big expression, but are listed separately
+    # to parallel [1].
+    if '/.' in refname or refname.startswith('.'):
+        return False
+    if '/' not in refname:
+        return False
+    if '..' in refname:
+        return False
+    for c in refname:
+        if ord(c) < 040 or c in '\177 ~^:?*[':
+            return False
+    if refname[-1] in '/.':
+        return False
+    if refname.endswith('.lock'):
+        return False
+    if '@{' in refname:
+        return False
+    if '\\' in refname:
+        return False
+    return True
 
 
 
 
 class RefsContainer(object):
 class RefsContainer(object):
     """A container for refs."""
     """A container for refs."""
 
 
-    def as_dict(self, base):
-        """Return the contents of this ref container under base as a dict."""
-        raise NotImplementedError(self.as_dict)
-
-    def follow(self, name):
-        """Follow a ref name back to a SHA1.
-        
-        :param name: Name of the ref
-        """
-        return follow_ref(self, name)
-
     def set_ref(self, name, other):
     def set_ref(self, name, other):
         """Make a ref point at another ref.
         """Make a ref point at another ref.
 
 
         :param name: Name of the ref to set
         :param name: Name of the ref to set
         :param other: Name of the ref to point at
         :param other: Name of the ref to point at
         """
         """
-        self[name] = "ref: %s\n" % other
+        self[name] = SYMREF + other + '\n'
+
+    def get_packed_refs(self):
+        """Get contents of the packed-refs file.
+
+        :return: Dictionary mapping ref names to SHA1s
+
+        :note: Will return an empty dictionary when no packed-refs file is
+            present.
+        """
+        raise NotImplementedError(self.get_packed_refs)
 
 
     def import_refs(self, base, other):
     def import_refs(self, base, other):
         for name, value in other.iteritems():
         for name, value in other.iteritems():
             self["%s/%s" % (base, name)] = value
             self["%s/%s" % (base, name)] = value
 
 
+    def keys(self, base=None):
+        """Refs present in this container.
+
+        :param base: An optional base to return refs under
+        :return: An unsorted set of valid refs in this container, including
+            packed refs.
+        """
+        if base is not None:
+            return self.subkeys(base)
+        else:
+            return self.allkeys()
+
+    def subkeys(self, base):
+        keys = set()
+        for refname in self.allkeys():
+            if refname.startswith(base):
+                keys.add(refname)
+        return keys
+
+    def as_dict(self, base=None):
+        """Return the contents of this container as a dictionary.
+
+        """
+        ret = {}
+        keys = self.keys(base)
+        if base is None:
+            base = ""
+        for key in keys:
+            try:
+                ret[key] = self[("%s/%s" % (base, key)).strip("/")]
+            except KeyError:
+                continue # Unable to resolve
+
+        return ret
+
+    def _check_refname(self, name):
+        """Ensure a refname is valid and lives in refs or is HEAD.
+
+        HEAD is not a valid refname according to git-check-ref-format, but this
+        class needs to be able to touch HEAD. Also, check_ref_format expects
+        refnames without the leading 'refs/', but this class requires that
+        so it cannot touch anything outside the refs dir (or HEAD).
+
+        :param name: The name of the reference.
+        :raises KeyError: if a refname is not HEAD or is otherwise not valid.
+        """
+        if name == 'HEAD':
+            return
+        if not name.startswith('refs/') or not check_ref_format(name[5:]):
+            raise KeyError(name)
+
+    def read_loose_ref(self, name):
+        """Read a loose reference and return its contents.
+
+        :param name: the refname to read
+        :return: The contents of the ref file, or None if it does 
+            not exist.
+        """
+        raise NotImplementedError(self.read_loose_ref)
+
+    def _follow(self, name):
+        """Follow a reference name.
+
+        :return: a tuple of (refname, sha), where refname is the name of the
+            last reference in the symbolic reference chain
+        """
+        self._check_refname(name)
+        contents = SYMREF + name
+        depth = 0
+        while contents.startswith(SYMREF):
+            refname = contents[len(SYMREF):]
+            contents = self.read_loose_ref(refname)
+            if not contents:
+                contents = self.get_packed_refs().get(refname, None)
+                if not contents:
+                    break
+            depth += 1
+            if depth > 5:
+                raise KeyError(name)
+        return refname, contents
+
+    def __getitem__(self, name):
+        """Get the SHA1 for a reference name.
+
+        This method follows all symbolic references.
+        """
+        _, sha = self._follow(name)
+        if sha is None:
+            raise KeyError(name)
+        return sha
+
+
+class DictRefsContainer(RefsContainer):
+
+    def __init__(self, refs):
+        self._refs = refs
+
+    def allkeys(self):
+        return self._refs.keys()
+
+    def read_loose_ref(self, name):
+        return self._refs[name]
+
 
 
 class DiskRefsContainer(RefsContainer):
 class DiskRefsContainer(RefsContainer):
     """Refs container that reads refs from disk."""
     """Refs container that reads refs from disk."""
 
 
     def __init__(self, path):
     def __init__(self, path):
         self.path = path
         self.path = path
+        self._packed_refs = None
+        self._peeled_refs = {}
 
 
     def __repr__(self):
     def __repr__(self):
         return "%s(%r)" % (self.__class__.__name__, self.path)
         return "%s(%r)" % (self.__class__.__name__, self.path)
 
 
-    def keys(self, base=None):
-        """Refs present in this container."""
-        return list(self.iterkeys(base))
-
-    def iterkeys(self, base=None):
-        if base is not None:
-            return self.itersubkeys(base)
-        else:
-            return self.iterallkeys()
-
-    def itersubkeys(self, base):
+    def subkeys(self, base):
+        keys = set()
         path = self.refpath(base)
         path = self.refpath(base)
         for root, dirs, files in os.walk(path):
         for root, dirs, files in os.walk(path):
-            dir = root[len(path):].strip("/").replace(os.path.sep, "/")
+            dir = root[len(path):].strip(os.path.sep).replace(os.path.sep, "/")
             for filename in files:
             for filename in files:
-                yield ("%s/%s" % (dir, filename)).strip("/")
-
-    def iterallkeys(self):
+                refname = ("%s/%s" % (dir, filename)).strip("/")
+                # check_ref_format requires at least one /, so we prepend the
+                # base before calling it.
+                if check_ref_format("%s/%s" % (base, refname)):
+                    keys.add(refname)
+        for key in self.get_packed_refs():
+            if key.startswith(base):
+                keys.add(key[len(base):].strip("/"))
+        return keys
+
+    def allkeys(self):
+        keys = set()
         if os.path.exists(self.refpath("HEAD")):
         if os.path.exists(self.refpath("HEAD")):
-            yield "HEAD"
+            keys.add("HEAD")
         path = self.refpath("")
         path = self.refpath("")
         for root, dirs, files in os.walk(self.refpath("refs")):
         for root, dirs, files in os.walk(self.refpath("refs")):
-            dir = root[len(path):].strip("/").replace(os.path.sep, "/")
+            dir = root[len(path):].strip(os.path.sep).replace(os.path.sep, "/")
             for filename in files:
             for filename in files:
-                yield ("%s/%s" % (dir, filename)).strip("/")
-
-    def as_dict(self, base=None, follow=True):
-        """Return the contents of this container as a dictionary.
-
-        """
-        ret = {}
-        if base is None:
-            keys = self.iterkeys()
-            base = ""
-        else:
-            keys = self.itersubkeys(base)
-        for key in keys:
-                if follow:
-                    try:
-                        ret[key] = self.follow(("%s/%s" % (base, key)).strip("/"))
-                    except KeyError:
-                        continue # Unable to resolve
-                else:
-                    ret[key] = self[("%s/%s" % (base, key)).strip("/")]
-        return ret
+                refname = ("%s/%s" % (dir, filename)).strip("/")
+                if check_ref_format(refname):
+                    keys.add(refname)
+        keys.update(self.get_packed_refs())
+        return keys
 
 
     def refpath(self, name):
     def refpath(self, name):
         """Return the disk path of a ref.
         """Return the disk path of a ref.
@@ -157,90 +288,318 @@ class DiskRefsContainer(RefsContainer):
             name = name.replace("/", os.path.sep)
             name = name.replace("/", os.path.sep)
         return os.path.join(self.path, name)
         return os.path.join(self.path, name)
 
 
-    def __getitem__(self, name):
-        file = self.refpath(name)
-        if not os.path.exists(file):
-            raise KeyError(name)
-        f = open(file, 'rb')
+    def get_packed_refs(self):
+        """Get contents of the packed-refs file.
+
+        :return: Dictionary mapping ref names to SHA1s
+
+        :note: Will return an empty dictionary when no packed-refs file is
+            present.
+        """
+        # TODO: invalidate the cache on repacking
+        if self._packed_refs is None:
+            self._packed_refs = {}
+            path = os.path.join(self.path, 'packed-refs')
+            try:
+                f = GitFile(path, 'rb')
+            except IOError, e:
+                if e.errno == errno.ENOENT:
+                    return {}
+                raise
+            try:
+                first_line = iter(f).next().rstrip()
+                if (first_line.startswith("# pack-refs") and " peeled" in
+                        first_line):
+                    for sha, name, peeled in read_packed_refs_with_peeled(f):
+                        self._packed_refs[name] = sha
+                        if peeled:
+                            self._peeled_refs[name] = peeled
+                else:
+                    f.seek(0)
+                    for sha, name in read_packed_refs(f):
+                        self._packed_refs[name] = sha
+            finally:
+                f.close()
+        return self._packed_refs
+
+    def read_loose_ref(self, name):
+        """Read a reference file and return its contents.
+
+        If the reference file a symbolic reference, only read the first line of
+        the file. Otherwise, only read the first 40 bytes.
+
+        :param name: the refname to read, relative to refpath
+        :return: The contents of the ref file, or None if the file does not
+            exist.
+        :raises IOError: if any other error occurs
+        """
+        filename = self.refpath(name)
+        try:
+            f = GitFile(filename, 'rb')
+            try:
+                header = f.read(len(SYMREF))
+                if header == SYMREF:
+                    # Read only the first line
+                    return header + iter(f).next().rstrip("\n")
+                else:
+                    # Read only the first 40 bytes
+                    return header + f.read(40-len(SYMREF))
+            finally:
+                f.close()
+        except IOError, e:
+            if e.errno == errno.ENOENT:
+                return None
+            raise
+
+    def _remove_packed_ref(self, name):
+        if self._packed_refs is None:
+            return
+        filename = os.path.join(self.path, 'packed-refs')
+        # reread cached refs from disk, while holding the lock
+        f = GitFile(filename, 'wb')
         try:
         try:
-            return f.read().strip("\n")
+            self._packed_refs = None
+            self.get_packed_refs()
+
+            if name not in self._packed_refs:
+                return
+
+            del self._packed_refs[name]
+            if name in self._peeled_refs:
+                del self._peeled_refs[name]
+            write_packed_refs(f, self._packed_refs, self._peeled_refs)
+            f.close()
+        finally:
+            f.abort()
+
+    def set_if_equals(self, name, old_ref, new_ref):
+        """Set a refname to new_ref only if it currently equals old_ref.
+
+        This method follows all symbolic references, and can be used to perform
+        an atomic compare-and-swap operation.
+
+        :param name: The refname to set.
+        :param old_ref: The old sha the refname must refer to, or None to set
+            unconditionally.
+        :param new_ref: The new sha the refname will refer to.
+        :return: True if the set was successful, False otherwise.
+        """
+        try:
+            realname, _ = self._follow(name)
+        except KeyError:
+            realname = name
+        filename = self.refpath(realname)
+        ensure_dir_exists(os.path.dirname(filename))
+        f = GitFile(filename, 'wb')
+        try:
+            if old_ref is not None:
+                try:
+                    # read again while holding the lock
+                    orig_ref = self.read_loose_ref(realname)
+                    if orig_ref is None:
+                        orig_ref = self.get_packed_refs().get(realname, None)
+                    if orig_ref != old_ref:
+                        f.abort()
+                        return False
+                except (OSError, IOError):
+                    f.abort()
+                    raise
+            try:
+                f.write(new_ref+"\n")
+            except (OSError, IOError):
+                f.abort()
+                raise
         finally:
         finally:
             f.close()
             f.close()
+        return True
+
+    def add_if_new(self, name, ref):
+        """Add a new reference only if it does not already exist."""
+        self._check_refname(name)
+        filename = self.refpath(name)
+        ensure_dir_exists(os.path.dirname(filename))
+        f = GitFile(filename, 'wb')
+        try:
+            if os.path.exists(filename) or name in self.get_packed_refs():
+                f.abort()
+                return False
+            try:
+                f.write(ref+"\n")
+            except (OSError, IOError):
+                f.abort()
+                raise
+        finally:
+            f.close()
+        return True
 
 
     def __setitem__(self, name, ref):
     def __setitem__(self, name, ref):
-        file = self.refpath(name)
-        dirpath = os.path.dirname(file)
-        if not os.path.exists(dirpath):
-            os.makedirs(dirpath)
-        f = open(file, 'wb')
+        """Set a reference name to point to the given SHA1.
+
+        This method follows all symbolic references.
+
+        :note: This method unconditionally overwrites the contents of a reference
+            on disk. To update atomically only if the reference has not changed
+            on disk, use set_if_equals().
+        """
+        self.set_if_equals(name, None, ref)
+
+    def remove_if_equals(self, name, old_ref):
+        """Remove a refname only if it currently equals old_ref.
+
+        This method does not follow symbolic references. It can be used to
+        perform an atomic compare-and-delete operation.
+
+        :param name: The refname to delete.
+        :param old_ref: The old sha the refname must refer to, or None to delete
+            unconditionally.
+        :return: True if the delete was successful, False otherwise.
+        """
+        self._check_refname(name)
+        filename = self.refpath(name)
+        ensure_dir_exists(os.path.dirname(filename))
+        f = GitFile(filename, 'wb')
         try:
         try:
-            f.write(ref+"\n")
+            if old_ref is not None:
+                orig_ref = self.read_loose_ref(name)
+                if orig_ref is None:
+                    orig_ref = self.get_packed_refs().get(name, None)
+                if orig_ref != old_ref:
+                    return False
+            # may only be packed
+            try:
+                os.remove(filename)
+            except OSError, e:
+                if e.errno != errno.ENOENT:
+                    raise
+            self._remove_packed_ref(name)
         finally:
         finally:
-            f.close()
+            # never write, we just wanted the lock
+            f.abort()
+        return True
 
 
     def __delitem__(self, name):
     def __delitem__(self, name):
-        file = self.refpath(name)
-        if os.path.exists(file):
-            os.remove(file)
+        """Remove a refname.
+
+        This method does not follow symbolic references.
+        :note: This method unconditionally deletes the contents of a reference
+            on disk. To delete atomically only if the reference has not changed
+            on disk, use set_if_equals().
+        """
+        self.remove_if_equals(name, None)
+
+
+def _split_ref_line(line):
+    """Split a single ref line into a tuple of SHA1 and name."""
+    fields = line.rstrip("\n").split(" ")
+    if len(fields) != 2:
+        raise PackedRefsException("invalid ref line '%s'" % line)
+    sha, name = fields
+    try:
+        hex_to_sha(sha)
+    except (AssertionError, TypeError), e:
+        raise PackedRefsException(e)
+    if not check_ref_format(name):
+        raise PackedRefsException("invalid ref name '%s'" % name)
+    return (sha, name)
 
 
 
 
 def read_packed_refs(f):
 def read_packed_refs(f):
     """Read a packed refs file.
     """Read a packed refs file.
 
 
-    Yields tuples with ref names and SHA1s.
+    Yields tuples with SHA1s and ref names.
 
 
     :param f: file-like object to read from
     :param f: file-like object to read from
     """
     """
-    l = f.readline()
-    for l in f.readlines():
+    for l in f:
         if l[0] == "#":
         if l[0] == "#":
             # Comment
             # Comment
             continue
             continue
         if l[0] == "^":
         if l[0] == "^":
-            # FIXME: Return somehow
+            raise PackedRefsException(
+                "found peeled ref in packed-refs without peeled")
+        yield _split_ref_line(l)
+
+
+def read_packed_refs_with_peeled(f):
+    """Read a packed refs file including peeled refs.
+
+    Assumes the "# pack-refs with: peeled" line was already read. Yields tuples
+    with ref names, SHA1s, and peeled SHA1s (or None).
+
+    :param f: file-like object to read from, seek'ed to the second line
+    """
+    last = None
+    for l in f:
+        if l[0] == "#":
             continue
             continue
-        yield tuple(l.rstrip("\n").split(" ", 2))
+        l = l.rstrip("\n")
+        if l[0] == "^":
+            if not last:
+                raise PackedRefsException("unexpected peeled ref line")
+            try:
+                hex_to_sha(l[1:])
+            except (AssertionError, TypeError), e:
+                raise PackedRefsException(e)
+            sha, name = _split_ref_line(last)
+            last = None
+            yield (sha, name, l[1:])
+        else:
+            if last:
+                sha, name = _split_ref_line(last)
+                yield (sha, name, None)
+            last = l
+    if last:
+        sha, name = _split_ref_line(last)
+        yield (sha, name, None)
+
 
 
+def write_packed_refs(f, packed_refs, peeled_refs=None):
+    """Write a packed refs file.
+
+    :param f: empty file-like object to write to
+    :param packed_refs: dict of refname to sha of packed refs to write
+    """
+    if peeled_refs is None:
+        peeled_refs = {}
+    else:
+        f.write('# pack-refs with: peeled\n')
+    for refname in sorted(packed_refs.iterkeys()):
+        f.write('%s %s\n' % (packed_refs[refname], refname))
+        if refname in peeled_refs:
+            f.write('^%s\n' % peeled_refs[refname])
+
+
+class BaseRepo(object):
+    """Base class for a git repository.
 
 
-class Repo(object):
-    """A local git repository.
-    
-    :ivar refs: Dictionary with the refs in this repository
     :ivar object_store: Dictionary-like object for accessing
     :ivar object_store: Dictionary-like object for accessing
         the objects
         the objects
+    :ivar refs: Dictionary-like object with the refs in this repository
     """
     """
 
 
-    def __init__(self, root):
-        if os.path.isdir(os.path.join(root, ".git", OBJECTDIR)):
-            self.bare = False
-            self._controldir = os.path.join(root, ".git")
-        elif (os.path.isdir(os.path.join(root, OBJECTDIR)) and
-              os.path.isdir(os.path.join(root, REFSDIR))):
-            self.bare = True
-            self._controldir = root
-        else:
-            raise NotGitRepository(root)
-        self.path = root
-        self.refs = DiskRefsContainer(self.controldir())
-        self.object_store = DiskObjectStore(
-            os.path.join(self.controldir(), OBJECTDIR))
+    def __init__(self, object_store, refs):
+        self.object_store = object_store
+        self.refs = refs
 
 
-    def controldir(self):
-        """Return the path of the control directory."""
-        return self._controldir
+    def get_named_file(self, path):
+        """Get a file from the control dir with a specific name.
 
 
-    def index_path(self):
-        """Return path to the index file."""
-        return os.path.join(self.controldir(), INDEX_FILENAME)
+        Although the filename should be interpreted as a filename relative to
+        the control dir in a disk-baked Repo, the object returned need not be
+        pointing to a file in that location.
 
 
-    def open_index(self):
-        """Open the index for this repository."""
-        from dulwich.index import Index
-        return Index(self.index_path())
+        :param path: The path to the file, relative to the control dir.
+        :return: An open file object, or None if the file does not exist.
+        """
+        raise NotImplementedError(self.get_named_file)
 
 
-    def has_index(self):
-        """Check if an index is present."""
-        return os.path.exists(self.index_path())
+    def open_index(self):
+        """Open the index for this repository.
+        
+        :raises NoIndexPresent: If no index is present
+        :return: Index instance
+        """
+        raise NotImplementedError(self.open_index)
 
 
     def fetch(self, target, determine_wants=None, progress=None):
     def fetch(self, target, determine_wants=None, progress=None):
         """Fetch objects into another repository.
         """Fetch objects into another repository.
@@ -250,6 +609,8 @@ class Repo(object):
             fetch.
             fetch.
         :param progress: Optional progress function
         :param progress: Optional progress function
         """
         """
+        if determine_wants is None:
+            determine_wants = lambda heads: heads.values()
         target.object_store.add_objects(
         target.object_store.add_objects(
             self.fetch_objects(determine_wants, target.get_graph_walker(),
             self.fetch_objects(determine_wants, target.get_graph_walker(),
                 progress))
                 progress))
@@ -279,46 +640,15 @@ class Repo(object):
 
 
     def ref(self, name):
     def ref(self, name):
         """Return the SHA1 a ref is pointing to."""
         """Return the SHA1 a ref is pointing to."""
-        try:
-            return self.refs.follow(name)
-        except KeyError:
-            return self.get_packed_refs()[name]
+        return self.refs[name]
 
 
     def get_refs(self):
     def get_refs(self):
         """Get dictionary with all refs."""
         """Get dictionary with all refs."""
-        ret = {}
-        try:
-            if self.head():
-                ret['HEAD'] = self.head()
-        except KeyError:
-            pass
-        ret.update(self.refs.as_dict())
-        ret.update(self.get_packed_refs())
-        return ret
-
-    def get_packed_refs(self):
-        """Get contents of the packed-refs file.
-
-        :return: Dictionary mapping ref names to SHA1s
-
-        :note: Will return an empty dictionary when no packed-refs file is 
-            present.
-        """
-        path = os.path.join(self.controldir(), 'packed-refs')
-        if not os.path.exists(path):
-            return {}
-        ret = {}
-        f = open(path, 'rb')
-        try:
-            for entry in read_packed_refs(f):
-                ret[entry[1]] = entry[0]
-            return ret
-        finally:
-            f.close()
+        return self.refs.as_dict()
 
 
     def head(self):
     def head(self):
         """Return the SHA1 pointed at by HEAD."""
         """Return the SHA1 pointed at by HEAD."""
-        return self.refs.follow('HEAD')
+        return self.refs['HEAD']
 
 
     def _get_object(self, sha, cls):
     def _get_object(self, sha, cls):
         assert len(sha) in (20, 40)
         assert len(sha) in (20, 40)
@@ -340,6 +670,13 @@ class Repo(object):
     def get_parents(self, sha):
     def get_parents(self, sha):
         return self.commit(sha).parents
         return self.commit(sha).parents
 
 
+    def get_config(self):
+        import ConfigParser
+        p = ConfigParser.RawConfigParser()
+        p.read(os.path.join(self._controldir, 'config'))
+        return dict((section, dict(p.items(section)))
+                    for section in p.sections())
+
     def commit(self, sha):
     def commit(self, sha):
         return self._get_object(sha, Commit)
         return self._get_object(sha, Commit)
 
 
@@ -386,14 +723,11 @@ class Repo(object):
         history.reverse()
         history.reverse()
         return history
         return history
 
 
-    def __repr__(self):
-        return "<Repo at %r>" % self.path
-
     def __getitem__(self, name):
     def __getitem__(self, name):
         if len(name) in (20, 40):
         if len(name) in (20, 40):
             return self.object_store[name]
             return self.object_store[name]
         return self.object_store[self.refs[name]]
         return self.object_store[self.refs[name]]
-    
+
     def __setitem__(self, name, value):
     def __setitem__(self, name, value):
         if name.startswith("refs/") or name == "HEAD":
         if name.startswith("refs/") or name == "HEAD":
             if isinstance(value, ShaFile):
             if isinstance(value, ShaFile):
@@ -455,6 +789,74 @@ class Repo(object):
         self.refs["HEAD"] = c.id
         self.refs["HEAD"] = c.id
         return c.id
         return c.id
 
 
+
+class Repo(BaseRepo):
+    """A git repository backed by local disk."""
+
+    def __init__(self, root):
+        if os.path.isdir(os.path.join(root, ".git", OBJECTDIR)):
+            self.bare = False
+            self._controldir = os.path.join(root, ".git")
+        elif (os.path.isdir(os.path.join(root, OBJECTDIR)) and
+              os.path.isdir(os.path.join(root, REFSDIR))):
+            self.bare = True
+            self._controldir = root
+        else:
+            raise NotGitRepository(root)
+        self.path = root
+        object_store = DiskObjectStore(
+            os.path.join(self.controldir(), OBJECTDIR))
+        refs = DiskRefsContainer(self.controldir())
+        BaseRepo.__init__(self, object_store, refs)
+
+    def controldir(self):
+        """Return the path of the control directory."""
+        return self._controldir
+
+    def _put_named_file(self, path, contents):
+        """Write a file from the control dir with a specific name and contents.
+        """
+        f = GitFile(os.path.join(self.controldir(), path), 'wb')
+        try:
+            f.write(contents)
+        finally:
+            f.close()
+
+    def get_named_file(self, path):
+        """Get a file from the control dir with a specific name.
+
+        Although the filename should be interpreted as a filename relative to
+        the control dir in a disk-baked Repo, the object returned need not be
+        pointing to a file in that location.
+
+        :param path: The path to the file, relative to the control dir.
+        :return: An open file object, or None if the file does not exist.
+        """
+        try:
+            return open(os.path.join(self.controldir(), path.lstrip('/')), 'rb')
+        except (IOError, OSError), e:
+            if e.errno == errno.ENOENT:
+                return None
+            raise
+
+    def index_path(self):
+        """Return path to the index file."""
+        return os.path.join(self.controldir(), INDEX_FILENAME)
+
+    def open_index(self):
+        """Open the index for this repository."""
+        from dulwich.index import Index
+        if not self.has_index():
+            raise NoIndexPresent()
+        return Index(self.index_path())
+
+    def has_index(self):
+        """Check if an index is present."""
+        return os.path.exists(self.index_path())
+
+    def __repr__(self):
+        return "<Repo at %r>" % self.path
+
     @classmethod
     @classmethod
     def init(cls, path, mkdir=True):
     def init(cls, path, mkdir=True):
         controldir = os.path.join(path, ".git")
         controldir = os.path.join(path, ".git")
@@ -464,21 +866,18 @@ class Repo(object):
 
 
     @classmethod
     @classmethod
     def init_bare(cls, path, mkdir=True):
     def init_bare(cls, path, mkdir=True):
-        for d in [[OBJECTDIR], 
-                  [OBJECTDIR, "info"], 
-                  [OBJECTDIR, "pack"],
-                  ["branches"],
-                  [REFSDIR],
-                  [REFSDIR, REFSDIR_TAGS],
-                  [REFSDIR, REFSDIR_HEADS],
-                  ["hooks"],
-                  ["info"]]:
+        for d in BASE_DIRECTORIES:
             os.mkdir(os.path.join(path, *d))
             os.mkdir(os.path.join(path, *d))
         ret = cls(path)
         ret = cls(path)
         ret.refs.set_ref("HEAD", "refs/heads/master")
         ret.refs.set_ref("HEAD", "refs/heads/master")
-        open(os.path.join(path, 'description'), 'wb').write("Unnamed repository")
-        open(os.path.join(path, 'info', 'excludes'), 'wb').write("")
+        ret._put_named_file('description', "Unnamed repository")
+        ret._put_named_file('config', """[core]
+    repositoryformatversion = 0
+    filemode = true
+    bare = false
+    logallrefupdates = true
+""")
+        ret._put_named_file(os.path.join('info', 'excludes'), '')
         return ret
         return ret
 
 
     create = init_bare
     create = init_bare
-

+ 414 - 99
dulwich/server.py

@@ -17,17 +17,37 @@
 # MA  02110-1301, USA.
 # MA  02110-1301, USA.
 
 
 
 
-"""Git smart network protocol server implementation."""
+"""Git smart network protocol server implementation.
 
 
+For more detailed implementation on the network protocol, see the
+Documentation/technical directory in the cgit distribution, and in particular:
+    Documentation/technical/protocol-capabilities.txt
+    Documentation/technical/pack-protocol.txt
+"""
 
 
+
+import collections
 import SocketServer
 import SocketServer
 import tempfile
 import tempfile
 
 
+from dulwich.errors import (
+    ApplyDeltaError,
+    ChecksumMismatch,
+    GitProtocolError,
+    )
+from dulwich.objects import (
+    hex_to_sha,
+    )
 from dulwich.protocol import (
 from dulwich.protocol import (
     Protocol,
     Protocol,
     ProtocolFile,
     ProtocolFile,
     TCP_GIT_PORT,
     TCP_GIT_PORT,
     extract_capabilities,
     extract_capabilities,
+    extract_want_line_capabilities,
+    SINGLE_ACK,
+    MULTI_ACK,
+    MULTI_ACK_DETAILED,
+    ack_type,
     )
     )
 from dulwich.repo import (
 from dulwich.repo import (
     Repo,
     Repo,
@@ -65,31 +85,65 @@ class Backend(object):
 
 
 class GitBackend(Backend):
 class GitBackend(Backend):
 
 
-    def __init__(self, gitdir=None):
-        self.gitdir = gitdir
-
-        if not self.gitdir:
-            self.gitdir = tempfile.mkdtemp()
-            Repo.create(self.gitdir)
-
-        self.repo = Repo(self.gitdir)
+    def __init__(self, repo=None):
+        if repo is None:
+            repo = Repo(tmpfile.mkdtemp())
+        self.repo = repo
+        self.object_store = self.repo.object_store
         self.fetch_objects = self.repo.fetch_objects
         self.fetch_objects = self.repo.fetch_objects
         self.get_refs = self.repo.get_refs
         self.get_refs = self.repo.get_refs
 
 
     def apply_pack(self, refs, read):
     def apply_pack(self, refs, read):
         f, commit = self.repo.object_store.add_thin_pack()
         f, commit = self.repo.object_store.add_thin_pack()
+        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
+        status = []
+        unpack_error = None
+        # TODO: more informative error messages than just the exception string
+        try:
+            # TODO: decode the pack as we stream to avoid blocking reads beyond
+            # the end of data (when using HTTP/1.1 chunked encoding)
+            while True:
+                data = read(10240)
+                if not data:
+                    break
+                f.write(data)
+        except all_exceptions, e:
+            unpack_error = str(e).replace('\n', '')
         try:
         try:
-            f.write(read())
-        finally:
             commit()
             commit()
+        except all_exceptions, e:
+            if not unpack_error:
+                unpack_error = str(e).replace('\n', '')
+
+        if unpack_error:
+            status.append(('unpack', unpack_error))
+        else:
+            status.append(('unpack', 'ok'))
 
 
         for oldsha, sha, ref in refs:
         for oldsha, sha, ref in refs:
-            if ref == "0" * 40:
-                del self.repo.refs[ref]
+            # TODO: check refname
+            ref_error = None
+            try:
+                if ref == "0" * 40:
+                    try:
+                        del self.repo.refs[ref]
+                    except all_exceptions:
+                        ref_error = 'failed to delete'
+                else:
+                    try:
+                        self.repo.refs[ref] = sha
+                    except all_exceptions:
+                        ref_error = 'failed to write'
+            except KeyError, e:
+                ref_error = 'bad ref'
+            if ref_error:
+                status.append((ref, ref_error))
             else:
             else:
-                self.repo.refs[ref] = sha
+                status.append((ref, 'ok'))
+
 
 
         print "pack applied"
         print "pack applied"
+        return status
 
 
 
 
 class Handler(object):
 class Handler(object):
@@ -106,102 +160,354 @@ class Handler(object):
 class UploadPackHandler(Handler):
 class UploadPackHandler(Handler):
     """Protocol handler for uploading a pack to the server."""
     """Protocol handler for uploading a pack to the server."""
 
 
+    def __init__(self, backend, read, write,
+                 stateless_rpc=False, advertise_refs=False):
+        Handler.__init__(self, backend, read, write)
+        self._client_capabilities = None
+        self._graph_walker = None
+        self.stateless_rpc = stateless_rpc
+        self.advertise_refs = advertise_refs
+
     def default_capabilities(self):
     def default_capabilities(self):
-        return ("multi_ack", "side-band-64k", "thin-pack", "ofs-delta")
+        return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
+                "ofs-delta")
 
 
-    def handle(self):
-        def determine_wants(heads):
-            keys = heads.keys()
-            if keys:
-                self.proto.write_pkt_line("%s %s\x00%s\n" % ( heads[keys[0]], keys[0], self.capabilities()))
-                for k in keys[1:]:
-                    self.proto.write_pkt_line("%s %s\n" % (heads[k], k))
+    def set_client_capabilities(self, caps):
+        my_caps = self.default_capabilities()
+        for cap in caps:
+            if '_ack' in cap and cap not in my_caps:
+                raise GitProtocolError('Client asked for capability %s that '
+                                       'was not advertised.' % cap)
+        self._client_capabilities = caps
 
 
-            # i'm done..
-            self.proto.write("0000")
+    def get_client_capabilities(self):
+        return self._client_capabilities
 
 
-            # Now client will either send "0000", meaning that it doesnt want to pull.
-            # or it will start sending want want want commands
-            want = self.proto.read_pkt_line()
-            if want == None:
-                return []
+    client_capabilities = property(get_client_capabilities,
+                                   set_client_capabilities)
 
 
-            want, self.client_capabilities = extract_capabilities(want)
-
-            want_revs = []
-            while want and want[:4] == 'want':
-                want_revs.append(want[5:45])
-                want = self.proto.read_pkt_line()
-                if want == None:
-                    self.proto.write_pkt_line("ACK %s\n" % want_revs[-1])
-            return want_revs
+    def handle(self):
 
 
         progress = lambda x: self.proto.write_sideband(2, x)
         progress = lambda x: self.proto.write_sideband(2, x)
         write = lambda x: self.proto.write_sideband(1, x)
         write = lambda x: self.proto.write_sideband(1, x)
 
 
-        class ProtocolGraphWalker(object):
+        graph_walker = ProtocolGraphWalker(self)
+        objects_iter = self.backend.fetch_objects(
+          graph_walker.determine_wants, graph_walker, progress)
 
 
-            def __init__(self, proto):
-                self.proto = proto
-                self._last_sha = None
-                self._cached = False
-                self._cache = []
-                self._cache_index = 0
+        # Do they want any objects?
+        if len(objects_iter) == 0:
+            return
 
 
-            def ack(self, have_ref):
-                self.proto.write_pkt_line("ACK %s continue\n" % have_ref)
+        progress("dul-daemon says what\n")
+        progress("counting objects: %d, done.\n" % len(objects_iter))
+        write_pack_data(ProtocolFile(None, write), objects_iter, 
+                        len(objects_iter))
+        progress("how was that, then?\n")
+        # we are done
+        self.proto.write("0000")
 
 
-            def reset(self):
-                self._cached = True
-                self._cache_index = 0
 
 
-            def next(self):
-                if not self._cached:
-                    return self.next_from_proto()
-                self._cache_index = self._cache_index + 1
-                if self._cache_index > len(self._cache):
-                    return None
-                return self._cache[self._cache_index]
+class ProtocolGraphWalker(object):
+    """A graph walker that knows the git protocol.
+
+    As a graph walker, this class implements ack(), next(), and reset(). It also
+    contains some base methods for interacting with the wire and walking the
+    commit tree.
+
+    The work of determining which acks to send is passed on to the
+    implementation instance stored in _impl. The reason for this is that we do
+    not know at object creation time what ack level the protocol requires. A
+    call to set_ack_level() is required to set up the implementation, before any
+    calls to next() or ack() are made.
+    """
+    def __init__(self, handler):
+        self.handler = handler
+        self.store = handler.backend.object_store
+        self.proto = handler.proto
+        self.stateless_rpc = handler.stateless_rpc
+        self.advertise_refs = handler.advertise_refs
+        self._wants = []
+        self._cached = False
+        self._cache = []
+        self._cache_index = 0
+        self._impl = None
+
+    def determine_wants(self, heads):
+        """Determine the wants for a set of heads.
+
+        The given heads are advertised to the client, who then specifies which
+        refs he wants using 'want' lines. This portion of the protocol is the
+        same regardless of ack type, and in fact is used to set the ack type of
+        the ProtocolGraphWalker.
+
+        :param heads: a dict of refname->SHA1 to advertise
+        :return: a list of SHA1s requested by the client
+        """
+        if not heads:
+            raise GitProtocolError('No heads found')
+        values = set(heads.itervalues())
+        if self.advertise_refs or not self.stateless_rpc:
+            for i, (ref, sha) in enumerate(heads.iteritems()):
+                line = "%s %s" % (sha, ref)
+                if not i:
+                    line = "%s\x00%s" % (line, self.handler.capabilities())
+                self.proto.write_pkt_line("%s\n" % line)
+                # TODO: include peeled value of any tags
 
 
-            def next_from_proto(self):
-                have = self.proto.read_pkt_line()
-                if have is None:
-                    self.proto.write_pkt_line("ACK %s\n" % self._last_sha)
-                    return None
+            # i'm done..
+            self.proto.write_pkt_line(None)
 
 
-                if have[:4] == 'have':
-                    self._cache.append(have[5:45])
-                    return have[5:45]
+            if self.advertise_refs:
+                return []
 
 
+        # Now client will sending want want want commands
+        want = self.proto.read_pkt_line()
+        if not want:
+            return []
+        line, caps = extract_want_line_capabilities(want)
+        self.handler.client_capabilities = caps
+        self.set_ack_type(ack_type(caps))
+        command, sha = self._split_proto_line(line)
+
+        want_revs = []
+        while command != None:
+            if command != 'want':
+                raise GitProtocolError(
+                    'Protocol got unexpected command %s' % command)
+            if sha not in values:
+                raise GitProtocolError(
+                    'Client wants invalid object %s' % sha)
+            want_revs.append(sha)
+            command, sha = self.read_proto_line()
+
+        self.set_wants(want_revs)
+        return want_revs
+
+    def ack(self, have_ref):
+        return self._impl.ack(have_ref)
+
+    def reset(self):
+        self._cached = True
+        self._cache_index = 0
+
+    def next(self):
+        if not self._cached:
+            if not self._impl and self.stateless_rpc:
+                return None
+            return self._impl.next()
+        self._cache_index += 1
+        if self._cache_index > len(self._cache):
+            return None
+        return self._cache[self._cache_index]
+
+    def _split_proto_line(self, line):
+        fields = line.rstrip('\n').split(' ', 1)
+        if len(fields) == 1 and fields[0] == 'done':
+            return ('done', None)
+        elif len(fields) == 2 and fields[0] in ('want', 'have'):
+            try:
+                hex_to_sha(fields[1])
+                return tuple(fields)
+            except (TypeError, AssertionError), e:
+                raise GitProtocolError(e)
+        raise GitProtocolError('Received invalid line from client:\n%s' % line)
+
+    def read_proto_line(self):
+        """Read a line from the wire.
+
+        :return: a tuple having one of the following forms:
+            ('want', obj_id)
+            ('have', obj_id)
+            ('done', None)
+            (None, None)  (for a flush-pkt)
+
+        :raise GitProtocolError: if the line cannot be parsed into one of the
+            possible return values.
+        """
+        line = self.proto.read_pkt_line()
+        if not line:
+            return (None, None)
+        return self._split_proto_line(line)
 
 
-                #if have[:4] == 'done':
-                #    return None
+    def send_ack(self, sha, ack_type=''):
+        if ack_type:
+            ack_type = ' %s' % ack_type
+        self.proto.write_pkt_line('ACK %s%s\n' % (sha, ack_type))
 
 
-                if self._last_sha:
-                    # Oddness: Git seems to resend the last ACK, without the "continue" statement
-                    self.proto.write_pkt_line("ACK %s\n" % self._last_sha)
+    def send_nak(self):
+        self.proto.write_pkt_line('NAK\n')
 
 
-                # The exchange finishes with a NAK
-                self.proto.write_pkt_line("NAK\n")
+    def set_wants(self, wants):
+        self._wants = wants
 
 
-        graph_walker = ProtocolGraphWalker(self.proto)
-        objects_iter = self.backend.fetch_objects(determine_wants, graph_walker, progress)
+    def _is_satisfied(self, haves, want, earliest):
+        """Check whether a want is satisfied by a set of haves.
 
 
-        # Do they want any objects?
-        if len(objects_iter) == 0:
-            return
+        A want, typically a branch tip, is "satisfied" only if there exists a
+        path back from that want to one of the haves.
 
 
-        progress("dul-daemon says what\n")
-        progress("counting objects: %d, done.\n" % len(objects_iter))
-        write_pack_data(ProtocolFile(None, write), objects_iter, 
-                        len(objects_iter))
-        progress("how was that, then?\n")
-        # we are done
-        self.proto.write("0000")
+        :param haves: A set of commits we know the client has.
+        :param want: The want to check satisfaction for.
+        :param earliest: A timestamp beyond which the search for haves will be
+            terminated, presumably because we're searching too far down the
+            wrong branch.
+        """
+        o = self.store[want]
+        pending = collections.deque([o])
+        while pending:
+            commit = pending.popleft()
+            if commit.id in haves:
+                return True
+            if not getattr(commit, 'get_parents', None):
+                # non-commit wants are assumed to be satisfied
+                continue
+            for parent in commit.get_parents():
+                parent_obj = self.store[parent]
+                # TODO: handle parents with later commit times than children
+                if parent_obj.commit_time >= earliest:
+                    pending.append(parent_obj)
+        return False
+
+    def all_wants_satisfied(self, haves):
+        """Check whether all the current wants are satisfied by a set of haves.
+
+        :param haves: A set of commits we know the client has.
+        :note: Wants are specified with set_wants rather than passed in since
+            in the current interface they are determined outside this class.
+        """
+        haves = set(haves)
+        earliest = min([self.store[h].commit_time for h in haves])
+        for want in self._wants:
+            if not self._is_satisfied(haves, want, earliest):
+                return False
+        return True
+
+    def set_ack_type(self, ack_type):
+        impl_classes = {
+            MULTI_ACK: MultiAckGraphWalkerImpl,
+            MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
+            SINGLE_ACK: SingleAckGraphWalkerImpl,
+            }
+        self._impl = impl_classes[ack_type](self)
+
+
+class SingleAckGraphWalkerImpl(object):
+    """Graph walker implementation that speaks the single-ack protocol."""
+
+    def __init__(self, walker):
+        self.walker = walker
+        self._sent_ack = False
+
+    def ack(self, have_ref):
+        if not self._sent_ack:
+            self.walker.send_ack(have_ref)
+            self._sent_ack = True
+
+    def next(self):
+        command, sha = self.walker.read_proto_line()
+        if command in (None, 'done'):
+            if not self._sent_ack:
+                self.walker.send_nak()
+            return None
+        elif command == 'have':
+            return sha
+
+
+class MultiAckGraphWalkerImpl(object):
+    """Graph walker implementation that speaks the multi-ack protocol."""
+
+    def __init__(self, walker):
+        self.walker = walker
+        self._found_base = False
+        self._common = []
+
+    def ack(self, have_ref):
+        self._common.append(have_ref)
+        if not self._found_base:
+            self.walker.send_ack(have_ref, 'continue')
+            if self.walker.all_wants_satisfied(self._common):
+                self._found_base = True
+        # else we blind ack within next
+
+    def next(self):
+        while True:
+            command, sha = self.walker.read_proto_line()
+            if command is None:
+                self.walker.send_nak()
+                # in multi-ack mode, a flush-pkt indicates the client wants to
+                # flush but more have lines are still coming
+                continue
+            elif command == 'done':
+                # don't nak unless no common commits were found, even if not
+                # everything is satisfied
+                if self._common:
+                    self.walker.send_ack(self._common[-1])
+                else:
+                    self.walker.send_nak()
+                return None
+            elif command == 'have':
+                if self._found_base:
+                    # blind ack
+                    self.walker.send_ack(sha, 'continue')
+                return sha
+
+
+class MultiAckDetailedGraphWalkerImpl(object):
+    """Graph walker implementation speaking the multi-ack-detailed protocol."""
+
+    def __init__(self, walker):
+        self.walker = walker
+        self._found_base = False
+        self._common = []
+
+    def ack(self, have_ref):
+        self._common.append(have_ref)
+        if not self._found_base:
+            self.walker.send_ack(have_ref, 'common')
+            if self.walker.all_wants_satisfied(self._common):
+                self._found_base = True
+                self.walker.send_ack(have_ref, 'ready')
+        # else we blind ack within next
+
+    def next(self):
+        while True:
+            command, sha = self.walker.read_proto_line()
+            if command is None:
+                self.walker.send_nak()
+                if self.walker.stateless_rpc:
+                    return None
+                continue
+            elif command == 'done':
+                # don't nak unless no common commits were found, even if not
+                # everything is satisfied
+                if self._common:
+                    self.walker.send_ack(self._common[-1])
+                else:
+                    self.walker.send_nak()
+                return None
+            elif command == 'have':
+                if self._found_base:
+                    # blind ack; can happen if the client has more requests
+                    # inflight
+                    self.walker.send_ack(sha, 'ready')
+                return sha
 
 
 
 
 class ReceivePackHandler(Handler):
 class ReceivePackHandler(Handler):
-    """Protocol handler for downloading a pack to the client."""
+    """Protocol handler for downloading a pack from the client."""
+
+    def __init__(self, backend, read, write,
+                 stateless_rpc=False, advertise_refs=False):
+        Handler.__init__(self, backend, read, write)
+        self.stateless_rpc = stateless_rpc
+        self.advertise_refs = advertise_refs
+
+    def __init__(self, backend, read, write,
+                 stateless_rpc=False, advertise_refs=False):
+        Handler.__init__(self, backend, read, write)
+        self._stateless_rpc = stateless_rpc
+        self._advertise_refs = advertise_refs
 
 
     def default_capabilities(self):
     def default_capabilities(self):
         return ("report-status", "delete-refs")
         return ("report-status", "delete-refs")
@@ -209,15 +515,18 @@ class ReceivePackHandler(Handler):
     def handle(self):
     def handle(self):
         refs = self.backend.get_refs().items()
         refs = self.backend.get_refs().items()
 
 
-        if refs:
-            self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
-            for i in range(1, len(refs)):
-                ref = refs[i]
-                self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
-        else:
-            self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
+        if self.advertise_refs or not self.stateless_rpc:
+            if refs:
+                self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
+                for i in range(1, len(refs)):
+                    ref = refs[i]
+                    self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
+            else:
+                self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
 
 
-        self.proto.write("0000")
+            self.proto.write("0000")
+            if self.advertise_refs:
+                return
 
 
         client_refs = []
         client_refs = []
         ref = self.proto.read_pkt_line()
         ref = self.proto.read_pkt_line()
@@ -234,11 +543,19 @@ class ReceivePackHandler(Handler):
             ref = self.proto.read_pkt_line()
             ref = self.proto.read_pkt_line()
 
 
         # backend can now deal with this refs and read a pack using self.read
         # backend can now deal with this refs and read a pack using self.read
-        self.backend.apply_pack(client_refs, self.proto.read)
+        status = self.backend.apply_pack(client_refs, self.proto.read)
 
 
-        # when we have read all the pack from the client, it assumes 
-        # everything worked OK.
-        # there is NO ack from the server before it reports victory.
+        # when we have read all the pack from the client, send a status report
+        # if the client asked for it
+        if 'report-status' in client_capabilities:
+            for name, msg in status:
+                if name == 'unpack':
+                    self.proto.write_pkt_line('unpack %s\n' % msg)
+                elif msg == 'ok':
+                    self.proto.write_pkt_line('ok %s\n' % name)
+                else:
+                    self.proto.write_pkt_line('ng %s %s\n' % (name, msg))
+            self.proto.write_pkt_line(None)
 
 
 
 
 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
@@ -267,5 +584,3 @@ class TCPGitServer(SocketServer.TCPServer):
     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT):
     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT):
         self.backend = backend
         self.backend = backend
         SocketServer.TCPServer.__init__(self, (listen_addr, port), TCPGitRequestHandler)
         SocketServer.TCPServer.__init__(self, (listen_addr, port), TCPGitRequestHandler)
-
-

+ 1 - 0
dulwich/tests/data/repos/refs.git/HEAD

@@ -0,0 +1 @@
+ref: refs/heads/master

BIN
dulwich/tests/data/repos/refs.git/objects/3b/9e5457140e738c2dcd39bf6d7acf88379b90d1


BIN
dulwich/tests/data/repos/refs.git/objects/42/d06bd4b77fed026b154d16493e5deab78f02ec


BIN
dulwich/tests/data/repos/refs.git/objects/a1/8114c31713746a33a2e70d9914d1ef3e781425


BIN
dulwich/tests/data/repos/refs.git/objects/df/6800012397fb85c56e7418dd4eb9405dee075c


+ 3 - 0
dulwich/tests/data/repos/refs.git/packed-refs

@@ -0,0 +1,3 @@
+# pack-refs with: peeled 
+df6800012397fb85c56e7418dd4eb9405dee075c refs/tags/refs-0.1
+^42d06bd4b77fed026b154d16493e5deab78f02ec

+ 1 - 0
dulwich/tests/data/repos/refs.git/refs/heads/loop

@@ -0,0 +1 @@
+ref: refs/heads/loop

+ 1 - 0
dulwich/tests/data/repos/refs.git/refs/heads/master

@@ -0,0 +1 @@
+42d06bd4b77fed026b154d16493e5deab78f02ec

+ 131 - 0
dulwich/tests/test_file.py

@@ -0,0 +1,131 @@
+# test_file.py -- Test for git files
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) a later version of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+
+import errno
+import os
+import shutil
+import tempfile
+import unittest
+
+from dulwich.file import GitFile
+
+class GitFileTests(unittest.TestCase):
+    def setUp(self):
+        self._tempdir = tempfile.mkdtemp()
+        f = open(self.path('foo'), 'wb')
+        f.write('foo contents')
+        f.close()
+
+    def tearDown(self):
+        shutil.rmtree(self._tempdir)
+
+    def path(self, filename):
+        return os.path.join(self._tempdir, filename)
+
+    def test_invalid(self):
+        foo = self.path('foo')
+        self.assertRaises(IOError, GitFile, foo, mode='r')
+        self.assertRaises(IOError, GitFile, foo, mode='ab')
+        self.assertRaises(IOError, GitFile, foo, mode='r+b')
+        self.assertRaises(IOError, GitFile, foo, mode='w+b')
+        self.assertRaises(IOError, GitFile, foo, mode='a+bU')
+
+    def test_readonly(self):
+        f = GitFile(self.path('foo'), 'rb')
+        self.assertTrue(isinstance(f, file))
+        self.assertEquals('foo contents', f.read())
+        self.assertEquals('', f.read())
+        f.seek(4)
+        self.assertEquals('contents', f.read())
+        f.close()
+
+    def test_write(self):
+        foo = self.path('foo')
+        foo_lock = '%s.lock' % foo
+
+        orig_f = open(foo, 'rb')
+        self.assertEquals(orig_f.read(), 'foo contents')
+        orig_f.close()
+
+        self.assertFalse(os.path.exists(foo_lock))
+        f = GitFile(foo, 'wb')
+        self.assertFalse(f.closed)
+        self.assertRaises(AttributeError, getattr, f, 'not_a_file_property')
+
+        self.assertTrue(os.path.exists(foo_lock))
+        f.write('new stuff')
+        f.seek(4)
+        f.write('contents')
+        f.close()
+        self.assertFalse(os.path.exists(foo_lock))
+
+        new_f = open(foo, 'rb')
+        self.assertEquals('new contents', new_f.read())
+        new_f.close()
+
+    def test_open_twice(self):
+        foo = self.path('foo')
+        f1 = GitFile(foo, 'wb')
+        f1.write('new')
+        try:
+            f2 = GitFile(foo, 'wb')
+            fail()
+        except OSError, e:
+            self.assertEquals(errno.EEXIST, e.errno)
+        f1.write(' contents')
+        f1.close()
+
+        # Ensure trying to open twice doesn't affect original.
+        f = open(foo, 'rb')
+        self.assertEquals('new contents', f.read())
+        f.close()
+
+    def test_abort(self):
+        foo = self.path('foo')
+        foo_lock = '%s.lock' % foo
+
+        orig_f = open(foo, 'rb')
+        self.assertEquals(orig_f.read(), 'foo contents')
+        orig_f.close()
+
+        f = GitFile(foo, 'wb')
+        f.write('new contents')
+        f.abort()
+        self.assertTrue(f.closed)
+        self.assertFalse(os.path.exists(foo_lock))
+
+        new_orig_f = open(foo, 'rb')
+        self.assertEquals(new_orig_f.read(), 'foo contents')
+        new_orig_f.close()
+
+    def test_abort_close(self):
+        foo = self.path('foo')
+        f = GitFile(foo, 'wb')
+        f.abort()
+        try:
+            f.close()
+        except (IOError, OSError):
+            self.fail()
+
+        f = GitFile(foo, 'wb')
+        f.close()
+        try:
+            f.abort()
+        except (IOError, OSError):
+            self.fail()

+ 35 - 0
dulwich/tests/test_objects.py

@@ -203,6 +203,41 @@ class CommitSerializationTests(unittest.TestCase):
         self.assertTrue(" -0100\n" in c.as_raw_string())
         self.assertTrue(" -0100\n" in c.as_raw_string())
 
 
 
 
+class CommitDeserializationTests(unittest.TestCase):
+
+    def test_simple(self):
+        c = Commit.from_string(
+                'tree d80c186a03f423a81b39df39dc87fd269736ca86\n'
+                'parent ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd\n'
+                'parent 4cffe90e0a41ad3f5190079d7c8f036bde29cbe6\n'
+                'author James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
+                'committer James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
+                '\n'
+                'Merge ../b\n')
+        self.assertEquals('Merge ../b\n', c.message)
+        self.assertEquals('James Westby <jw+debian@jameswestby.net>',
+            c.author)
+        self.assertEquals('James Westby <jw+debian@jameswestby.net>',
+            c.committer)
+        self.assertEquals('d80c186a03f423a81b39df39dc87fd269736ca86',
+            c.tree)
+        self.assertEquals(['ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd',
+                          '4cffe90e0a41ad3f5190079d7c8f036bde29cbe6'],
+            c.parents)
+
+    def test_custom(self):
+        c = Commit.from_string(
+                'tree d80c186a03f423a81b39df39dc87fd269736ca86\n'
+                'parent ab64bbdcc51b170d21588e5c5d391ee5c0c96dfd\n'
+                'parent 4cffe90e0a41ad3f5190079d7c8f036bde29cbe6\n'
+                'author James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
+                'committer James Westby <jw+debian@jameswestby.net> 1174773719 +0000\n'
+                'extra-field data\n'
+                '\n'
+                'Merge ../b\n')
+        self.assertEquals([('extra-field', 'data')], c.extra)
+
+
 class TreeSerializationTests(unittest.TestCase):
 class TreeSerializationTests(unittest.TestCase):
 
 
     def test_simple(self):
     def test_simple(self):

+ 3 - 2
dulwich/tests/test_pack.py

@@ -21,6 +21,7 @@
 """Tests for Dulwich packs."""
 """Tests for Dulwich packs."""
 
 
 
 
+from cStringIO import StringIO
 import os
 import os
 import unittest
 import unittest
 
 
@@ -282,6 +283,6 @@ TEST_COMP1 = """\x78\x9c\x9d\x8e\xc1\x0a\xc2\x30\x10\x44\xef\xf9\x8a\xbd\xa9\x08
 class ZlibTests(unittest.TestCase):
 class ZlibTests(unittest.TestCase):
 
 
     def test_simple_decompress(self):
     def test_simple_decompress(self):
-        self.assertEquals(("tree 4ada885c9196b6b6fa08744b5862bf92896fc002\nparent None\nauthor Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\ncommitter Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\n\nProvide replacement for mmap()'s offset argument.", 158), 
-        read_zlib(TEST_COMP1, 0, 229))
+        self.assertEquals(("tree 4ada885c9196b6b6fa08744b5862bf92896fc002\nparent None\nauthor Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\ncommitter Jelmer Vernooij <jelmer@samba.org> 1228980214 +0000\n\nProvide replacement for mmap()'s offset argument.", 158, 'Z'), 
+        read_zlib(StringIO(TEST_COMP1).read, 229))
 
 

+ 28 - 3
dulwich/tests/test_protocol.py

@@ -26,6 +26,11 @@ from unittest import TestCase
 from dulwich.protocol import (
 from dulwich.protocol import (
     Protocol,
     Protocol,
     extract_capabilities,
     extract_capabilities,
+    extract_want_line_capabilities,
+    ack_type,
+    SINGLE_ACK,
+    MULTI_ACK,
+    MULTI_ACK_DETAILED,
     )
     )
 
 
 class ProtocolTests(TestCase):
 class ProtocolTests(TestCase):
@@ -77,10 +82,30 @@ class ProtocolTests(TestCase):
         self.assertRaises(AssertionError, self.proto.read_cmd)
         self.assertRaises(AssertionError, self.proto.read_cmd)
 
 
 
 
-class ExtractCapabilitiesTestCase(TestCase):
+class CapabilitiesTestCase(TestCase):
 
 
     def test_plain(self):
     def test_plain(self):
-        self.assertEquals(("bla", None), extract_capabilities("bla"))
+        self.assertEquals(("bla", []), extract_capabilities("bla"))
 
 
     def test_caps(self):
     def test_caps(self):
-        self.assertEquals(("bla", ["la", "la"]), extract_capabilities("bla\0la\0la"))
+        self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la"))
+        self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la\n"))
+        self.assertEquals(("bla", ["la", "la"]), extract_capabilities("bla\0la la"))
+
+    def test_plain_want_line(self):
+        self.assertEquals(("want bla", []), extract_want_line_capabilities("want bla"))
+
+    def test_caps_want_line(self):
+        self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la"))
+        self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la\n"))
+        self.assertEquals(("want bla", ["la", "la"]), extract_want_line_capabilities("want bla la la"))
+
+    def test_ack_type(self):
+        self.assertEquals(SINGLE_ACK, ack_type(['foo', 'bar']))
+        self.assertEquals(MULTI_ACK, ack_type(['foo', 'bar', 'multi_ack']))
+        self.assertEquals(MULTI_ACK_DETAILED,
+                          ack_type(['foo', 'bar', 'multi_ack_detailed']))
+        # choose detailed when both present
+        self.assertEquals(MULTI_ACK_DETAILED,
+                          ack_type(['foo', 'bar', 'multi_ack',
+                                    'multi_ack_detailed']))

+ 324 - 31
dulwich/tests/test_repository.py

@@ -20,95 +20,141 @@
 
 
 """Tests for the repository."""
 """Tests for the repository."""
 
 
-
+from cStringIO import StringIO
 import os
 import os
+import shutil
+import tempfile
 import unittest
 import unittest
 
 
 from dulwich import errors
 from dulwich import errors
-from dulwich.repo import Repo
+from dulwich.repo import (
+    check_ref_format,
+    Repo,
+    read_packed_refs,
+    read_packed_refs_with_peeled,
+    write_packed_refs,
+    _split_ref_line,
+    )
 
 
 missing_sha = 'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
 missing_sha = 'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
 
 
+
+def open_repo(name):
+    """Open a copy of a repo in a temporary directory.
+
+    Use this function for accessing repos in dulwich/tests/data/repos to avoid
+    accidentally or intentionally modifying those repos in place. Use
+    tear_down_repo to delete any temp files created.
+
+    :param name: The name of the repository, relative to
+        dulwich/tests/data/repos
+    :returns: An initialized Repo object that lives in a temporary directory.
+    """
+    temp_dir = tempfile.mkdtemp()
+    repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos', name)
+    temp_repo_dir = os.path.join(temp_dir, name)
+    shutil.copytree(repo_dir, temp_repo_dir, symlinks=True)
+    return Repo(temp_repo_dir)
+
+def tear_down_repo(repo):
+    """Tear down a test repository."""
+    temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
+    shutil.rmtree(temp_dir)
+
+
+
+class CreateRepositoryTests(unittest.TestCase):
+
+    def test_create(self):
+        tmp_dir = tempfile.mkdtemp()
+        try:
+            repo = Repo.init_bare(tmp_dir)
+            self.assertEquals(tmp_dir, repo._controldir)
+        finally:
+            shutil.rmtree(tmp_dir)
+
+
 class RepositoryTests(unittest.TestCase):
 class RepositoryTests(unittest.TestCase):
-  
-    def open_repo(self, name):
-        return Repo(os.path.join(os.path.dirname(__file__),
-                          'data', 'repos', name))
-  
+
+    def setUp(self):
+        self._repo = None
+
+    def tearDown(self):
+        if self._repo is not None:
+            tear_down_repo(self._repo)
+
     def test_simple_props(self):
     def test_simple_props(self):
-        r = self.open_repo('a.git')
-        basedir = os.path.join(os.path.dirname(__file__), 
-                os.path.join('data', 'repos', 'a.git'))
-        self.assertEqual(r.controldir(), basedir)
+        r = self._repo = open_repo('a.git')
+        self.assertEqual(r.controldir(), r.path)
   
   
     def test_ref(self):
     def test_ref(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertEqual(r.ref('refs/heads/master'),
         self.assertEqual(r.ref('refs/heads/master'),
                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
   
   
     def test_get_refs(self):
     def test_get_refs(self):
-        r = self.open_repo('a.git')
-        self.assertEquals({
+        r = self._repo = open_repo('a.git')
+        self.assertEqual({
             'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097', 
             'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097', 
             'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
             'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
             }, r.get_refs())
             }, r.get_refs())
   
   
     def test_head(self):
     def test_head(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
   
   
     def test_get_object(self):
     def test_get_object(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         obj = r.get_object(r.head())
         obj = r.get_object(r.head())
         self.assertEqual(obj._type, 'commit')
         self.assertEqual(obj._type, 'commit')
   
   
     def test_get_object_non_existant(self):
     def test_get_object_non_existant(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertRaises(KeyError, r.get_object, missing_sha)
         self.assertRaises(KeyError, r.get_object, missing_sha)
   
   
     def test_commit(self):
     def test_commit(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         obj = r.commit(r.head())
         obj = r.commit(r.head())
         self.assertEqual(obj._type, 'commit')
         self.assertEqual(obj._type, 'commit')
   
   
     def test_commit_not_commit(self):
     def test_commit_not_commit(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertRaises(errors.NotCommitError,
         self.assertRaises(errors.NotCommitError,
                           r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
                           r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
   
   
     def test_tree(self):
     def test_tree(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         commit = r.commit(r.head())
         commit = r.commit(r.head())
         tree = r.tree(commit.tree)
         tree = r.tree(commit.tree)
         self.assertEqual(tree._type, 'tree')
         self.assertEqual(tree._type, 'tree')
         self.assertEqual(tree.sha().hexdigest(), commit.tree)
         self.assertEqual(tree.sha().hexdigest(), commit.tree)
   
   
     def test_tree_not_tree(self):
     def test_tree_not_tree(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertRaises(errors.NotTreeError, r.tree, r.head())
         self.assertRaises(errors.NotTreeError, r.tree, r.head())
   
   
     def test_get_blob(self):
     def test_get_blob(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         commit = r.commit(r.head())
         commit = r.commit(r.head())
-        tree = r.tree(commit.tree())
+        tree = r.tree(commit.tree)
         blob_sha = tree.entries()[0][2]
         blob_sha = tree.entries()[0][2]
         blob = r.get_blob(blob_sha)
         blob = r.get_blob(blob_sha)
         self.assertEqual(blob._type, 'blob')
         self.assertEqual(blob._type, 'blob')
         self.assertEqual(blob.sha().hexdigest(), blob_sha)
         self.assertEqual(blob.sha().hexdigest(), blob_sha)
   
   
-    def test_get_blob(self):
-        r = self.open_repo('a.git')
+    def test_get_blob_notblob(self):
+        r = self._repo = open_repo('a.git')
         self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
         self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
     
     
     def test_linear_history(self):
     def test_linear_history(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         history = r.revision_history(r.head())
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, [r.head(),
         self.assertEqual(shas, [r.head(),
                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
   
   
     def test_merge_history(self):
     def test_merge_history(self):
-        r = self.open_repo('simple_merge.git')
+        r = self._repo = open_repo('simple_merge.git')
         history = r.revision_history(r.head())
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, ['5dac377bdded4c9aeb8dff595f0faeebcc8498cc',
         self.assertEqual(shas, ['5dac377bdded4c9aeb8dff595f0faeebcc8498cc',
@@ -118,13 +164,13 @@ class RepositoryTests(unittest.TestCase):
                                 '0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
                                 '0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
   
   
     def test_revision_history_missing_commit(self):
     def test_revision_history_missing_commit(self):
-        r = self.open_repo('simple_merge.git')
+        r = self._repo = open_repo('simple_merge.git')
         self.assertRaises(errors.MissingCommitError, r.revision_history,
         self.assertRaises(errors.MissingCommitError, r.revision_history,
                           missing_sha)
                           missing_sha)
   
   
     def test_out_of_order_merge(self):
     def test_out_of_order_merge(self):
         """Test that revision history is ordered by date, not parent order."""
         """Test that revision history is ordered by date, not parent order."""
-        r = self.open_repo('ooo_merge.git')
+        r = self._repo = open_repo('ooo_merge.git')
         history = r.revision_history(r.head())
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, ['7601d7f6231db6a57f7bbb79ee52e4d462fd44d1',
         self.assertEqual(shas, ['7601d7f6231db6a57f7bbb79ee52e4d462fd44d1',
@@ -133,5 +179,252 @@ class RepositoryTests(unittest.TestCase):
                                 'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
                                 'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
   
   
     def test_get_tags_empty(self):
     def test_get_tags_empty(self):
-        r = self.open_repo('ooo_merge.git')
-        self.assertEquals({}, r.refs.as_dict('refs/tags'))
+        r = self._repo = open_repo('ooo_merge.git')
+        self.assertEqual({}, r.refs.as_dict('refs/tags'))
+
+    def test_get_config(self):
+        r = self._repo = open_repo('ooo_merge.git')
+        self.assertEquals({}, r.get_config())
+
+
+class CheckRefFormatTests(unittest.TestCase):
+    """Tests for the check_ref_format function.
+
+    These are the same tests as in the git test suite.
+    """
+
+    def test_valid(self):
+        self.assertTrue(check_ref_format('heads/foo'))
+        self.assertTrue(check_ref_format('foo/bar/baz'))
+        self.assertTrue(check_ref_format('refs///heads/foo'))
+        self.assertTrue(check_ref_format('foo./bar'))
+        self.assertTrue(check_ref_format('heads/foo@bar'))
+        self.assertTrue(check_ref_format('heads/fix.lock.error'))
+
+    def test_invalid(self):
+        self.assertFalse(check_ref_format('foo'))
+        self.assertFalse(check_ref_format('heads/foo/'))
+        self.assertFalse(check_ref_format('./foo'))
+        self.assertFalse(check_ref_format('.refs/foo'))
+        self.assertFalse(check_ref_format('heads/foo..bar'))
+        self.assertFalse(check_ref_format('heads/foo?bar'))
+        self.assertFalse(check_ref_format('heads/foo.lock'))
+        self.assertFalse(check_ref_format('heads/v@{ation'))
+        self.assertFalse(check_ref_format('heads/foo\bar'))
+
+
+ONES = "1" * 40
+TWOS = "2" * 40
+THREES = "3" * 40
+FOURS = "4" * 40
+
+class PackedRefsFileTests(unittest.TestCase):
+
+    def test_split_ref_line_errors(self):
+        self.assertRaises(errors.PackedRefsException, _split_ref_line,
+                          'singlefield')
+        self.assertRaises(errors.PackedRefsException, _split_ref_line,
+                          'badsha name')
+        self.assertRaises(errors.PackedRefsException, _split_ref_line,
+                          '%s bad/../refname' % ONES)
+
+    def test_read_without_peeled(self):
+        f = StringIO('# comment\n%s ref/1\n%s ref/2' % (ONES, TWOS))
+        self.assertEqual([(ONES, 'ref/1'), (TWOS, 'ref/2')],
+                         list(read_packed_refs(f)))
+
+    def test_read_without_peeled_errors(self):
+        f = StringIO('%s ref/1\n^%s' % (ONES, TWOS))
+        self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
+
+    def test_read_with_peeled(self):
+        f = StringIO('%s ref/1\n%s ref/2\n^%s\n%s ref/4' % (
+            ONES, TWOS, THREES, FOURS))
+        self.assertEqual([
+            (ONES, 'ref/1', None),
+            (TWOS, 'ref/2', THREES),
+            (FOURS, 'ref/4', None),
+            ], list(read_packed_refs_with_peeled(f)))
+
+    def test_read_with_peeled_errors(self):
+        f = StringIO('^%s\n%s ref/1' % (TWOS, ONES))
+        self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
+
+        f = StringIO('%s ref/1\n^%s\n^%s' % (ONES, TWOS, THREES))
+        self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
+
+    def test_write_with_peeled(self):
+        f = StringIO()
+        write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS},
+                          {'ref/1': THREES})
+        self.assertEqual(
+            "# pack-refs with: peeled\n%s ref/1\n^%s\n%s ref/2\n" % (
+            ONES, THREES, TWOS), f.getvalue())
+
+    def test_write_without_peeled(self):
+        f = StringIO()
+        write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS})
+        self.assertEqual("%s ref/1\n%s ref/2\n" % (ONES, TWOS), f.getvalue())
+
+
+class RefsContainerTests(unittest.TestCase):
+
+    def setUp(self):
+        self._repo = open_repo('refs.git')
+        self._refs = self._repo.refs
+
+    def tearDown(self):
+        tear_down_repo(self._repo)
+
+    def test_get_packed_refs(self):
+        self.assertEqual(
+            {'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c'},
+            self._refs.get_packed_refs())
+
+    def test_keys(self):
+        self.assertEqual([
+            'HEAD',
+            'refs/heads/loop',
+            'refs/heads/master',
+            'refs/tags/refs-0.1',
+            ], sorted(list(self._refs.keys())))
+        self.assertEqual(['loop', 'master'],
+                         sorted(self._refs.keys('refs/heads')))
+        self.assertEqual(['refs-0.1'], list(self._refs.keys('refs/tags')))
+
+    def test_as_dict(self):
+        # refs/heads/loop does not show up
+        self.assertEqual({
+            'HEAD': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+            'refs/heads/master': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+            'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
+            }, self._refs.as_dict())
+
+    def test_setitem(self):
+        self._refs['refs/some/ref'] = '42d06bd4b77fed026b154d16493e5deab78f02ec'
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/some/ref'])
+        f = open(os.path.join(self._refs.path, 'refs', 'some', 'ref'), 'rb')
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                          f.read()[:40])
+        f.close()
+
+    def test_setitem_symbolic(self):
+        ones = '1' * 40
+        self._refs['HEAD'] = ones
+        self.assertEqual(ones, self._refs['HEAD'])
+
+        # ensure HEAD was not modified
+        f = open(os.path.join(self._refs.path, 'HEAD'), 'rb')
+        self.assertEqual('ref: refs/heads/master', iter(f).next().rstrip('\n'))
+        f.close()
+
+        # ensure the symbolic link was written through
+        f = open(os.path.join(self._refs.path, 'refs', 'heads', 'master'), 'rb')
+        self.assertEqual(ones, f.read()[:40])
+        f.close()
+
+    def test_set_if_equals(self):
+        nines = '9' * 40
+        self.assertFalse(self._refs.set_if_equals('HEAD', 'c0ffee', nines))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['HEAD'])
+
+        self.assertTrue(self._refs.set_if_equals(
+            'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec', nines))
+        self.assertEqual(nines, self._refs['HEAD'])
+
+        # ensure symref was followed
+        self.assertEqual(nines, self._refs['refs/heads/master'])
+
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'HEAD.lock')))
+
+    def test_add_if_new(self):
+        nines = '9' * 40
+        self.assertFalse(self._refs.add_if_new('refs/heads/master', nines))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/heads/master'])
+
+        self.assertTrue(self._refs.add_if_new('refs/some/ref', nines))
+        self.assertEqual(nines, self._refs['refs/some/ref'])
+
+        # don't overwrite packed ref
+        self.assertFalse(self._refs.add_if_new('refs/tags/refs-0.1', nines))
+        self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
+                         self._refs['refs/tags/refs-0.1'])
+
+    def test_check_refname(self):
+        try:
+            self._refs._check_refname('HEAD')
+        except KeyError:
+            self.fail()
+
+        try:
+            self._refs._check_refname('refs/heads/foo')
+        except KeyError:
+            self.fail()
+
+        self.assertRaises(KeyError, self._refs._check_refname, 'refs')
+        self.assertRaises(KeyError, self._refs._check_refname, 'notrefs/foo')
+
+    def test_follow(self):
+        self.assertEquals(
+            ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
+            self._refs._follow('HEAD'))
+        self.assertEquals(
+            ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
+            self._refs._follow('refs/heads/master'))
+        self.assertRaises(KeyError, self._refs._follow, 'notrefs/foo')
+        self.assertRaises(KeyError, self._refs._follow, 'refs/heads/loop')
+
+    def test_delitem(self):
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                          self._refs['refs/heads/master'])
+        del self._refs['refs/heads/master']
+        self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
+        ref_file = os.path.join(self._refs.path, 'refs', 'heads', 'master')
+        self.assertFalse(os.path.exists(ref_file))
+        self.assertFalse('refs/heads/master' in self._refs.get_packed_refs())
+
+    def test_delitem_symbolic(self):
+        self.assertEqual('ref: refs/heads/master',
+                          self._refs.read_loose_ref('HEAD'))
+        del self._refs['HEAD']
+        self.assertRaises(KeyError, lambda: self._refs['HEAD'])
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/heads/master'])
+        self.assertFalse(os.path.exists(os.path.join(self._refs.path, 'HEAD')))
+
+    def test_remove_if_equals(self):
+        nines = '9' * 40
+        self.assertFalse(self._refs.remove_if_equals('HEAD', 'c0ffee'))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['HEAD'])
+
+        # HEAD is a symref, so shouldn't equal its dereferenced value
+        self.assertFalse(self._refs.remove_if_equals(
+            'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
+        self.assertTrue(self._refs.remove_if_equals(
+            'refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
+        self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
+
+        # HEAD is now a broken symref
+        self.assertRaises(KeyError, lambda: self._refs['HEAD'])
+        self.assertEqual('ref: refs/heads/master',
+                          self._refs.read_loose_ref('HEAD'))
+
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'HEAD.lock')))
+
+        # test removing ref that is only packed
+        self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
+                         self._refs['refs/tags/refs-0.1'])
+        self.assertTrue(
+            self._refs.remove_if_equals('refs/tags/refs-0.1',
+            'df6800012397fb85c56e7418dd4eb9405dee075c'))
+        self.assertRaises(KeyError, lambda: self._refs['refs/tags/refs-0.1'])

+ 519 - 0
dulwich/tests/test_server.py

@@ -0,0 +1,519 @@
+# test_server.py -- Tests for the git server
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# or (at your option) any later version of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+
+"""Tests for the smart protocol server."""
+
+
+from cStringIO import StringIO
+from unittest import TestCase
+
+from dulwich.errors import (
+    GitProtocolError,
+    )
+from dulwich.server import (
+    UploadPackHandler,
+    ProtocolGraphWalker,
+    SingleAckGraphWalkerImpl,
+    MultiAckGraphWalkerImpl,
+    MultiAckDetailedGraphWalkerImpl,
+    )
+
+from dulwich.protocol import (
+    SINGLE_ACK,
+    MULTI_ACK,
+    )
+
+ONE = '1' * 40
+TWO = '2' * 40
+THREE = '3' * 40
+FOUR = '4' * 40
+FIVE = '5' * 40
+
+class TestProto(object):
+    def __init__(self):
+        self._output = []
+        self._received = {0: [], 1: [], 2: [], 3: []}
+
+    def set_output(self, output_lines):
+        self._output = ['%s\n' % line.rstrip() for line in output_lines]
+
+    def read_pkt_line(self):
+        if self._output:
+            return self._output.pop(0)
+        else:
+            return None
+
+    def write_sideband(self, band, data):
+        self._received[band].append(data)
+
+    def write_pkt_line(self, data):
+        if data is None:
+            data = 'None'
+        self._received[0].append(data)
+
+    def get_received_line(self, band=0):
+        lines = self._received[band]
+        if lines:
+            return lines.pop(0)
+        else:
+            return None
+
+
+class UploadPackHandlerTestCase(TestCase):
+    def setUp(self):
+        self._handler = UploadPackHandler(None, None, None)
+
+    def test_set_client_capabilities(self):
+        try:
+            self._handler.set_client_capabilities([])
+        except GitProtocolError:
+            self.fail()
+
+        try:
+            self._handler.set_client_capabilities([
+                'multi_ack', 'side-band-64k', 'thin-pack', 'ofs-delta'])
+        except GitProtocolError:
+            self.fail()
+
+    def test_set_client_capabilities_error(self):
+        self.assertRaises(GitProtocolError,
+                          self._handler.set_client_capabilities,
+                          ['weird_ack_level', 'ofs-delta'])
+        try:
+            self._handler.set_client_capabilities(['include-tag'])
+        except GitProtocolError:
+            self.fail()
+
+
+class TestCommit(object):
+    def __init__(self, sha, parents, commit_time):
+        self.id = sha
+        self._parents = parents
+        self.commit_time = commit_time
+
+    def get_parents(self):
+        return self._parents
+
+    def __repr__(self):
+        return '%s(%s)' % (self.__class__.__name__, self._sha)
+
+
+class TestBackend(object):
+    def __init__(self, objects):
+        self.object_store = objects
+
+
+class TestHandler(object):
+    def __init__(self, objects, proto):
+        self.backend = TestBackend(objects)
+        self.proto = proto
+        self.stateless_rpc = False
+        self.advertise_refs = False
+
+    def capabilities(self):
+        return 'multi_ack'
+
+
+class ProtocolGraphWalkerTestCase(TestCase):
+    def setUp(self):
+        # Create the following commit tree:
+        #   3---5
+        #  /
+        # 1---2---4
+        self._objects = {
+            ONE: TestCommit(ONE, [], 111),
+            TWO: TestCommit(TWO, [ONE], 222),
+            THREE: TestCommit(THREE, [ONE], 333),
+            FOUR: TestCommit(FOUR, [TWO], 444),
+            FIVE: TestCommit(FIVE, [THREE], 555),
+            }
+        self._walker = ProtocolGraphWalker(
+            TestHandler(self._objects, TestProto()))
+
+    def test_is_satisfied_no_haves(self):
+        self.assertFalse(self._walker._is_satisfied([], ONE, 0))
+        self.assertFalse(self._walker._is_satisfied([], TWO, 0))
+        self.assertFalse(self._walker._is_satisfied([], THREE, 0))
+
+    def test_is_satisfied_have_root(self):
+        self.assertTrue(self._walker._is_satisfied([ONE], ONE, 0))
+        self.assertTrue(self._walker._is_satisfied([ONE], TWO, 0))
+        self.assertTrue(self._walker._is_satisfied([ONE], THREE, 0))
+
+    def test_is_satisfied_have_branch(self):
+        self.assertTrue(self._walker._is_satisfied([TWO], TWO, 0))
+        # wrong branch
+        self.assertFalse(self._walker._is_satisfied([TWO], THREE, 0))
+
+    def test_all_wants_satisfied(self):
+        self._walker.set_wants([FOUR, FIVE])
+        # trivial case: wants == haves
+        self.assertTrue(self._walker.all_wants_satisfied([FOUR, FIVE]))
+        # cases that require walking the commit tree
+        self.assertTrue(self._walker.all_wants_satisfied([ONE]))
+        self.assertFalse(self._walker.all_wants_satisfied([TWO]))
+        self.assertFalse(self._walker.all_wants_satisfied([THREE]))
+        self.assertTrue(self._walker.all_wants_satisfied([TWO, THREE]))
+
+    def test_read_proto_line(self):
+        self._walker.proto.set_output([
+            'want %s' % ONE,
+            'want %s' % TWO,
+            'have %s' % THREE,
+            'foo %s' % FOUR,
+            'bar',
+            'done',
+            ])
+        self.assertEquals(('want', ONE), self._walker.read_proto_line())
+        self.assertEquals(('want', TWO), self._walker.read_proto_line())
+        self.assertEquals(('have', THREE), self._walker.read_proto_line())
+        self.assertRaises(GitProtocolError, self._walker.read_proto_line)
+        self.assertRaises(GitProtocolError, self._walker.read_proto_line)
+        self.assertEquals(('done', None), self._walker.read_proto_line())
+        self.assertEquals((None, None), self._walker.read_proto_line())
+
+    def test_determine_wants(self):
+        self.assertRaises(GitProtocolError, self._walker.determine_wants, {})
+
+        self._walker.proto.set_output([
+            'want %s multi_ack' % ONE,
+            'want %s' % TWO,
+            ])
+        heads = {'ref1': ONE, 'ref2': TWO, 'ref3': THREE}
+        self.assertEquals([ONE, TWO], self._walker.determine_wants(heads))
+
+        self._walker.proto.set_output(['want %s multi_ack' % FOUR])
+        self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
+
+        self._walker.proto.set_output([])
+        self.assertEquals([], self._walker.determine_wants(heads))
+
+        self._walker.proto.set_output(['want %s multi_ack' % ONE, 'foo'])
+        self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
+
+        self._walker.proto.set_output(['want %s multi_ack' % FOUR])
+        self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
+
+    # TODO: test commit time cutoff
+
+
+class TestProtocolGraphWalker(object):
+    def __init__(self):
+        self.acks = []
+        self.lines = []
+        self.done = False
+        self.stateless_rpc = False
+        self.advertise_refs = False
+
+    def read_proto_line(self):
+        return self.lines.pop(0)
+
+    def send_ack(self, sha, ack_type=''):
+        self.acks.append((sha, ack_type))
+
+    def send_nak(self):
+        self.acks.append((None, 'nak'))
+
+    def all_wants_satisfied(self, haves):
+        return self.done
+
+    def pop_ack(self):
+        if not self.acks:
+            return None
+        return self.acks.pop(0)
+
+
+class AckGraphWalkerImplTestCase(TestCase):
+    """Base setup and asserts for AckGraphWalker tests."""
+    def setUp(self):
+        self._walker = TestProtocolGraphWalker()
+        self._walker.lines = [
+            ('have', TWO),
+            ('have', ONE),
+            ('have', THREE),
+            ('done', None),
+            ]
+        self._impl = self.impl_cls(self._walker)
+
+    def assertNoAck(self):
+        self.assertEquals(None, self._walker.pop_ack())
+
+    def assertAcks(self, acks):
+        for sha, ack_type in acks:
+            self.assertEquals((sha, ack_type), self._walker.pop_ack())
+        self.assertNoAck()
+
+    def assertAck(self, sha, ack_type=''):
+        self.assertAcks([(sha, ack_type)])
+
+    def assertNak(self):
+        self.assertAck(None, 'nak')
+
+    def assertNextEquals(self, sha):
+        self.assertEquals(sha, self._impl.next())
+
+
+class SingleAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
+    impl_cls = SingleAckGraphWalkerImpl
+
+    def test_single_ack(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self._walker.done = True
+        self._impl.ack(ONE)
+        self.assertAck(ONE)
+
+        self.assertNextEquals(THREE)
+        self._impl.ack(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNoAck()
+
+    def test_single_ack_flush(self):
+        # same as ack test but ends with a flush-pkt instead of done
+        self._walker.lines[-1] = (None, None)
+
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self._walker.done = True
+        self._impl.ack(ONE)
+        self.assertAck(ONE)
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNoAck()
+
+    def test_single_ack_nak(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNoAck()
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNak()
+
+    def test_single_ack_nak_flush(self):
+        # same as nak test but ends with a flush-pkt instead of done
+        self._walker.lines[-1] = (None, None)
+
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNoAck()
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNak()
+
+class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
+    impl_cls = MultiAckGraphWalkerImpl
+
+    def test_multi_ack(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self._walker.done = True
+        self._impl.ack(ONE)
+        self.assertAck(ONE, 'continue')
+
+        self.assertNextEquals(THREE)
+        self._impl.ack(THREE)
+        self.assertAck(THREE, 'continue')
+
+        self.assertNextEquals(None)
+        self.assertAck(THREE)
+
+    def test_multi_ack_partial(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self._impl.ack(ONE)
+        self.assertAck(ONE, 'continue')
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        # done, re-send ack of last common
+        self.assertAck(ONE)
+
+    def test_multi_ack_flush(self):
+        self._walker.lines = [
+            ('have', TWO),
+            (None, None),
+            ('have', ONE),
+            ('have', THREE),
+            ('done', None),
+            ]
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNak() # nak the flush-pkt
+
+        self._walker.done = True
+        self._impl.ack(ONE)
+        self.assertAck(ONE, 'continue')
+
+        self.assertNextEquals(THREE)
+        self._impl.ack(THREE)
+        self.assertAck(THREE, 'continue')
+
+        self.assertNextEquals(None)
+        self.assertAck(THREE)
+
+    def test_multi_ack_nak(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNoAck()
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNak()
+
+class MultiAckDetailedGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
+    impl_cls = MultiAckDetailedGraphWalkerImpl
+
+    def test_multi_ack(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self._walker.done = True
+        self._impl.ack(ONE)
+        self.assertAcks([(ONE, 'common'), (ONE, 'ready')])
+
+        self.assertNextEquals(THREE)
+        self._impl.ack(THREE)
+        self.assertAck(THREE, 'ready')
+
+        self.assertNextEquals(None)
+        self.assertAck(THREE)
+
+    def test_multi_ack_partial(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self._impl.ack(ONE)
+        self.assertAck(ONE, 'common')
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        # done, re-send ack of last common
+        self.assertAck(ONE)
+
+    def test_multi_ack_flush(self):
+        # same as ack test but contains a flush-pkt in the middle
+        self._walker.lines = [
+            ('have', TWO),
+            (None, None),
+            ('have', ONE),
+            ('have', THREE),
+            ('done', None),
+            ]
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNak() # nak the flush-pkt
+
+        self._walker.done = True
+        self._impl.ack(ONE)
+        self.assertAcks([(ONE, 'common'), (ONE, 'ready')])
+
+        self.assertNextEquals(THREE)
+        self._impl.ack(THREE)
+        self.assertAck(THREE, 'ready')
+
+        self.assertNextEquals(None)
+        self.assertAck(THREE)
+
+    def test_multi_ack_nak(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNoAck()
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNak()
+
+    def test_multi_ack_nak_flush(self):
+        # same as nak test but contains a flush-pkt in the middle
+        self._walker.lines = [
+            ('have', TWO),
+            (None, None),
+            ('have', ONE),
+            ('have', THREE),
+            ('done', None),
+            ]
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNak()
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNak()
+
+    def test_multi_ack_stateless(self):
+        # transmission ends with a flush-pkt
+        self._walker.lines[-1] = (None, None)
+        self._walker.stateless_rpc = True
+
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNoAck()
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNak()

+ 289 - 0
dulwich/tests/test_web.py

@@ -0,0 +1,289 @@
+# test_web.py -- Tests for the git HTTP server
+# Copryight (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# or (at your option) any later version of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Tests for the Git HTTP server."""
+
+from cStringIO import StringIO
+import re
+from unittest import TestCase
+
+from dulwich.objects import (
+    type_map,
+    Tag,
+    Blob,
+    )
+from dulwich.web import (
+    HTTP_OK,
+    HTTP_NOT_FOUND,
+    HTTP_FORBIDDEN,
+    send_file,
+    get_info_refs,
+    handle_service_request,
+    _LengthLimitedFile,
+    HTTPGitRequest,
+    HTTPGitApplication,
+    )
+
+
+class WebTestCase(TestCase):
+    """Base TestCase that sets up some useful instance vars."""
+    def setUp(self):
+        self._environ = {}
+        self._req = HTTPGitRequest(self._environ, self._start_response)
+        self._status = None
+        self._headers = []
+
+    def _start_response(self, status, headers):
+        self._status = status
+        self._headers = list(headers)
+
+
+class DumbHandlersTestCase(WebTestCase):
+
+    def test_send_file_not_found(self):
+        list(send_file(self._req, None, 'text/plain'))
+        self.assertEquals(HTTP_NOT_FOUND, self._status)
+
+    def test_send_file(self):
+        f = StringIO('foobar')
+        output = ''.join(send_file(self._req, f, 'text/plain'))
+        self.assertEquals('foobar', output)
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertTrue(('Content-Type', 'text/plain') in self._headers)
+        self.assertTrue(f.closed)
+
+    def test_send_file_buffered(self):
+        bufsize = 10240
+        xs = 'x' * bufsize
+        f = StringIO(2 * xs)
+        self.assertEquals([xs, xs],
+                          list(send_file(self._req, f, 'text/plain')))
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertTrue(('Content-Type', 'text/plain') in self._headers)
+        self.assertTrue(f.closed)
+
+    def test_send_file_error(self):
+        class TestFile(object):
+            def __init__(self):
+                self.closed = False
+
+            def read(self, size=-1):
+                raise IOError
+
+            def close(self):
+                self.closed = True
+
+        f = TestFile()
+        list(send_file(self._req, f, 'text/plain'))
+        self.assertEquals(HTTP_NOT_FOUND, self._status)
+        self.assertTrue(f.closed)
+
+    def test_get_info_refs(self):
+        self._environ['QUERY_STRING'] = ''
+
+        class TestTag(object):
+            type = Tag().type
+
+            def __init__(self, sha, obj_type, obj_sha):
+                self.sha = lambda: sha
+                self.object = (obj_type, obj_sha)
+
+        class TestBlob(object):
+            type = Blob().type
+
+            def __init__(self, sha):
+                self.sha = lambda: sha
+
+        blob1 = TestBlob('111')
+        blob2 = TestBlob('222')
+        blob3 = TestBlob('333')
+
+        tag1 = TestTag('aaa', TestTag.type, 'bbb')
+        tag2 = TestTag('bbb', TestBlob.type, '222')
+
+        class TestBackend(object):
+            def __init__(self):
+                objects = [blob1, blob2, blob3, tag1, tag2]
+                self.repo = dict((o.sha(), o) for o in objects)
+
+            def get_refs(self):
+                return {
+                    'HEAD': '000',
+                    'refs/heads/master': blob1.sha(),
+                    'refs/tags/tag-tag': tag1.sha(),
+                    'refs/tags/blob-tag': blob3.sha(),
+                    }
+
+        self.assertEquals(['111\trefs/heads/master\n',
+                           '333\trefs/tags/blob-tag\n',
+                           'aaa\trefs/tags/tag-tag\n',
+                           '222\trefs/tags/tag-tag^{}\n'],
+                          list(get_info_refs(self._req, TestBackend(), None)))
+
+
+class SmartHandlersTestCase(WebTestCase):
+
+    class TestProtocol(object):
+        def __init__(self, handler):
+            self._handler = handler
+
+        def write_pkt_line(self, line):
+            if line is None:
+                self._handler.write('flush-pkt\n')
+            else:
+                self._handler.write('pkt-line: %s' % line)
+
+    class _TestUploadPackHandler(object):
+        def __init__(self, backend, read, write, stateless_rpc=False,
+                     advertise_refs=False):
+            self.read = read
+            self.write = write
+            self.proto = SmartHandlersTestCase.TestProtocol(self)
+            self.stateless_rpc = stateless_rpc
+            self.advertise_refs = advertise_refs
+
+        def handle(self):
+            self.write('handled input: %s' % self.read())
+
+    def _MakeHandler(self, *args, **kwargs):
+        self._handler = self._TestUploadPackHandler(*args, **kwargs)
+        return self._handler
+
+    def services(self):
+        return {'git-upload-pack': self._MakeHandler}
+
+    def test_handle_service_request_unknown(self):
+        mat = re.search('.*', '/git-evil-handler')
+        list(handle_service_request(self._req, 'backend', mat))
+        self.assertEquals(HTTP_FORBIDDEN, self._status)
+
+    def test_handle_service_request(self):
+        self._environ['wsgi.input'] = StringIO('foo')
+        mat = re.search('.*', '/git-upload-pack')
+        output = ''.join(handle_service_request(self._req, 'backend', mat,
+                                                services=self.services()))
+        self.assertEqual('handled input: foo', output)
+        response_type = 'application/x-git-upload-pack-response'
+        self.assertTrue(('Content-Type', response_type) in self._headers)
+        self.assertFalse(self._handler.advertise_refs)
+        self.assertTrue(self._handler.stateless_rpc)
+
+    def test_handle_service_request_with_length(self):
+        self._environ['wsgi.input'] = StringIO('foobar')
+        self._environ['CONTENT_LENGTH'] = 3
+        mat = re.search('.*', '/git-upload-pack')
+        output = ''.join(handle_service_request(self._req, 'backend', mat,
+                                                services=self.services()))
+        self.assertEqual('handled input: foo', output)
+        response_type = 'application/x-git-upload-pack-response'
+        self.assertTrue(('Content-Type', response_type) in self._headers)
+
+    def test_get_info_refs_unknown(self):
+        self._environ['QUERY_STRING'] = 'service=git-evil-handler'
+        list(get_info_refs(self._req, 'backend', None,
+                           services=self.services()))
+        self.assertEquals(HTTP_FORBIDDEN, self._status)
+
+    def test_get_info_refs(self):
+        self._environ['wsgi.input'] = StringIO('foo')
+        self._environ['QUERY_STRING'] = 'service=git-upload-pack'
+
+        output = ''.join(get_info_refs(self._req, 'backend', None,
+                                       services=self.services()))
+        self.assertEquals(('pkt-line: # service=git-upload-pack\n'
+                           'flush-pkt\n'
+                           # input is ignored by the handler
+                           'handled input: '), output)
+        self.assertTrue(self._handler.advertise_refs)
+        self.assertTrue(self._handler.stateless_rpc)
+
+
+class LengthLimitedFileTestCase(TestCase):
+    def test_no_cutoff(self):
+        f = _LengthLimitedFile(StringIO('foobar'), 1024)
+        self.assertEquals('foobar', f.read())
+
+    def test_cutoff(self):
+        f = _LengthLimitedFile(StringIO('foobar'), 3)
+        self.assertEquals('foo', f.read())
+        self.assertEquals('', f.read())
+
+    def test_multiple_reads(self):
+        f = _LengthLimitedFile(StringIO('foobar'), 3)
+        self.assertEquals('fo', f.read(2))
+        self.assertEquals('o', f.read(2))
+        self.assertEquals('', f.read())
+
+
+class HTTPGitRequestTestCase(WebTestCase):
+    def test_not_found(self):
+        self._req.cache_forever()  # cache headers should be discarded
+        message = 'Something not found'
+        self.assertEquals(message, self._req.not_found(message))
+        self.assertEquals(HTTP_NOT_FOUND, self._status)
+        self.assertEquals(set([('Content-Type', 'text/plain')]),
+                          set(self._headers))
+
+    def test_forbidden(self):
+        self._req.cache_forever()  # cache headers should be discarded
+        message = 'Something not found'
+        self.assertEquals(message, self._req.forbidden(message))
+        self.assertEquals(HTTP_FORBIDDEN, self._status)
+        self.assertEquals(set([('Content-Type', 'text/plain')]),
+                          set(self._headers))
+
+    def test_respond_ok(self):
+        self._req.respond()
+        self.assertEquals([], self._headers)
+        self.assertEquals(HTTP_OK, self._status)
+
+    def test_respond(self):
+        self._req.nocache()
+        self._req.respond(status=402, content_type='some/type',
+                          headers=[('X-Foo', 'foo'), ('X-Bar', 'bar')])
+        self.assertEquals(set([
+            ('X-Foo', 'foo'),
+            ('X-Bar', 'bar'),
+            ('Content-Type', 'some/type'),
+            ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
+            ('Pragma', 'no-cache'),
+            ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
+            ]), set(self._headers))
+        self.assertEquals(402, self._status)
+
+
+class HTTPGitApplicationTestCase(TestCase):
+    def setUp(self):
+        self._app = HTTPGitApplication('backend')
+
+    def test_call(self):
+        def test_handler(req, backend, mat):
+            # tests interface used by all handlers
+            self.assertEquals(environ, req.environ)
+            self.assertEquals('backend', backend)
+            self.assertEquals('/foo', mat.group(0))
+            return 'output'
+
+        self._app.services = {
+            ('GET', re.compile('/foo$')): test_handler,
+        }
+        environ = {
+            'PATH_INFO': '/foo',
+            'REQUEST_METHOD': 'GET',
+            }
+        self.assertEquals('output', self._app(environ, None))

+ 311 - 0
dulwich/web.py

@@ -0,0 +1,311 @@
+# web.py -- WSGI smart-http server
+# Copryight (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# or (at your option) any later version of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""HTTP server for dulwich that implements the git smart HTTP protocol."""
+
+from cStringIO import StringIO
+import cgi
+import os
+import re
+import time
+
+from dulwich.objects import (
+    Tag,
+    num_type_map,
+    )
+from dulwich.repo import (
+    Repo,
+    )
+from dulwich.server import (
+    GitBackend,
+    ReceivePackHandler,
+    UploadPackHandler,
+    )
+
+HTTP_OK = '200 OK'
+HTTP_NOT_FOUND = '404 Not Found'
+HTTP_FORBIDDEN = '403 Forbidden'
+
+
+def date_time_string(self, timestamp=None):
+    # Based on BaseHTTPServer.py in python2.5
+    weekdays = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
+    months = [None,
+              'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
+              'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
+    if timestamp is None:
+        timestamp = time.time()
+    year, month, day, hh, mm, ss, wd, y, z = time.gmtime(timestamp)
+    return '%s, %02d %3s %4d %02d:%02d:%02d GMD' % (
+            weekdays[wd], day, months[month], year, hh, mm, ss)
+
+
+def send_file(req, f, content_type):
+    """Send a file-like object to the request output.
+
+    :param req: The HTTPGitRequest object to send output to.
+    :param f: An open file-like object to send; will be closed.
+    :param content_type: The MIME type for the file.
+    :yield: The contents of the file.
+    """
+    if f is None:
+        yield req.not_found('File not found')
+        return
+    try:
+        try:
+            req.respond(HTTP_OK, content_type)
+            while True:
+                data = f.read(10240)
+                if not data:
+                    break
+                yield data
+        except IOError:
+            yield req.not_found('Error reading file')
+    finally:
+        f.close()
+
+
+def get_text_file(req, backend, mat):
+    req.nocache()
+    return send_file(req, backend.repo.get_named_file(mat.group()),
+                     'text/plain')
+
+
+def get_loose_object(req, backend, mat):
+    sha = mat.group(1) + mat.group(2)
+    object_store = backend.object_store
+    if not object_store.contains_loose(sha):
+        yield req.not_found('Object not found')
+        return
+    try:
+        data = object_store[sha].as_legacy_object()
+    except IOError:
+        yield req.not_found('Error reading object')
+    req.cache_forever()
+    req.respond(HTTP_OK, 'application/x-git-loose-object')
+    yield data
+
+
+def get_pack_file(req, backend, mat):
+    req.cache_forever()
+    return send_file(req, backend.repo.get_named_file(mat.group()),
+                     'application/x-git-packed-objects', False)
+
+
+def get_idx_file(req, backend, mat):
+    req.cache_forever()
+    return send_file(req, backend.repo.get_named_file(mat.group()),
+                     'application/x-git-packed-objects-toc', False)
+
+
+services = {'git-upload-pack': UploadPackHandler,
+            'git-receive-pack': ReceivePackHandler}
+def get_info_refs(req, backend, mat, services=None):
+    if services is None:
+        services = services
+    params = cgi.parse_qs(req.environ['QUERY_STRING'])
+    service = params.get('service', [None])[0]
+    if service:
+        handler_cls = services.get(service, None)
+        if handler_cls is None:
+            yield req.forbidden('Unsupported service %s' % service)
+            return
+        req.nocache()
+        req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
+        output = StringIO()
+        dummy_input = StringIO()  # GET request, handler doesn't need to read
+        handler = handler_cls(backend, dummy_input.read, output.write,
+                              stateless_rpc=True, advertise_refs=True)
+        handler.proto.write_pkt_line('# service=%s\n' % service)
+        handler.proto.write_pkt_line(None)
+        handler.handle()
+        yield output.getvalue()
+    else:
+        # non-smart fallback
+        # TODO: select_getanyfile() (see http-backend.c)
+        req.nocache()
+        req.respond(HTTP_OK, 'text/plain')
+        refs = backend.get_refs()
+        for name in sorted(refs.iterkeys()):
+            # get_refs() includes HEAD as a special case, but we don't want to
+            # advertise it
+            if name == 'HEAD':
+                continue
+            sha = refs[name]
+            o = backend.repo[sha]
+            if not o:
+                continue
+            yield '%s\t%s\n' % (sha, name)
+            obj_type = num_type_map[o.type]
+            if obj_type == Tag:
+                while obj_type == Tag:
+                    num_type, sha = o.object
+                    obj_type = num_type_map[num_type]
+                    o = backend.repo[sha]
+                if not o:
+                    continue
+                yield '%s\t%s^{}\n' % (o.sha(), name)
+
+
+def get_info_packs(req, backend, mat):
+    req.nocache()
+    req.respond(HTTP_OK, 'text/plain')
+    for pack in backend.object_store.packs:
+        yield 'P pack-%s.pack\n' % pack.name()
+
+
+class _LengthLimitedFile(object):
+    """Wrapper class to limit the length of reads from a file-like object.
+
+    This is used to ensure EOF is read from the wsgi.input object once
+    Content-Length bytes are read. This behavior is required by the WSGI spec
+    but not implemented in wsgiref as of 2.5.
+    """
+    def __init__(self, input, max_bytes):
+        self._input = input
+        self._bytes_avail = max_bytes
+
+    def read(self, size=-1):
+        if self._bytes_avail <= 0:
+            return ''
+        if size == -1 or size > self._bytes_avail:
+            size = self._bytes_avail
+        self._bytes_avail -= size
+        return self._input.read(size)
+
+    # TODO: support more methods as necessary
+
+def handle_service_request(req, backend, mat, services=services):
+    if services is None:
+        services = services
+    service = mat.group().lstrip('/')
+    handler_cls = services.get(service, None)
+    if handler_cls is None:
+        yield req.forbidden('Unsupported service %s' % service)
+        return
+    req.nocache()
+    req.respond(HTTP_OK, 'application/x-%s-response' % service)
+
+    output = StringIO()
+    input = req.environ['wsgi.input']
+    # This is not necessary if this app is run from a conforming WSGI server.
+    # Unfortunately, there's no way to tell that at this point.
+    # TODO: git may used HTTP/1.1 chunked encoding instead of specifying
+    # content-length
+    if 'CONTENT_LENGTH' in req.environ:
+        input = _LengthLimitedFile(input, int(req.environ['CONTENT_LENGTH']))
+    handler = handler_cls(backend, input.read, output.write, stateless_rpc=True)
+    handler.handle()
+    yield output.getvalue()
+
+
+class HTTPGitRequest(object):
+    """Class encapsulating the state of a single git HTTP request.
+
+    :ivar environ: the WSGI environment for the request.
+    """
+
+    def __init__(self, environ, start_response):
+        self.environ = environ
+        self._start_response = start_response
+        self._cache_headers = []
+        self._headers = []
+
+    def add_header(self, name, value):
+        """Add a header to the response."""
+        self._headers.append((name, value))
+
+    def respond(self, status=HTTP_OK, content_type=None, headers=None):
+        """Begin a response with the given status and other headers."""
+        if headers:
+            self._headers.extend(headers)
+        if content_type:
+            self._headers.append(('Content-Type', content_type))
+        self._headers.extend(self._cache_headers)
+
+        self._start_response(status, self._headers)
+
+    def not_found(self, message):
+        """Begin a HTTP 404 response and return the text of a message."""
+        self._cache_headers = []
+        self.respond(HTTP_NOT_FOUND, 'text/plain')
+        return message
+
+    def forbidden(self, message):
+        """Begin a HTTP 403 response and return the text of a message."""
+        self._cache_headers = []
+        self.respond(HTTP_FORBIDDEN, 'text/plain')
+        return message
+
+    def nocache(self):
+        """Set the response to never be cached by the client."""
+        self._cache_headers = [
+            ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
+            ('Pragma', 'no-cache'),
+            ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
+            ]
+
+    def cache_forever(self):
+        """Set the response to be cached forever by the client."""
+        now = time.time()
+        self._cache_headers = [
+            ('Date', date_time_string(now)),
+            ('Expires', date_time_string(now + 31536000)),
+            ('Cache-Control', 'public, max-age=31536000'),
+            ]
+
+
+class HTTPGitApplication(object):
+    """Class encapsulating the state of a git WSGI application.
+
+    :ivar backend: the Backend object backing this application
+    """
+
+    services = {
+        ('GET', re.compile('/HEAD$')): get_text_file,
+        ('GET', re.compile('/info/refs$')): get_info_refs,
+        ('GET', re.compile('/objects/info/alternates$')): get_text_file,
+        ('GET', re.compile('/objects/info/http-alternates$')): get_text_file,
+        ('GET', re.compile('/objects/info/packs$')): get_info_packs,
+        ('GET', re.compile('/objects/([0-9a-f]{2})/([0-9a-f]{38})$')): get_loose_object,
+        ('GET', re.compile('/objects/pack/pack-([0-9a-f]{40})\\.pack$')): get_pack_file,
+        ('GET', re.compile('/objects/pack/pack-([0-9a-f]{40})\\.idx$')): get_idx_file,
+
+        ('POST', re.compile('/git-upload-pack$')): handle_service_request,
+        ('POST', re.compile('/git-receive-pack$')): handle_service_request,
+    }
+
+    def __init__(self, backend):
+        self.backend = backend
+
+    def __call__(self, environ, start_response):
+        path = environ['PATH_INFO']
+        method = environ['REQUEST_METHOD']
+        req = HTTPGitRequest(environ, start_response)
+        # environ['QUERY_STRING'] has qs args
+        handler = None
+        for smethod, spath in self.services.iterkeys():
+            if smethod != method:
+                continue
+            mat = spath.search(path)
+            if mat:
+                handler = self.services[smethod, spath]
+                break
+        if handler is None:
+            return req.not_found('Sorry, that method is not supported')
+        return handler(req, self.backend, mat)

+ 25 - 7
setup.py

@@ -1,14 +1,14 @@
 #!/usr/bin/python
 #!/usr/bin/python
 # Setup file for bzr-git
 # Setup file for bzr-git
-# Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org>
+# Copyright (C) 2008-2010 Jelmer Vernooij <jelmer@samba.org>
 
 
 try:
 try:
-    from setuptools import setup
+    from setuptools import setup, Extension
 except ImportError:
 except ImportError:
-    from distutils.core import setup
-from distutils.extension import Extension
+    from distutils.core import setup, Extension
+from distutils.core import Distribution
 
 
-dulwich_version_string = '0.4.1'
+dulwich_version_string = '0.5.0'
 
 
 include_dirs = []
 include_dirs = []
 # Windows MSVC support
 # Windows MSVC support
@@ -17,6 +17,22 @@ if sys.platform == 'win32':
     include_dirs.append('dulwich')
     include_dirs.append('dulwich')
 
 
 
 
+class DulwichDistribution(Distribution):
+
+    def is_pure(self):
+        if self.pure:
+            return True
+
+    def has_ext_modules(self):
+        return not self.pure
+
+    global_options = Distribution.global_options + [
+        ('pure', None, 
+            "use pure (slower) Python code instead of C extensions")]
+
+    pure = False
+
+        
 setup(name='dulwich',
 setup(name='dulwich',
       description='Pure-Python Git Library',
       description='Pure-Python Git Library',
       keywords='git',
       keywords='git',
@@ -32,10 +48,12 @@ setup(name='dulwich',
       in one of the Monty Python sketches.
       in one of the Monty Python sketches.
       """,
       """,
       packages=['dulwich', 'dulwich.tests'],
       packages=['dulwich', 'dulwich.tests'],
-      ext_modules=[
+      scripts=['bin/dulwich', 'bin/dul-daemon', 'bin/dul-web'],
+      ext_modules = [
           Extension('dulwich._objects', ['dulwich/_objects.c'],
           Extension('dulwich._objects', ['dulwich/_objects.c'],
                     include_dirs=include_dirs),
                     include_dirs=include_dirs),
           Extension('dulwich._pack', ['dulwich/_pack.c'],
           Extension('dulwich._pack', ['dulwich/_pack.c'],
-                    include_dirs=include_dirs),
+              include_dirs=include_dirs),
           ],
           ],
+      distclass=DulwichDistribution,
       )
       )