瀏覽代碼

merge improvements from Dave to the server, file locking and repository abstraction.

Jelmer Vernooij 15 年之前
父節點
當前提交
ba676109c2

+ 13 - 0
NEWS

@@ -4,6 +4,14 @@
 
   * Support custom fields in commits.
 
+  * Improved ref handling. (Dave Borowitz)
+
+  * Rework server protocol to be smarter and interoperate with cgit client.
+    (Dave Borowitz)
+
+  * Add a GitFile class that uses the same locking protocol for writes as 
+    cgit. (Dave Borowitz)
+
  FEATURES
 
   * --without-speedups option to setup.py to allow building/installing 
@@ -11,6 +19,11 @@
 
   * Implement Repo.get_config(). (Jelmer Vernooij)
 
+  * HTTP dumb and smart server. (Dave Borowitz)
+
+  * Add abstract baseclass for Repo that does not require file system 
+    operations. (Dave Borowitz)
+
 0.4.1	2010-01-03
 
  FEATURES

+ 2 - 1
bin/dul-daemon

@@ -18,6 +18,7 @@
 # MA  02110-1301, USA.
 
 import sys
+from dulwich.repo import Repo
 from dulwich.server import GitBackend, TCPGitServer
 
 if __name__ == "__main__":
@@ -25,6 +26,6 @@ if __name__ == "__main__":
     if len(sys.argv) > 1:
         gitdir = sys.argv[1]
 
-    backend = GitBackend(gitdir)
+    backend = GitBackend(Repo(gitdir))
     server = TCPGitServer(backend, 'localhost')
     server.serve_forever()

+ 2 - 1
bin/dul-receive-pack

@@ -18,6 +18,7 @@
 # MA  02110-1301, USA.
 
 import sys
+from dulwich.repo import Repo
 from dulwich.server import GitBackend, ReceivePackHandler
 
 def send_fn(data):
@@ -29,6 +30,6 @@ if __name__ == "__main__":
     if len(sys.argv) > 1:
         gitdir = sys.argv[1]
 
-    backend = GitBackend(gitdir)
+    backend = GitBackend(Repo(gitdir))
     handler = ReceivePackHandler(backend, sys.stdin.read, send_fn)
     handler.handle()

+ 2 - 1
bin/dul-upload-pack

@@ -18,6 +18,7 @@
 # MA  02110-1301, USA.
 
 import sys
+from dulwich.repo import Repo
 from dulwich.server import GitBackend, UploadPackHandler
 
 def send_fn(data):
@@ -29,6 +30,6 @@ if __name__ == "__main__":
     if len(sys.argv) > 1:
         gitdir = sys.argv[1]
 
-    backend = GitBackend(gitdir)
+    backend = GitBackend(Repo(gitdir))
     handler = UploadPackHandler(backend, sys.stdin.read, send_fn)
     handler.handle()

+ 37 - 0
bin/dul-web

@@ -0,0 +1,37 @@
+#!/usr/bin/python
+# dul-web - HTTP-based git server
+# Copyright (C) 2010 David Borowitz <dborowitz@google.com>
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+import os
+import sys
+from dulwich.repo import Repo
+from dulwich.server import GitBackend
+from dulwich.web import HTTPGitApplication
+from wsgiref.simple_server import make_server
+
+if __name__ == "__main__":
+    if len(sys.argv) > 1:
+        gitdir = sys.argv[1]
+    else:
+        gitdir = os.getcwd()
+
+    backend = GitBackend(Repo(gitdir))
+    app = HTTPGitApplication(backend)
+    # TODO: allow serving on other ports via command-line flag
+    server = make_server('', 8000, app)
+    server.serve_forever()

+ 8 - 0
dulwich/errors.py

@@ -108,3 +108,11 @@ class HangupException(GitProtocolError):
     def __init__(self):
         Exception.__init__(self,
             "The remote server unexpectedly closed the connection.")
+
+
+class FileFormatException(Exception):
+    """Base class for exceptions relating to reading git file formats."""
+
+
+class PackedRefsException(FileFormatException):
+    """Indicates an error parsing a packed-refs file."""

+ 138 - 0
dulwich/file.py

@@ -0,0 +1,138 @@
+# file.py -- Safe access to git files
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) a later version of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+
+"""Safe access to git files."""
+
+
+import errno
+import os
+
+def ensure_dir_exists(dirname):
+    """Ensure a directory exists, creating if necessary."""
+    try:
+        os.makedirs(dirname)
+    except OSError, e:
+        if e.errno != errno.EEXIST:
+            raise
+
+def GitFile(filename, mode='r', bufsize=-1):
+    """Create a file object that obeys the git file locking protocol.
+
+    See _GitFile for a description of the file locking protocol.
+
+    Only read-only and write-only (binary) modes are supported; r+, w+, and a
+    are not.  To read and write from the same file, you can take advantage of
+    the fact that opening a file for write does not actually open the file you
+    request:
+
+    >>> write_file = GitFile('filename', 'wb')
+    >>> read_file = GitFile('filename', 'rb')
+    >>> read_file.readlines()
+    ['contents\n', 'of\n', 'the\n', 'file\n']
+    >>> write_file.write('foo')
+    >>> read_file.close()
+    >>> write_file.close()
+    >>> new_file = GitFile('filename', 'rb')
+    'foo'
+    >>> new_file.close()
+    >>> other_file = GitFile('filename', 'wb')
+    Traceback (most recent call last):
+        ...
+    OSError: [Errno 17] File exists: 'filename.lock'
+
+    :return: a builtin file object or a _GitFile object
+    """
+    if 'a' in mode:
+        raise IOError('append mode not supported for Git files')
+    if '+' in mode:
+        raise IOError('read/write mode not supported for Git files')
+    if 'b' not in mode:
+        raise IOError('text mode not supported for Git files')
+    if 'w' in mode:
+        return _GitFile(filename, mode, bufsize)
+    else:
+        return file(filename, mode, bufsize)
+
+
+class _GitFile(object):
+    """File that follows the git locking protocol for writes.
+
+    All writes to a file foo will be written into foo.lock in the same
+    directory, and the lockfile will be renamed to overwrite the original file
+    on close.
+
+    :note: You *must* call close() or abort() on a _GitFile for the lock to be
+        released. Typically this will happen in a finally block.
+    """
+
+    PROXY_PROPERTIES = set(['closed', 'encoding', 'errors', 'mode', 'name',
+                            'newlines', 'softspace'])
+    PROXY_METHODS = ('__iter__', 'flush', 'fileno', 'isatty', 'next', 'read',
+                     'readline', 'readlines', 'xreadlines', 'seek', 'tell',
+                     'truncate', 'write', 'writelines')
+    def __init__(self, filename, mode, bufsize):
+        self._filename = filename
+        self._lockfilename = '%s.lock' % self._filename
+        fd = os.open(self._lockfilename, os.O_RDWR | os.O_CREAT | os.O_EXCL)
+        self._file = os.fdopen(fd, mode, bufsize)
+        self._closed = False
+
+        for method in self.PROXY_METHODS:
+            setattr(self, method, getattr(self._file, method))
+
+    def abort(self):
+        """Close and discard the lockfile without overwriting the target.
+
+        If the file is already closed, this is a no-op.
+        """
+        if self._closed:
+            return
+        self._file.close()
+        try:
+            os.remove(self._lockfilename)
+            self._closed = True
+        except OSError, e:
+            # The file may have been removed already, which is ok.
+            if e.errno != errno.ENOENT:
+                raise
+
+    def close(self):
+        """Close this file, saving the lockfile over the original.
+
+        :note: If this method fails, it will attempt to delete the lockfile.
+            However, it is not guaranteed to do so (e.g. if a filesystem becomes
+            suddenly read-only), which will prevent future writes to this file
+            until the lockfile is removed manually.
+        :raises OSError: if the original file could not be overwritten. The lock
+            file is still closed, so further attempts to write to the same file
+            object will raise ValueError.
+        """
+        if self._closed:
+            return
+        self._file.close()
+        try:
+            os.rename(self._lockfilename, self._filename)
+        finally:
+            self.abort()
+
+    def __getattr__(self, name):
+        """Proxy property calls to the underlying file."""
+        if name in self.PROXY_PROPERTIES:
+            return getattr(self._file, name)
+        raise AttributeError(name)

+ 3 - 2
dulwich/index.py

@@ -22,6 +22,7 @@ import os
 import stat
 import struct
 
+from dulwich.file import GitFile
 from dulwich.objects import (
     S_IFGITLINK,
     S_ISGITLINK,
@@ -173,7 +174,7 @@ class Index(object):
 
     def write(self):
         """Write current contents of index to disk."""
-        f = open(self._filename, 'wb')
+        f = GitFile(self._filename, 'wb')
         try:
             f = SHA1Writer(f)
             write_index_dict(f, self._byname)
@@ -182,7 +183,7 @@ class Index(object):
 
     def read(self):
         """Read current contents of index from disk."""
-        f = open(self._filename, 'rb')
+        f = GitFile(self._filename, 'rb')
         try:
             f = SHA1Reader(f)
             for x in read_index(f):

+ 47 - 13
dulwich/object_store.py

@@ -20,6 +20,7 @@
 """Git object store interfaces and implementation."""
 
 
+import errno
 import itertools
 import os
 import stat
@@ -29,6 +30,7 @@ import urllib2
 from dulwich.errors import (
     NotTreeError,
     )
+from dulwich.file import GitFile
 from dulwich.objects import (
     Commit,
     ShaFile,
@@ -65,9 +67,25 @@ class BaseObjectStore(object):
         """
         return ObjectStoreIterator(self, shas)
 
+    def contains_loose(self, sha):
+        """Check if a particular object is present by SHA1 and is loose."""
+        raise NotImplementedError(self.contains_loose)
+
+    def contains_packed(self, sha):
+        """Check if a particular object is present by SHA1 and is packed."""
+        raise NotImplementedError(self.contains_packed)
+
     def __contains__(self, sha):
-        """Check if a particular object is present by SHA1."""
-        raise NotImplementedError(self.__contains__)
+        """Check if a particular object is present by SHA1.
+
+        This method makes no distinction between loose and packed objects.
+        """
+        return self.contains_packed(sha) or self.contains_loose(sha)
+
+    @property
+    def packs(self):
+        """Iterable of pack objects."""
+        raise NotImplementedError
 
     def get_raw(self, name):
         """Obtain the raw text for an object.
@@ -232,14 +250,15 @@ class DiskObjectStore(BaseObjectStore):
         self._pack_cache = None
         self.pack_dir = os.path.join(self.path, PACKDIR)
 
-    def __contains__(self, sha):
-        """Check if a particular object is present by SHA1."""
+    def contains_loose(self, sha):
+        """Check if a particular object is present by SHA1 and is loose."""
+        return self._get_shafile(sha) is not None
+
+    def contains_packed(self, sha):
+        """Check if a particular object is present by SHA1 and is packed."""
         for pack in self.packs:
             if sha in pack:
                 return True
-        ret = self._get_shafile(sha)
-        if ret is not None:
-            return True
         return False
 
     def __iter__(self):
@@ -251,15 +270,21 @@ class DiskObjectStore(BaseObjectStore):
     def packs(self):
         """List with pack objects."""
         if self._pack_cache is None:
-            self._pack_cache = list(self._load_packs())
+            self._pack_cache = self._load_packs()
         return self._pack_cache
 
     def _load_packs(self):
         if not os.path.exists(self.pack_dir):
-            return
+            return []
+        pack_files = []
         for name in os.listdir(self.pack_dir):
+            # TODO: verify that idx exists first
             if name.startswith("pack-") and name.endswith(".pack"):
-                yield Pack(os.path.join(self.pack_dir, name[:-len(".pack")]))
+                filename = os.path.join(self.pack_dir, name)
+                pack_files.append((os.stat(filename).st_mtime, filename))
+        pack_files.sort(reverse=True)
+        suffix_len = len(".pack")
+        return [Pack(f[:-suffix_len]) for _, f in pack_files]
 
     def _add_known_pack(self, path):
         """Add a newly appeared pack to the cache by path.
@@ -294,7 +319,7 @@ class DiskObjectStore(BaseObjectStore):
         path = os.path.join(dir, sha[2:])
         if os.path.exists(path):
             return # Already there, no need to write again
-        f = open(path, 'w+')
+        f = GitFile(path, 'wb')
         try:
             f.write(o.as_legacy_object())
         finally:
@@ -427,14 +452,23 @@ class MemoryObjectStore(BaseObjectStore):
         super(MemoryObjectStore, self).__init__()
         self._data = {}
 
-    def __contains__(self, sha):
-        """Check if the object with a particular SHA is present."""
+    def contains_loose(self, sha):
+        """Check if a particular object is present by SHA1 and is loose."""
         return sha in self._data
 
+    def contains_packed(self, sha):
+        """Check if a particular object is present by SHA1 and is packed."""
+        return False
+
     def __iter__(self):
         """Iterate over the SHAs that are present in this store."""
         return self._data.iterkeys()
 
+    @property
+    def packs(self):
+        """List with pack objects."""
+        return []
+
     def get_raw(self, name):
         """Obtain the raw text for an object.
         

+ 2 - 2
dulwich/objects.py

@@ -36,6 +36,7 @@ from dulwich.errors import (
     NotCommitError,
     NotTreeError,
     )
+from dulwich.file import GitFile
 from dulwich.misc import (
     make_sha,
     )
@@ -184,7 +185,7 @@ class ShaFile(object):
     def from_file(cls, filename):
         """Get the contents of a SHA file on disk"""
         size = os.path.getsize(filename)
-        f = open(filename, 'rb')
+        f = GitFile(filename, 'rb')
         try:
             map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ)
             shafile = cls._parse_file(map)
@@ -637,4 +638,3 @@ try:
     from dulwich._objects import parse_tree
 except ImportError:
     pass
-

+ 7 - 6
dulwich/pack.py

@@ -55,6 +55,7 @@ from dulwich.errors import (
     ApplyDeltaError,
     ChecksumMismatch,
     )
+from dulwich.file import GitFile
 from dulwich.lru_cache import (
     LRUSizeCache,
     )
@@ -150,7 +151,7 @@ def load_pack_index(filename):
 
     :param filename: Path to the index file
     """
-    f = open(filename, 'rb')
+    f = GitFile(filename, 'rb')
     if f.read(4) == '\377tOc':
         version = struct.unpack(">L", f.read(4))[0]
         if version == 2:
@@ -211,7 +212,7 @@ class PackIndex(object):
         # ensure that it hasn't changed.
         self._size = os.path.getsize(filename)
         if file is None:
-            self._file = open(filename, 'rb')
+            self._file = GitFile(filename, 'rb')
         else:
             self._file = file
         self._contents, map_offset = simple_mmap(self._file, 0, self._size)
@@ -497,7 +498,7 @@ class PackData(object):
         self._size = os.path.getsize(filename)
         self._header_size = 12
         assert self._size >= self._header_size, "%s is too small for a packfile (%d < %d)" % (filename, self._size, self._header_size)
-        self._file = open(self._filename, 'rb')
+        self._file = GitFile(self._filename, 'rb')
         self._read_header()
         self._offset_cache = LRUSizeCache(1024*1024*20, 
             compute_size=_compute_object_size)
@@ -809,7 +810,7 @@ def write_pack(filename, objects, num_objects):
     :param objects: Iterable over (object, path) tuples to write
     :param num_objects: Number of objects to write
     """
-    f = open(filename + ".pack", 'wb')
+    f = GitFile(filename + ".pack", 'wb')
     try:
         entries, data_sum = write_pack_data(f, objects, num_objects)
     finally:
@@ -873,7 +874,7 @@ def write_pack_index_v1(filename, entries, pack_checksum):
             crc32_checksum.
     :param pack_checksum: Checksum of the pack file.
     """
-    f = open(filename, 'wb')
+    f = GitFile(filename, 'wb')
     f = SHA1Writer(f)
     fan_out_table = defaultdict(lambda: 0)
     for (name, offset, entry_checksum) in entries:
@@ -1021,7 +1022,7 @@ def write_pack_index_v2(filename, entries, pack_checksum):
             crc32_checksum.
     :param pack_checksum: Checksum of the pack file.
     """
-    f = open(filename, 'wb')
+    f = GitFile(filename, 'wb')
     f = SHA1Writer(f)
     f.write('\377tOc') # Magic!
     f.write(struct.pack(">L", 2))

+ 33 - 5
dulwich/protocol.py

@@ -28,6 +28,10 @@ from dulwich.errors import (
 
 TCP_GIT_PORT = 9418
 
+SINGLE_ACK = 0
+MULTI_ACK = 1
+MULTI_ACK_DETAILED = 2
+
 class ProtocolFile(object):
     """
     Some network ops are like file ops. The file ops expect to operate on
@@ -160,11 +164,35 @@ def extract_capabilities(text):
     """Extract a capabilities list from a string, if present.
 
     :param text: String to extract from
-    :return: Tuple with text with capabilities removed and list of 
-        capabilities or None (if no capabilities were present.
+    :return: Tuple with text with capabilities removed and list of capabilities
     """
     if not "\0" in text:
-        return text, None
-    capabilities = text.split("\0")
-    return (capabilities[0], capabilities[1:])
+        return text, []
+    text, capabilities = text.rstrip().split("\0")
+    return (text, capabilities.split(" "))
+
+
+def extract_want_line_capabilities(text):
+    """Extract a capabilities list from a want line, if present.
 
+    Note that want lines have capabilities separated from the rest of the line
+    by a space instead of a null byte. Thus want lines have the form:
+
+        want obj-id cap1 cap2 ...
+
+    :param text: Want line to extract from
+    :return: Tuple with text with capabilities removed and list of capabilities
+    """
+    split_text = text.rstrip().split(" ")
+    if len(split_text) < 3:
+        return text, []
+    return (" ".join(split_text[:2]), split_text[2:])
+
+
+def ack_type(capabilities):
+    """Extract the ack type from a capabilities list."""
+    if 'multi_ack_detailed' in capabilities:
+      return MULTI_ACK_DETAILED
+    elif 'multi_ack' in capabilities:
+        return MULTI_ACK
+    return SINGLE_ACK

+ 463 - 137
dulwich/repo.py

@@ -22,6 +22,7 @@
 """Repository access."""
 
 
+import errno
 import os
 import stat
 
@@ -31,6 +32,11 @@ from dulwich.errors import (
     NotCommitError, 
     NotGitRepository,
     NotTreeError, 
+    PackedRefsException,
+    )
+from dulwich.file import (
+    ensure_dir_exists,
+    GitFile,
     )
 from dulwich.object_store import (
     DiskObjectStore,
@@ -41,6 +47,7 @@ from dulwich.objects import (
     ShaFile,
     Tag,
     Tree,
+    hex_to_sha,
     )
 
 OBJECTDIR = 'objects'
@@ -51,20 +58,36 @@ REFSDIR_HEADS = 'heads'
 INDEX_FILENAME = "index"
 
 
-def follow_ref(container, name):
-    """Follow a ref back to a SHA1.
-    
-    :param container: Ref container to use for looking up refs.
-    :param name: Name of the original ref.
+def check_ref_format(refname):
+    """Check if a refname is correctly formatted.
+
+    Implements all the same rules as git-check-ref-format[1].
+
+    [1] http://www.kernel.org/pub/software/scm/git/docs/git-check-ref-format.html
+
+    :param refname: The refname to check
+    :return: True if refname is valid, False otherwise
     """
-    contents = container[name]
-    if contents.startswith(SYMREF):
-        ref = contents[len(SYMREF):]
-        if ref[-1] == '\n':
-            ref = ref[:-1]
-        return follow_ref(container, ref)
-    assert len(contents) == 40, 'Invalid ref in %s' % name
-    return contents
+    # These could be combined into one big expression, but are listed separately
+    # to parallel [1].
+    if '/.' in refname or refname.startswith('.'):
+        return False
+    if '/' not in refname:
+        return False
+    if '..' in refname:
+        return False
+    for c in refname:
+        if ord(c) < 040 or c in '\177 ~^:?*[':
+            return False
+    if refname[-1] in '/.':
+        return False
+    if refname.endswith('.lock'):
+        return False
+    if '@{' in refname:
+        return False
+    if '\\' in refname:
+        return False
+    return True
 
 
 class RefsContainer(object):
@@ -74,13 +97,6 @@ class RefsContainer(object):
         """Return the contents of this ref container under base as a dict."""
         raise NotImplementedError(self.as_dict)
 
-    def follow(self, name):
-        """Follow a ref name back to a SHA1.
-        
-        :param name: Name of the ref
-        """
-        return follow_ref(self, name)
-
     def set_ref(self, name, other):
         """Make a ref point at another ref.
 
@@ -99,54 +115,68 @@ class DiskRefsContainer(RefsContainer):
 
     def __init__(self, path):
         self.path = path
+        self._packed_refs = None
+        self._peeled_refs = {}
 
     def __repr__(self):
         return "%s(%r)" % (self.__class__.__name__, self.path)
 
     def keys(self, base=None):
-        """Refs present in this container."""
-        return list(self.iterkeys(base))
+        """Refs present in this container.
 
-    def iterkeys(self, base=None):
+        :param base: An optional base to return refs under
+        :return: An unsorted set of valid refs in this container, including
+            packed refs.
+        """
         if base is not None:
-            return self.itersubkeys(base)
+            return self.subkeys(base)
         else:
-            return self.iterallkeys()
+            return self.allkeys()
 
-    def itersubkeys(self, base):
+    def subkeys(self, base):
+        keys = set()
         path = self.refpath(base)
         for root, dirs, files in os.walk(path):
-            dir = root[len(path):].strip("/").replace(os.path.sep, "/")
+            dir = root[len(path):].strip(os.path.sep).replace(os.path.sep, "/")
             for filename in files:
-                yield ("%s/%s" % (dir, filename)).strip("/")
-
-    def iterallkeys(self):
+                refname = ("%s/%s" % (dir, filename)).strip("/")
+                # check_ref_format requires at least one /, so we prepend the
+                # base before calling it.
+                if check_ref_format("%s/%s" % (base, refname)):
+                    keys.add(refname)
+        for key in self.get_packed_refs():
+            if key.startswith(base):
+                keys.add(key[len(base):].strip("/"))
+        return keys
+
+    def allkeys(self):
+        keys = set()
         if os.path.exists(self.refpath("HEAD")):
-            yield "HEAD"
+            keys.add("HEAD")
         path = self.refpath("")
         for root, dirs, files in os.walk(self.refpath("refs")):
-            dir = root[len(path):].strip("/").replace(os.path.sep, "/")
+            dir = root[len(path):].strip(os.path.sep).replace(os.path.sep, "/")
             for filename in files:
-                yield ("%s/%s" % (dir, filename)).strip("/")
+                refname = ("%s/%s" % (dir, filename)).strip("/")
+                if check_ref_format(refname):
+                    keys.add(refname)
+        keys.update(self.get_packed_refs())
+        return keys
 
-    def as_dict(self, base=None, follow=True):
+    def as_dict(self, base=None):
         """Return the contents of this container as a dictionary.
 
         """
         ret = {}
+        keys = self.keys(base)
         if base is None:
-            keys = self.iterkeys()
             base = ""
-        else:
-            keys = self.itersubkeys(base)
         for key in keys:
-                if follow:
-                    try:
-                        ret[key] = self.follow(("%s/%s" % (base, key)).strip("/"))
-                    except KeyError:
-                        continue # Unable to resolve
-                else:
-                    ret[key] = self[("%s/%s" % (base, key)).strip("/")]
+            try:
+                ret[key] = self[("%s/%s" % (base, key)).strip("/")]
+            except KeyError:
+                continue # Unable to resolve
+
         return ret
 
     def refpath(self, name):
@@ -157,90 +187,353 @@ class DiskRefsContainer(RefsContainer):
             name = name.replace("/", os.path.sep)
         return os.path.join(self.path, name)
 
+    def get_packed_refs(self):
+        """Get contents of the packed-refs file.
+
+        :return: Dictionary mapping ref names to SHA1s
+
+        :note: Will return an empty dictionary when no packed-refs file is
+            present.
+        """
+        # TODO: invalidate the cache on repacking
+        if self._packed_refs is None:
+            self._packed_refs = {}
+            path = os.path.join(self.path, 'packed-refs')
+            try:
+                f = GitFile(path, 'rb')
+            except IOError, e:
+                if e.errno == errno.ENOENT:
+                    return {}
+                raise
+            try:
+                first_line = iter(f).next().rstrip()
+                if (first_line.startswith("# pack-refs") and " peeled" in
+                        first_line):
+                    for sha, name, peeled in read_packed_refs_with_peeled(f):
+                        self._packed_refs[name] = sha
+                        if peeled:
+                            self._peeled_refs[name] = peeled
+                else:
+                    f.seek(0)
+                    for sha, name in read_packed_refs(f):
+                        self._packed_refs[name] = sha
+            finally:
+                f.close()
+        return self._packed_refs
+
+    def _check_refname(self, name):
+        """Ensure a refname is valid and lives in refs or is HEAD.
+
+        HEAD is not a valid refname according to git-check-ref-format, but this
+        class needs to be able to touch HEAD. Also, check_ref_format expects
+        refnames without the leading 'refs/', but this class requires that
+        so it cannot touch anything outside the refs dir (or HEAD).
+
+        :param name: The name of the reference.
+        :raises KeyError: if a refname is not HEAD or is otherwise not valid.
+        """
+        if name == 'HEAD':
+            return
+        if not name.startswith('refs/') or not check_ref_format(name[5:]):
+            raise KeyError(name)
+
+    def _read_ref_file(self, name):
+        """Read a reference file and return its contents.
+
+        If the reference file a symbolic reference, only read the first line of
+        the file. Otherwise, only read the first 40 bytes.
+
+        :param name: the refname to read, relative to refpath
+        :return: The contents of the ref file, or None if the file does not
+            exist.
+        :raises IOError: if any other error occurs
+        """
+        filename = self.refpath(name)
+        try:
+            f = GitFile(filename, 'rb')
+            try:
+                header = f.read(len(SYMREF))
+                if header == SYMREF:
+                    # Read only the first line
+                    return header + iter(f).next().rstrip("\n")
+                else:
+                    # Read only the first 40 bytes
+                    return header + f.read(40-len(SYMREF))
+            finally:
+                f.close()
+        except IOError, e:
+            if e.errno == errno.ENOENT:
+                return None
+            raise
+
+    def _follow(self, name):
+        """Follow a reference name.
+
+        :return: a tuple of (refname, sha), where refname is the name of the
+            last reference in the symbolic reference chain
+        """
+        self._check_refname(name)
+        contents = SYMREF + name
+        depth = 0
+        while contents.startswith(SYMREF):
+            refname = contents[len(SYMREF):]
+            contents = self._read_ref_file(refname)
+            if not contents:
+                contents = self.get_packed_refs().get(refname, None)
+                if not contents:
+                    break
+            depth += 1
+            if depth > 5:
+                raise KeyError(name)
+        return refname, contents
+
     def __getitem__(self, name):
-        file = self.refpath(name)
-        if not os.path.exists(file):
+        """Get the SHA1 for a reference name.
+
+        This method follows all symbolic references.
+        """
+        _, sha = self._follow(name)
+        if sha is None:
             raise KeyError(name)
-        f = open(file, 'rb')
+        return sha
+
+    def _remove_packed_ref(self, name):
+        if self._packed_refs is None:
+            return
+        filename = os.path.join(self.path, 'packed-refs')
+        # reread cached refs from disk, while holding the lock
+        f = GitFile(filename, 'wb')
+        try:
+            self._packed_refs = None
+            self.get_packed_refs()
+
+            if name not in self._packed_refs:
+                return
+
+            del self._packed_refs[name]
+            if name in self._peeled_refs:
+                del self._peeled_refs[name]
+            write_packed_refs(f, self._packed_refs, self._peeled_refs)
+            f.close()
+        finally:
+            f.abort()
+
+    def set_if_equals(self, name, old_ref, new_ref):
+        """Set a refname to new_ref only if it currently equals old_ref.
+
+        This method follows all symbolic references, and can be used to perform
+        an atomic compare-and-swap operation.
+
+        :param name: The refname to set.
+        :param old_ref: The old sha the refname must refer to, or None to set
+            unconditionally.
+        :param new_ref: The new sha the refname will refer to.
+        :return: True if the set was successful, False otherwise.
+        """
+        try:
+            realname, _ = self._follow(name)
+        except KeyError:
+            realname = name
+        filename = self.refpath(realname)
+        ensure_dir_exists(os.path.dirname(filename))
+        f = GitFile(filename, 'wb')
         try:
-            return f.read().strip("\n")
+            if old_ref is not None:
+                try:
+                    # read again while holding the lock
+                    orig_ref = self._read_ref_file(realname)
+                    if orig_ref is None:
+                        orig_ref = self.get_packed_refs().get(realname, None)
+                    if orig_ref != old_ref:
+                        f.abort()
+                        return False
+                except (OSError, IOError):
+                    f.abort()
+                    raise
+            try:
+                f.write(new_ref+"\n")
+            except (OSError, IOError):
+                f.abort()
+                raise
         finally:
             f.close()
+        return True
+
+    def add_if_new(self, name, ref):
+        """Add a new reference only if it does not already exist."""
+        self._check_refname(name)
+        filename = self.refpath(name)
+        ensure_dir_exists(os.path.dirname(filename))
+        f = GitFile(filename, 'wb')
+        try:
+            if os.path.exists(filename) or name in self.get_packed_refs():
+                f.abort()
+                return False
+            try:
+                f.write(ref+"\n")
+            except (OSError, IOError):
+                f.abort()
+                raise
+        finally:
+            f.close()
+        return True
 
     def __setitem__(self, name, ref):
-        file = self.refpath(name)
-        dirpath = os.path.dirname(file)
-        if not os.path.exists(dirpath):
-            os.makedirs(dirpath)
-        f = open(file, 'wb')
+        """Set a reference name to point to the given SHA1.
+
+        This method follows all symbolic references.
+
+        :note: This method unconditionally overwrites the contents of a reference
+            on disk. To update atomically only if the reference has not changed
+            on disk, use set_if_equals().
+        """
+        self.set_if_equals(name, None, ref)
+
+    def remove_if_equals(self, name, old_ref):
+        """Remove a refname only if it currently equals old_ref.
+
+        This method does not follow symbolic references. It can be used to
+        perform an atomic compare-and-delete operation.
+
+        :param name: The refname to delete.
+        :param old_ref: The old sha the refname must refer to, or None to delete
+            unconditionally.
+        :return: True if the delete was successful, False otherwise.
+        """
+        self._check_refname(name)
+        filename = self.refpath(name)
+        ensure_dir_exists(os.path.dirname(filename))
+        f = GitFile(filename, 'wb')
         try:
-            f.write(ref+"\n")
+            if old_ref is not None:
+                orig_ref = self._read_ref_file(name)
+                if orig_ref is None:
+                    orig_ref = self.get_packed_refs().get(name, None)
+                if orig_ref != old_ref:
+                    return False
+            # may only be packed
+            if os.path.exists(filename):
+                os.remove(filename)
+            self._remove_packed_ref(name)
         finally:
-            f.close()
+            # never write, we just wanted the lock
+            f.abort()
+        return True
 
     def __delitem__(self, name):
-        file = self.refpath(name)
-        if os.path.exists(file):
-            os.remove(file)
+        """Remove a refname.
+
+        This method does not follow symbolic references.
+        :note: This method unconditionally deletes the contents of a reference
+            on disk. To delete atomically only if the reference has not changed
+            on disk, use set_if_equals().
+        """
+        self.remove_if_equals(name, None)
+
+
+def _split_ref_line(line):
+    """Split a single ref line into a tuple of SHA1 and name."""
+    fields = line.rstrip("\n").split(" ")
+    if len(fields) != 2:
+        raise PackedRefsException("invalid ref line '%s'" % line)
+    sha, name = fields
+    try:
+        hex_to_sha(sha)
+    except (AssertionError, TypeError), e:
+        raise PackedRefsException(e)
+    if not check_ref_format(name):
+        raise PackedRefsException("invalid ref name '%s'" % name)
+    return (sha, name)
 
 
 def read_packed_refs(f):
     """Read a packed refs file.
 
-    Yields tuples with ref names and SHA1s.
+    Yields tuples with SHA1s and ref names.
 
     :param f: file-like object to read from
     """
-    l = f.readline()
-    for l in f.readlines():
+    for l in f:
         if l[0] == "#":
             # Comment
             continue
         if l[0] == "^":
-            # FIXME: Return somehow
+            raise PackedRefsException(
+                "found peeled ref in packed-refs without peeled")
+        yield _split_ref_line(l)
+
+
+def read_packed_refs_with_peeled(f):
+    """Read a packed refs file including peeled refs.
+
+    Assumes the "# pack-refs with: peeled" line was already read. Yields tuples
+    with ref names, SHA1s, and peeled SHA1s (or None).
+
+    :param f: file-like object to read from, seek'ed to the second line
+    """
+    last = None
+    for l in f:
+        if l[0] == "#":
             continue
-        yield tuple(l.rstrip("\n").split(" ", 2))
+        l = l.rstrip("\n")
+        if l[0] == "^":
+            if not last:
+                raise PackedRefsException("unexpected peeled ref line")
+            try:
+                hex_to_sha(l[1:])
+            except (AssertionError, TypeError), e:
+                raise PackedRefsException(e)
+            sha, name = _split_ref_line(last)
+            last = None
+            yield (sha, name, l[1:])
+        else:
+            if last:
+                sha, name = _split_ref_line(last)
+                yield (sha, name, None)
+            last = l
+    if last:
+        sha, name = _split_ref_line(last)
+        yield (sha, name, None)
+
 
+def write_packed_refs(f, packed_refs, peeled_refs=None):
+    """Write a packed refs file.
+
+    :param f: empty file-like object to write to
+    :param packed_refs: dict of refname to sha of packed refs to write
+    """
+    if peeled_refs is None:
+        peeled_refs = {}
+    else:
+        f.write('# pack-refs with: peeled\n')
+    for refname in sorted(packed_refs.iterkeys()):
+        f.write('%s %s\n' % (packed_refs[refname], refname))
+        if refname in peeled_refs:
+            f.write('^%s\n' % peeled_refs[refname])
+
+class BaseRepo(object):
+    """Base class for a git repository.
 
-class Repo(object):
-    """A local git repository.
-    
-    :ivar refs: Dictionary with the refs in this repository
     :ivar object_store: Dictionary-like object for accessing
         the objects
+    :ivar refs: Dictionary-like object with the refs in this repository
     """
 
-    def __init__(self, root):
-        if os.path.isdir(os.path.join(root, ".git", OBJECTDIR)):
-            self.bare = False
-            self._controldir = os.path.join(root, ".git")
-        elif (os.path.isdir(os.path.join(root, OBJECTDIR)) and
-              os.path.isdir(os.path.join(root, REFSDIR))):
-            self.bare = True
-            self._controldir = root
-        else:
-            raise NotGitRepository(root)
-        self.path = root
-        self.refs = DiskRefsContainer(self.controldir())
-        self.object_store = DiskObjectStore(
-            os.path.join(self.controldir(), OBJECTDIR))
+    def __init__(self, object_store, refs):
+        self.object_store = object_store
+        self.refs = refs
 
-    def controldir(self):
-        """Return the path of the control directory."""
-        return self._controldir
+    def get_named_file(self, path):
+        """Get a file from the control dir with a specific name.
 
-    def index_path(self):
-        """Return path to the index file."""
-        return os.path.join(self.controldir(), INDEX_FILENAME)
-
-    def open_index(self):
-        """Open the index for this repository."""
-        from dulwich.index import Index
-        return Index(self.index_path())
+        Although the filename should be interpreted as a filename relative to
+        the control dir in a disk-baked Repo, the object returned need not be
+        pointing to a file in that location.
 
-    def has_index(self):
-        """Check if an index is present."""
-        return os.path.exists(self.index_path())
+        :param path: The path to the file, relative to the control dir.
+        :return: An open file object, or None if the file does not exist.
+        """
+        raise NotImplementedError(self.get_named_file)
 
     def fetch(self, target, determine_wants=None, progress=None):
         """Fetch objects into another repository.
@@ -281,46 +574,15 @@ class Repo(object):
 
     def ref(self, name):
         """Return the SHA1 a ref is pointing to."""
-        try:
-            return self.refs.follow(name)
-        except KeyError:
-            return self.get_packed_refs()[name]
+        return self.refs[name]
 
     def get_refs(self):
         """Get dictionary with all refs."""
-        ret = {}
-        try:
-            if self.head():
-                ret['HEAD'] = self.head()
-        except KeyError:
-            pass
-        ret.update(self.refs.as_dict())
-        ret.update(self.get_packed_refs())
-        return ret
-
-    def get_packed_refs(self):
-        """Get contents of the packed-refs file.
-
-        :return: Dictionary mapping ref names to SHA1s
-
-        :note: Will return an empty dictionary when no packed-refs file is 
-            present.
-        """
-        path = os.path.join(self.controldir(), 'packed-refs')
-        if not os.path.exists(path):
-            return {}
-        ret = {}
-        f = open(path, 'rb')
-        try:
-            for entry in read_packed_refs(f):
-                ret[entry[1]] = entry[0]
-            return ret
-        finally:
-            f.close()
+        return self.refs.as_dict()
 
     def head(self):
         """Return the SHA1 pointed at by HEAD."""
-        return self.refs.follow('HEAD')
+        return self.refs['HEAD']
 
     def _get_object(self, sha, cls):
         assert len(sha) in (20, 40)
@@ -392,14 +654,11 @@ class Repo(object):
         history.reverse()
         return history
 
-    def __repr__(self):
-        return "<Repo at %r>" % self.path
-
     def __getitem__(self, name):
         if len(name) in (20, 40):
             return self.object_store[name]
         return self.object_store[self.refs[name]]
-    
+
     def __setitem__(self, name, value):
         if name.startswith("refs/") or name == "HEAD":
             if isinstance(value, ShaFile):
@@ -415,6 +674,63 @@ class Repo(object):
             del self.refs[name]
         raise ValueError(name)
 
+
+class Repo(BaseRepo):
+    """A git repository backed by local disk."""
+
+    def __init__(self, root):
+        if os.path.isdir(os.path.join(root, ".git", OBJECTDIR)):
+            self.bare = False
+            self._controldir = os.path.join(root, ".git")
+        elif (os.path.isdir(os.path.join(root, OBJECTDIR)) and
+              os.path.isdir(os.path.join(root, REFSDIR))):
+            self.bare = True
+            self._controldir = root
+        else:
+            raise NotGitRepository(root)
+        self.path = root
+        object_store = DiskObjectStore(
+            os.path.join(self.controldir(), OBJECTDIR))
+        refs = DiskRefsContainer(self.controldir())
+        BaseRepo.__init__(self, object_store, refs)
+
+    def controldir(self):
+        """Return the path of the control directory."""
+        return self._controldir
+
+    def get_named_file(self, path):
+        """Get a file from the control dir with a specific name.
+
+        Although the filename should be interpreted as a filename relative to
+        the control dir in a disk-baked Repo, the object returned need not be
+        pointing to a file in that location.
+
+        :param path: The path to the file, relative to the control dir.
+        :return: An open file object, or None if the file does not exist.
+        """
+        try:
+            return open(os.path.join(self.controldir(), path.lstrip('/')), 'rb')
+        except (IOError, OSError), e:
+            if e.errno == errno.ENOENT:
+                return None
+            raise
+
+    def index_path(self):
+        """Return path to the index file."""
+        return os.path.join(self.controldir(), INDEX_FILENAME)
+
+    def open_index(self):
+        """Open the index for this repository."""
+        from dulwich.index import Index
+        return Index(self.index_path())
+
+    def has_index(self):
+        """Check if an index is present."""
+        return os.path.exists(self.index_path())
+
+    def __repr__(self):
+        return "<Repo at %r>" % self.path
+
     def do_commit(self, committer, message,
                   author=None, commit_timestamp=None,
                   commit_timezone=None, author_timestamp=None, 
@@ -482,15 +798,25 @@ class Repo(object):
             os.mkdir(os.path.join(path, *d))
         ret = cls(path)
         ret.refs.set_ref("HEAD", "refs/heads/master")
-        open(os.path.join(path, 'description'), 'wb').write("Unnamed repository")
-        open(os.path.join(path, 'info', 'excludes'), 'wb').write("")
-        open(os.path.join(path, 'config'), 'wb').write("""[core]
+        f = GitFile(os.path.join(path, 'description'), 'wb')
+        try:
+            f.write("Unnamed repository")
+        finally:
+            f.close()
+
+        f = GitFile(os.path.join(path, 'config'), 'wb')
+        try:
+            f.write("""[core]
     repositoryformatversion = 0
     filemode = true
     bare = false
     logallrefupdates = true
 """)
+        finally:
+            f.close()
+
+        f = GitFile(os.path.join(path, 'info', 'excludes'), 'wb')
+        f.close()
         return ret
 
     create = init_bare
-

+ 414 - 99
dulwich/server.py

@@ -17,17 +17,37 @@
 # MA  02110-1301, USA.
 
 
-"""Git smart network protocol server implementation."""
+"""Git smart network protocol server implementation.
 
+For more detailed implementation on the network protocol, see the
+Documentation/technical directory in the cgit distribution, and in particular:
+    Documentation/technical/protocol-capabilities.txt
+    Documentation/technical/pack-protocol.txt
+"""
 
+
+import collections
 import SocketServer
 import tempfile
 
+from dulwich.errors import (
+    ApplyDeltaError,
+    ChecksumMismatch,
+    GitProtocolError,
+    )
+from dulwich.objects import (
+    hex_to_sha,
+    )
 from dulwich.protocol import (
     Protocol,
     ProtocolFile,
     TCP_GIT_PORT,
     extract_capabilities,
+    extract_want_line_capabilities,
+    SINGLE_ACK,
+    MULTI_ACK,
+    MULTI_ACK_DETAILED,
+    ack_type,
     )
 from dulwich.repo import (
     Repo,
@@ -65,31 +85,65 @@ class Backend(object):
 
 class GitBackend(Backend):
 
-    def __init__(self, gitdir=None):
-        self.gitdir = gitdir
-
-        if not self.gitdir:
-            self.gitdir = tempfile.mkdtemp()
-            Repo.create(self.gitdir)
-
-        self.repo = Repo(self.gitdir)
+    def __init__(self, repo=None):
+        if repo is None:
+            repo = Repo(tmpfile.mkdtemp())
+        self.repo = repo
+        self.object_store = self.repo.object_store
         self.fetch_objects = self.repo.fetch_objects
         self.get_refs = self.repo.get_refs
 
     def apply_pack(self, refs, read):
         f, commit = self.repo.object_store.add_thin_pack()
+        all_exceptions = (IOError, OSError, ChecksumMismatch, ApplyDeltaError)
+        status = []
+        unpack_error = None
+        # TODO: more informative error messages than just the exception string
+        try:
+            # TODO: decode the pack as we stream to avoid blocking reads beyond
+            # the end of data (when using HTTP/1.1 chunked encoding)
+            while True:
+                data = read(10240)
+                if not data:
+                    break
+                f.write(data)
+        except all_exceptions, e:
+            unpack_error = str(e).replace('\n', '')
         try:
-            f.write(read())
-        finally:
             commit()
+        except all_exceptions, e:
+            if not unpack_error:
+                unpack_error = str(e).replace('\n', '')
+
+        if unpack_error:
+            status.append(('unpack', unpack_error))
+        else:
+            status.append(('unpack', 'ok'))
 
         for oldsha, sha, ref in refs:
-            if ref == "0" * 40:
-                del self.repo.refs[ref]
+            # TODO: check refname
+            ref_error = None
+            try:
+                if ref == "0" * 40:
+                    try:
+                        del self.repo.refs[ref]
+                    except all_exceptions:
+                        ref_error = 'failed to delete'
+                else:
+                    try:
+                        self.repo.refs[ref] = sha
+                    except all_exceptions:
+                        ref_error = 'failed to write'
+            except KeyError, e:
+                ref_error = 'bad ref'
+            if ref_error:
+                status.append((ref, ref_error))
             else:
-                self.repo.refs[ref] = sha
+                status.append((ref, 'ok'))
+
 
         print "pack applied"
+        return status
 
 
 class Handler(object):
@@ -106,102 +160,354 @@ class Handler(object):
 class UploadPackHandler(Handler):
     """Protocol handler for uploading a pack to the server."""
 
+    def __init__(self, backend, read, write,
+                 stateless_rpc=False, advertise_refs=False):
+        Handler.__init__(self, backend, read, write)
+        self._client_capabilities = None
+        self._graph_walker = None
+        self.stateless_rpc = stateless_rpc
+        self.advertise_refs = advertise_refs
+
     def default_capabilities(self):
-        return ("multi_ack", "side-band-64k", "thin-pack", "ofs-delta")
+        return ("multi_ack_detailed", "multi_ack", "side-band-64k", "thin-pack",
+                "ofs-delta")
 
-    def handle(self):
-        def determine_wants(heads):
-            keys = heads.keys()
-            if keys:
-                self.proto.write_pkt_line("%s %s\x00%s\n" % ( heads[keys[0]], keys[0], self.capabilities()))
-                for k in keys[1:]:
-                    self.proto.write_pkt_line("%s %s\n" % (heads[k], k))
+    def set_client_capabilities(self, caps):
+        my_caps = self.default_capabilities()
+        for cap in caps:
+            if '_ack' in cap and cap not in my_caps:
+                raise GitProtocolError('Client asked for capability %s that '
+                                       'was not advertised.' % cap)
+        self._client_capabilities = caps
 
-            # i'm done..
-            self.proto.write("0000")
+    def get_client_capabilities(self):
+        return self._client_capabilities
 
-            # Now client will either send "0000", meaning that it doesnt want to pull.
-            # or it will start sending want want want commands
-            want = self.proto.read_pkt_line()
-            if want == None:
-                return []
+    client_capabilities = property(get_client_capabilities,
+                                   set_client_capabilities)
 
-            want, self.client_capabilities = extract_capabilities(want)
-
-            want_revs = []
-            while want and want[:4] == 'want':
-                want_revs.append(want[5:45])
-                want = self.proto.read_pkt_line()
-                if want == None:
-                    self.proto.write_pkt_line("ACK %s\n" % want_revs[-1])
-            return want_revs
+    def handle(self):
 
         progress = lambda x: self.proto.write_sideband(2, x)
         write = lambda x: self.proto.write_sideband(1, x)
 
-        class ProtocolGraphWalker(object):
+        graph_walker = ProtocolGraphWalker(self)
+        objects_iter = self.backend.fetch_objects(
+          graph_walker.determine_wants, graph_walker, progress)
 
-            def __init__(self, proto):
-                self.proto = proto
-                self._last_sha = None
-                self._cached = False
-                self._cache = []
-                self._cache_index = 0
+        # Do they want any objects?
+        if len(objects_iter) == 0:
+            return
 
-            def ack(self, have_ref):
-                self.proto.write_pkt_line("ACK %s continue\n" % have_ref)
+        progress("dul-daemon says what\n")
+        progress("counting objects: %d, done.\n" % len(objects_iter))
+        write_pack_data(ProtocolFile(None, write), objects_iter, 
+                        len(objects_iter))
+        progress("how was that, then?\n")
+        # we are done
+        self.proto.write("0000")
 
-            def reset(self):
-                self._cached = True
-                self._cache_index = 0
 
-            def next(self):
-                if not self._cached:
-                    return self.next_from_proto()
-                self._cache_index = self._cache_index + 1
-                if self._cache_index > len(self._cache):
-                    return None
-                return self._cache[self._cache_index]
+class ProtocolGraphWalker(object):
+    """A graph walker that knows the git protocol.
+
+    As a graph walker, this class implements ack(), next(), and reset(). It also
+    contains some base methods for interacting with the wire and walking the
+    commit tree.
+
+    The work of determining which acks to send is passed on to the
+    implementation instance stored in _impl. The reason for this is that we do
+    not know at object creation time what ack level the protocol requires. A
+    call to set_ack_level() is required to set up the implementation, before any
+    calls to next() or ack() are made.
+    """
+    def __init__(self, handler):
+        self.handler = handler
+        self.store = handler.backend.object_store
+        self.proto = handler.proto
+        self.stateless_rpc = handler.stateless_rpc
+        self.advertise_refs = handler.advertise_refs
+        self._wants = []
+        self._cached = False
+        self._cache = []
+        self._cache_index = 0
+        self._impl = None
+
+    def determine_wants(self, heads):
+        """Determine the wants for a set of heads.
+
+        The given heads are advertised to the client, who then specifies which
+        refs he wants using 'want' lines. This portion of the protocol is the
+        same regardless of ack type, and in fact is used to set the ack type of
+        the ProtocolGraphWalker.
+
+        :param heads: a dict of refname->SHA1 to advertise
+        :return: a list of SHA1s requested by the client
+        """
+        if not heads:
+            raise GitProtocolError('No heads found')
+        values = set(heads.itervalues())
+        if self.advertise_refs or not self.stateless_rpc:
+            for i, (ref, sha) in enumerate(heads.iteritems()):
+                line = "%s %s" % (sha, ref)
+                if not i:
+                    line = "%s\x00%s" % (line, self.handler.capabilities())
+                self.proto.write_pkt_line("%s\n" % line)
+                # TODO: include peeled value of any tags
 
-            def next_from_proto(self):
-                have = self.proto.read_pkt_line()
-                if have is None:
-                    self.proto.write_pkt_line("ACK %s\n" % self._last_sha)
-                    return None
+            # i'm done..
+            self.proto.write_pkt_line(None)
 
-                if have[:4] == 'have':
-                    self._cache.append(have[5:45])
-                    return have[5:45]
+            if self.advertise_refs:
+                return []
 
+        # Now client will sending want want want commands
+        want = self.proto.read_pkt_line()
+        if not want:
+            return []
+        line, caps = extract_want_line_capabilities(want)
+        self.handler.client_capabilities = caps
+        self.set_ack_type(ack_type(caps))
+        command, sha = self._split_proto_line(line)
+
+        want_revs = []
+        while command != None:
+            if command != 'want':
+                raise GitProtocolError(
+                    'Protocol got unexpected command %s' % command)
+            if sha not in values:
+                raise GitProtocolError(
+                    'Client wants invalid object %s' % sha)
+            want_revs.append(sha)
+            command, sha = self.read_proto_line()
+
+        self.set_wants(want_revs)
+        return want_revs
+
+    def ack(self, have_ref):
+        return self._impl.ack(have_ref)
+
+    def reset(self):
+        self._cached = True
+        self._cache_index = 0
+
+    def next(self):
+        if not self._cached:
+            if not self._impl and self.stateless_rpc:
+                return None
+            return self._impl.next()
+        self._cache_index += 1
+        if self._cache_index > len(self._cache):
+            return None
+        return self._cache[self._cache_index]
+
+    def _split_proto_line(self, line):
+        fields = line.rstrip('\n').split(' ', 1)
+        if len(fields) == 1 and fields[0] == 'done':
+            return ('done', None)
+        elif len(fields) == 2 and fields[0] in ('want', 'have'):
+            try:
+                hex_to_sha(fields[1])
+                return tuple(fields)
+            except (TypeError, AssertionError), e:
+                raise GitProtocolError(e)
+        raise GitProtocolError('Received invalid line from client:\n%s' % line)
+
+    def read_proto_line(self):
+        """Read a line from the wire.
+
+        :return: a tuple having one of the following forms:
+            ('want', obj_id)
+            ('have', obj_id)
+            ('done', None)
+            (None, None)  (for a flush-pkt)
+
+        :raise GitProtocolError: if the line cannot be parsed into one of the
+            possible return values.
+        """
+        line = self.proto.read_pkt_line()
+        if not line:
+            return (None, None)
+        return self._split_proto_line(line)
 
-                #if have[:4] == 'done':
-                #    return None
+    def send_ack(self, sha, ack_type=''):
+        if ack_type:
+            ack_type = ' %s' % ack_type
+        self.proto.write_pkt_line('ACK %s%s\n' % (sha, ack_type))
 
-                if self._last_sha:
-                    # Oddness: Git seems to resend the last ACK, without the "continue" statement
-                    self.proto.write_pkt_line("ACK %s\n" % self._last_sha)
+    def send_nak(self):
+        self.proto.write_pkt_line('NAK\n')
 
-                # The exchange finishes with a NAK
-                self.proto.write_pkt_line("NAK\n")
+    def set_wants(self, wants):
+        self._wants = wants
 
-        graph_walker = ProtocolGraphWalker(self.proto)
-        objects_iter = self.backend.fetch_objects(determine_wants, graph_walker, progress)
+    def _is_satisfied(self, haves, want, earliest):
+        """Check whether a want is satisfied by a set of haves.
 
-        # Do they want any objects?
-        if len(objects_iter) == 0:
-            return
+        A want, typically a branch tip, is "satisfied" only if there exists a
+        path back from that want to one of the haves.
 
-        progress("dul-daemon says what\n")
-        progress("counting objects: %d, done.\n" % len(objects_iter))
-        write_pack_data(ProtocolFile(None, write), objects_iter, 
-                        len(objects_iter))
-        progress("how was that, then?\n")
-        # we are done
-        self.proto.write("0000")
+        :param haves: A set of commits we know the client has.
+        :param want: The want to check satisfaction for.
+        :param earliest: A timestamp beyond which the search for haves will be
+            terminated, presumably because we're searching too far down the
+            wrong branch.
+        """
+        o = self.store[want]
+        pending = collections.deque([o])
+        while pending:
+            commit = pending.popleft()
+            if commit.id in haves:
+                return True
+            if not getattr(commit, 'get_parents', None):
+                # non-commit wants are assumed to be satisfied
+                continue
+            for parent in commit.get_parents():
+                parent_obj = self.store[parent]
+                # TODO: handle parents with later commit times than children
+                if parent_obj.commit_time >= earliest:
+                    pending.append(parent_obj)
+        return False
+
+    def all_wants_satisfied(self, haves):
+        """Check whether all the current wants are satisfied by a set of haves.
+
+        :param haves: A set of commits we know the client has.
+        :note: Wants are specified with set_wants rather than passed in since
+            in the current interface they are determined outside this class.
+        """
+        haves = set(haves)
+        earliest = min([self.store[h].commit_time for h in haves])
+        for want in self._wants:
+            if not self._is_satisfied(haves, want, earliest):
+                return False
+        return True
+
+    def set_ack_type(self, ack_type):
+        impl_classes = {
+            MULTI_ACK: MultiAckGraphWalkerImpl,
+            MULTI_ACK_DETAILED: MultiAckDetailedGraphWalkerImpl,
+            SINGLE_ACK: SingleAckGraphWalkerImpl,
+            }
+        self._impl = impl_classes[ack_type](self)
+
+
+class SingleAckGraphWalkerImpl(object):
+    """Graph walker implementation that speaks the single-ack protocol."""
+
+    def __init__(self, walker):
+        self.walker = walker
+        self._sent_ack = False
+
+    def ack(self, have_ref):
+        if not self._sent_ack:
+            self.walker.send_ack(have_ref)
+            self._sent_ack = True
+
+    def next(self):
+        command, sha = self.walker.read_proto_line()
+        if command in (None, 'done'):
+            if not self._sent_ack:
+                self.walker.send_nak()
+            return None
+        elif command == 'have':
+            return sha
+
+
+class MultiAckGraphWalkerImpl(object):
+    """Graph walker implementation that speaks the multi-ack protocol."""
+
+    def __init__(self, walker):
+        self.walker = walker
+        self._found_base = False
+        self._common = []
+
+    def ack(self, have_ref):
+        self._common.append(have_ref)
+        if not self._found_base:
+            self.walker.send_ack(have_ref, 'continue')
+            if self.walker.all_wants_satisfied(self._common):
+                self._found_base = True
+        # else we blind ack within next
+
+    def next(self):
+        while True:
+            command, sha = self.walker.read_proto_line()
+            if command is None:
+                self.walker.send_nak()
+                # in multi-ack mode, a flush-pkt indicates the client wants to
+                # flush but more have lines are still coming
+                continue
+            elif command == 'done':
+                # don't nak unless no common commits were found, even if not
+                # everything is satisfied
+                if self._common:
+                    self.walker.send_ack(self._common[-1])
+                else:
+                    self.walker.send_nak()
+                return None
+            elif command == 'have':
+                if self._found_base:
+                    # blind ack
+                    self.walker.send_ack(sha, 'continue')
+                return sha
+
+
+class MultiAckDetailedGraphWalkerImpl(object):
+    """Graph walker implementation speaking the multi-ack-detailed protocol."""
+
+    def __init__(self, walker):
+        self.walker = walker
+        self._found_base = False
+        self._common = []
+
+    def ack(self, have_ref):
+        self._common.append(have_ref)
+        if not self._found_base:
+            self.walker.send_ack(have_ref, 'common')
+            if self.walker.all_wants_satisfied(self._common):
+                self._found_base = True
+                self.walker.send_ack(have_ref, 'ready')
+        # else we blind ack within next
+
+    def next(self):
+        while True:
+            command, sha = self.walker.read_proto_line()
+            if command is None:
+                self.walker.send_nak()
+                if self.walker.stateless_rpc:
+                    return None
+                continue
+            elif command == 'done':
+                # don't nak unless no common commits were found, even if not
+                # everything is satisfied
+                if self._common:
+                    self.walker.send_ack(self._common[-1])
+                else:
+                    self.walker.send_nak()
+                return None
+            elif command == 'have':
+                if self._found_base:
+                    # blind ack; can happen if the client has more requests
+                    # inflight
+                    self.walker.send_ack(sha, 'ready')
+                return sha
 
 
 class ReceivePackHandler(Handler):
-    """Protocol handler for downloading a pack to the client."""
+    """Protocol handler for downloading a pack from the client."""
+
+    def __init__(self, backend, read, write,
+                 stateless_rpc=False, advertise_refs=False):
+        Handler.__init__(self, backend, read, write)
+        self.stateless_rpc = stateless_rpc
+        self.advertise_refs = advertise_refs
+
+    def __init__(self, backend, read, write,
+                 stateless_rpc=False, advertise_refs=False):
+        Handler.__init__(self, backend, read, write)
+        self._stateless_rpc = stateless_rpc
+        self._advertise_refs = advertise_refs
 
     def default_capabilities(self):
         return ("report-status", "delete-refs")
@@ -209,15 +515,18 @@ class ReceivePackHandler(Handler):
     def handle(self):
         refs = self.backend.get_refs().items()
 
-        if refs:
-            self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
-            for i in range(1, len(refs)):
-                ref = refs[i]
-                self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
-        else:
-            self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
+        if self.advertise_refs or not self.stateless_rpc:
+            if refs:
+                self.proto.write_pkt_line("%s %s\x00%s\n" % (refs[0][1], refs[0][0], self.capabilities()))
+                for i in range(1, len(refs)):
+                    ref = refs[i]
+                    self.proto.write_pkt_line("%s %s\n" % (ref[1], ref[0]))
+            else:
+                self.proto.write_pkt_line("0000000000000000000000000000000000000000 capabilities^{} %s" % self.capabilities())
 
-        self.proto.write("0000")
+            self.proto.write("0000")
+            if self.advertise_refs:
+                return
 
         client_refs = []
         ref = self.proto.read_pkt_line()
@@ -234,11 +543,19 @@ class ReceivePackHandler(Handler):
             ref = self.proto.read_pkt_line()
 
         # backend can now deal with this refs and read a pack using self.read
-        self.backend.apply_pack(client_refs, self.proto.read)
+        status = self.backend.apply_pack(client_refs, self.proto.read)
 
-        # when we have read all the pack from the client, it assumes 
-        # everything worked OK.
-        # there is NO ack from the server before it reports victory.
+        # when we have read all the pack from the client, send a status report
+        # if the client asked for it
+        if 'report-status' in client_capabilities:
+            for name, msg in status:
+                if name == 'unpack':
+                    self.proto.write_pkt_line('unpack %s\n' % msg)
+                elif msg == 'ok':
+                    self.proto.write_pkt_line('ok %s\n' % name)
+                else:
+                    self.proto.write_pkt_line('ng %s %s\n' % (name, msg))
+            self.proto.write_pkt_line(None)
 
 
 class TCPGitRequestHandler(SocketServer.StreamRequestHandler):
@@ -267,5 +584,3 @@ class TCPGitServer(SocketServer.TCPServer):
     def __init__(self, backend, listen_addr, port=TCP_GIT_PORT):
         self.backend = backend
         SocketServer.TCPServer.__init__(self, (listen_addr, port), TCPGitRequestHandler)
-
-

+ 1 - 0
dulwich/tests/data/repos/refs.git/HEAD

@@ -0,0 +1 @@
+ref: refs/heads/master

二進制
dulwich/tests/data/repos/refs.git/objects/3b/9e5457140e738c2dcd39bf6d7acf88379b90d1


二進制
dulwich/tests/data/repos/refs.git/objects/42/d06bd4b77fed026b154d16493e5deab78f02ec


二進制
dulwich/tests/data/repos/refs.git/objects/a1/8114c31713746a33a2e70d9914d1ef3e781425


二進制
dulwich/tests/data/repos/refs.git/objects/df/6800012397fb85c56e7418dd4eb9405dee075c


+ 3 - 0
dulwich/tests/data/repos/refs.git/packed-refs

@@ -0,0 +1,3 @@
+# pack-refs with: peeled 
+df6800012397fb85c56e7418dd4eb9405dee075c refs/tags/refs-0.1
+^42d06bd4b77fed026b154d16493e5deab78f02ec

+ 1 - 0
dulwich/tests/data/repos/refs.git/refs/heads/loop

@@ -0,0 +1 @@
+ref: refs/heads/loop

+ 1 - 0
dulwich/tests/data/repos/refs.git/refs/heads/master

@@ -0,0 +1 @@
+42d06bd4b77fed026b154d16493e5deab78f02ec

+ 131 - 0
dulwich/tests/test_file.py

@@ -0,0 +1,131 @@
+# test_file.py -- Test for git files
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# of the License or (at your option) a later version of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+
+import errno
+import os
+import shutil
+import tempfile
+import unittest
+
+from dulwich.file import GitFile
+
+class GitFileTests(unittest.TestCase):
+    def setUp(self):
+        self._tempdir = tempfile.mkdtemp()
+        f = open(self.path('foo'), 'wb')
+        f.write('foo contents')
+        f.close()
+
+    def tearDown(self):
+        shutil.rmtree(self._tempdir)
+
+    def path(self, filename):
+        return os.path.join(self._tempdir, filename)
+
+    def test_invalid(self):
+        foo = self.path('foo')
+        self.assertRaises(IOError, GitFile, foo, mode='r')
+        self.assertRaises(IOError, GitFile, foo, mode='ab')
+        self.assertRaises(IOError, GitFile, foo, mode='r+b')
+        self.assertRaises(IOError, GitFile, foo, mode='w+b')
+        self.assertRaises(IOError, GitFile, foo, mode='a+bU')
+
+    def test_readonly(self):
+        f = GitFile(self.path('foo'), 'rb')
+        self.assertTrue(isinstance(f, file))
+        self.assertEquals('foo contents', f.read())
+        self.assertEquals('', f.read())
+        f.seek(4)
+        self.assertEquals('contents', f.read())
+        f.close()
+
+    def test_write(self):
+        foo = self.path('foo')
+        foo_lock = '%s.lock' % foo
+
+        orig_f = open(foo, 'rb')
+        self.assertEquals(orig_f.read(), 'foo contents')
+        orig_f.close()
+
+        self.assertFalse(os.path.exists(foo_lock))
+        f = GitFile(foo, 'wb')
+        self.assertFalse(f.closed)
+        self.assertRaises(AttributeError, getattr, f, 'not_a_file_property')
+
+        self.assertTrue(os.path.exists(foo_lock))
+        f.write('new stuff')
+        f.seek(4)
+        f.write('contents')
+        f.close()
+        self.assertFalse(os.path.exists(foo_lock))
+
+        new_f = open(foo, 'rb')
+        self.assertEquals('new contents', new_f.read())
+        new_f.close()
+
+    def test_open_twice(self):
+        foo = self.path('foo')
+        f1 = GitFile(foo, 'wb')
+        f1.write('new')
+        try:
+            f2 = GitFile(foo, 'wb')
+            fail()
+        except OSError, e:
+            self.assertEquals(errno.EEXIST, e.errno)
+        f1.write(' contents')
+        f1.close()
+
+        # Ensure trying to open twice doesn't affect original.
+        f = open(foo, 'rb')
+        self.assertEquals('new contents', f.read())
+        f.close()
+
+    def test_abort(self):
+        foo = self.path('foo')
+        foo_lock = '%s.lock' % foo
+
+        orig_f = open(foo, 'rb')
+        self.assertEquals(orig_f.read(), 'foo contents')
+        orig_f.close()
+
+        f = GitFile(foo, 'wb')
+        f.write('new contents')
+        f.abort()
+        self.assertTrue(f.closed)
+        self.assertFalse(os.path.exists(foo_lock))
+
+        new_orig_f = open(foo, 'rb')
+        self.assertEquals(new_orig_f.read(), 'foo contents')
+        new_orig_f.close()
+
+    def test_abort_close(self):
+        foo = self.path('foo')
+        f = GitFile(foo, 'wb')
+        f.abort()
+        try:
+            f.close()
+        except (IOError, OSError):
+            self.fail()
+
+        f = GitFile(foo, 'wb')
+        f.close()
+        try:
+            f.abort()
+        except (IOError, OSError):
+            self.fail()

+ 28 - 3
dulwich/tests/test_protocol.py

@@ -26,6 +26,11 @@ from unittest import TestCase
 from dulwich.protocol import (
     Protocol,
     extract_capabilities,
+    extract_want_line_capabilities,
+    ack_type,
+    SINGLE_ACK,
+    MULTI_ACK,
+    MULTI_ACK_DETAILED,
     )
 
 class ProtocolTests(TestCase):
@@ -77,10 +82,30 @@ class ProtocolTests(TestCase):
         self.assertRaises(AssertionError, self.proto.read_cmd)
 
 
-class ExtractCapabilitiesTestCase(TestCase):
+class CapabilitiesTestCase(TestCase):
 
     def test_plain(self):
-        self.assertEquals(("bla", None), extract_capabilities("bla"))
+        self.assertEquals(("bla", []), extract_capabilities("bla"))
 
     def test_caps(self):
-        self.assertEquals(("bla", ["la", "la"]), extract_capabilities("bla\0la\0la"))
+        self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la"))
+        self.assertEquals(("bla", ["la"]), extract_capabilities("bla\0la\n"))
+        self.assertEquals(("bla", ["la", "la"]), extract_capabilities("bla\0la la"))
+
+    def test_plain_want_line(self):
+        self.assertEquals(("want bla", []), extract_want_line_capabilities("want bla"))
+
+    def test_caps_want_line(self):
+        self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la"))
+        self.assertEquals(("want bla", ["la"]), extract_want_line_capabilities("want bla la\n"))
+        self.assertEquals(("want bla", ["la", "la"]), extract_want_line_capabilities("want bla la la"))
+
+    def test_ack_type(self):
+        self.assertEquals(SINGLE_ACK, ack_type(['foo', 'bar']))
+        self.assertEquals(MULTI_ACK, ack_type(['foo', 'bar', 'multi_ack']))
+        self.assertEquals(MULTI_ACK_DETAILED,
+                          ack_type(['foo', 'bar', 'multi_ack_detailed']))
+        # choose detailed when both present
+        self.assertEquals(MULTI_ACK_DETAILED,
+                          ack_type(['foo', 'bar', 'multi_ack',
+                                    'multi_ack_detailed']))

+ 305 - 29
dulwich/tests/test_repository.py

@@ -20,75 +20,110 @@
 
 """Tests for the repository."""
 
-
+from cStringIO import StringIO
 import os
+import shutil
+import tempfile
 import unittest
 
 from dulwich import errors
-from dulwich.repo import Repo
+from dulwich.repo import (
+    check_ref_format,
+    DiskRefsContainer,
+    Repo,
+    read_packed_refs,
+    read_packed_refs_with_peeled,
+    write_packed_refs,
+    _split_ref_line,
+    )
 
 missing_sha = 'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
 
+
+def open_repo(name):
+    """Open a copy of a repo in a temporary directory.
+
+    Use this function for accessing repos in dulwich/tests/data/repos to avoid
+    accidentally or intentionally modifying those repos in place. Use
+    tear_down_repo to delete any temp files created.
+
+    :param name: The name of the repository, relative to
+        dulwich/tests/data/repos
+    :returns: An initialized Repo object that lives in a temporary directory.
+    """
+    temp_dir = tempfile.mkdtemp()
+    repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos', name)
+    temp_repo_dir = os.path.join(temp_dir, name)
+    shutil.copytree(repo_dir, temp_repo_dir, symlinks=True)
+    return Repo(temp_repo_dir)
+
+def tear_down_repo(repo):
+    """Tear down a test repository."""
+    temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
+    shutil.rmtree(temp_dir)
+
+
 class RepositoryTests(unittest.TestCase):
-  
-    def open_repo(self, name):
-        return Repo(os.path.join(os.path.dirname(__file__),
-                          'data', 'repos', name))
+
+    def setUp(self):
+        self._repo = None
+
+    def tearDown(self):
+        if self._repo is not None:
+            tear_down_repo(self._repo)
   
     def test_simple_props(self):
-        r = self.open_repo('a.git')
-        basedir = os.path.join(os.path.dirname(__file__), 
-                os.path.join('data', 'repos', 'a.git'))
-        self.assertEqual(r.controldir(), basedir)
+        r = self._repo = open_repo('a.git')
+        self.assertEqual(r.controldir(), r.path)
   
     def test_ref(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertEqual(r.ref('refs/heads/master'),
                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
   
     def test_get_refs(self):
-        r = self.open_repo('a.git')
-        self.assertEquals({
+        r = self._repo = open_repo('a.git')
+        self.assertEqual({
             'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097', 
             'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
             }, r.get_refs())
   
     def test_head(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
   
     def test_get_object(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         obj = r.get_object(r.head())
         self.assertEqual(obj._type, 'commit')
   
     def test_get_object_non_existant(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertRaises(KeyError, r.get_object, missing_sha)
   
     def test_commit(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         obj = r.commit(r.head())
         self.assertEqual(obj._type, 'commit')
   
     def test_commit_not_commit(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertRaises(errors.NotCommitError,
                           r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
   
     def test_tree(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         commit = r.commit(r.head())
         tree = r.tree(commit.tree)
         self.assertEqual(tree._type, 'tree')
         self.assertEqual(tree.sha().hexdigest(), commit.tree)
   
     def test_tree_not_tree(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertRaises(errors.NotTreeError, r.tree, r.head())
   
     def test_get_blob(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         commit = r.commit(r.head())
         tree = r.tree(commit.tree)
         blob_sha = tree.entries()[0][2]
@@ -97,18 +132,18 @@ class RepositoryTests(unittest.TestCase):
         self.assertEqual(blob.sha().hexdigest(), blob_sha)
   
     def test_get_blob_notblob(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
     
     def test_linear_history(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, [r.head(),
                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
   
     def test_merge_history(self):
-        r = self.open_repo('simple_merge.git')
+        r = self._repo = open_repo('simple_merge.git')
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, ['5dac377bdded4c9aeb8dff595f0faeebcc8498cc',
@@ -118,13 +153,13 @@ class RepositoryTests(unittest.TestCase):
                                 '0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
   
     def test_revision_history_missing_commit(self):
-        r = self.open_repo('simple_merge.git')
+        r = self._repo = open_repo('simple_merge.git')
         self.assertRaises(errors.MissingCommitError, r.revision_history,
                           missing_sha)
   
     def test_out_of_order_merge(self):
         """Test that revision history is ordered by date, not parent order."""
-        r = self.open_repo('ooo_merge.git')
+        r = self._repo = open_repo('ooo_merge.git')
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, ['7601d7f6231db6a57f7bbb79ee52e4d462fd44d1',
@@ -133,9 +168,250 @@ class RepositoryTests(unittest.TestCase):
                                 'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
   
     def test_get_tags_empty(self):
-        r = self.open_repo('ooo_merge.git')
-        self.assertEquals({}, r.refs.as_dict('refs/tags'))
+        r = self._repo = open_repo('ooo_merge.git')
+        self.assertEqual({}, r.refs.as_dict('refs/tags'))
 
     def test_get_config(self):
-        r = self.open_repo('ooo_merge.git')
+        r = self._repo = open_repo('ooo_merge.git')
         self.assertEquals({}, r.get_config())
+
+
+class CheckRefFormatTests(unittest.TestCase):
+    """Tests for the check_ref_format function.
+
+    These are the same tests as in the git test suite.
+    """
+
+    def test_valid(self):
+        self.assertTrue(check_ref_format('heads/foo'))
+        self.assertTrue(check_ref_format('foo/bar/baz'))
+        self.assertTrue(check_ref_format('refs///heads/foo'))
+        self.assertTrue(check_ref_format('foo./bar'))
+        self.assertTrue(check_ref_format('heads/foo@bar'))
+        self.assertTrue(check_ref_format('heads/fix.lock.error'))
+
+    def test_invalid(self):
+        self.assertFalse(check_ref_format('foo'))
+        self.assertFalse(check_ref_format('heads/foo/'))
+        self.assertFalse(check_ref_format('./foo'))
+        self.assertFalse(check_ref_format('.refs/foo'))
+        self.assertFalse(check_ref_format('heads/foo..bar'))
+        self.assertFalse(check_ref_format('heads/foo?bar'))
+        self.assertFalse(check_ref_format('heads/foo.lock'))
+        self.assertFalse(check_ref_format('heads/v@{ation'))
+        self.assertFalse(check_ref_format('heads/foo\bar'))
+
+
+ONES = "1" * 40
+TWOS = "2" * 40
+THREES = "3" * 40
+FOURS = "4" * 40
+
+class PackedRefsFileTests(unittest.TestCase):
+    def test_split_ref_line_errors(self):
+        self.assertRaises(errors.PackedRefsException, _split_ref_line,
+                          'singlefield')
+        self.assertRaises(errors.PackedRefsException, _split_ref_line,
+                          'badsha name')
+        self.assertRaises(errors.PackedRefsException, _split_ref_line,
+                          '%s bad/../refname' % ONES)
+
+    def test_read_without_peeled(self):
+        f = StringIO('# comment\n%s ref/1\n%s ref/2' % (ONES, TWOS))
+        self.assertEqual([(ONES, 'ref/1'), (TWOS, 'ref/2')],
+                         list(read_packed_refs(f)))
+
+    def test_read_without_peeled_errors(self):
+        f = StringIO('%s ref/1\n^%s' % (ONES, TWOS))
+        self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
+
+    def test_read_with_peeled(self):
+        f = StringIO('%s ref/1\n%s ref/2\n^%s\n%s ref/4' % (
+            ONES, TWOS, THREES, FOURS))
+        self.assertEqual([
+            (ONES, 'ref/1', None),
+            (TWOS, 'ref/2', THREES),
+            (FOURS, 'ref/4', None),
+            ], list(read_packed_refs_with_peeled(f)))
+
+    def test_read_with_peeled_errors(self):
+        f = StringIO('^%s\n%s ref/1' % (TWOS, ONES))
+        self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
+
+        f = StringIO('%s ref/1\n^%s\n^%s' % (ONES, TWOS, THREES))
+        self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
+
+    def test_write_with_peeled(self):
+        f = StringIO()
+        write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS},
+                          {'ref/1': THREES})
+        self.assertEqual(
+            "# pack-refs with: peeled\n%s ref/1\n^%s\n%s ref/2\n" % (
+            ONES, THREES, TWOS), f.getvalue())
+
+    def test_write_without_peeled(self):
+        f = StringIO()
+        write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS})
+        self.assertEqual("%s ref/1\n%s ref/2\n" % (ONES, TWOS), f.getvalue())
+
+
+class RefsContainerTests(unittest.TestCase):
+    def setUp(self):
+        self._repo = open_repo('refs.git')
+        self._refs = self._repo.refs
+
+    def tearDown(self):
+        tear_down_repo(self._repo)
+
+    def test_get_packed_refs(self):
+        self.assertEqual(
+            {'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c'},
+            self._refs.get_packed_refs())
+
+    def test_keys(self):
+        self.assertEqual([
+            'HEAD',
+            'refs/heads/loop',
+            'refs/heads/master',
+            'refs/tags/refs-0.1',
+            ], sorted(list(self._refs.keys())))
+        self.assertEqual(['loop', 'master'],
+                         sorted(self._refs.keys('refs/heads')))
+        self.assertEqual(['refs-0.1'], list(self._refs.keys('refs/tags')))
+
+    def test_as_dict(self):
+        # refs/heads/loop does not show up
+        self.assertEqual({
+            'HEAD': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+            'refs/heads/master': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+            'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
+            }, self._refs.as_dict())
+
+    def test_setitem(self):
+        self._refs['refs/some/ref'] = '42d06bd4b77fed026b154d16493e5deab78f02ec'
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/some/ref'])
+        f = open(os.path.join(self._refs.path, 'refs', 'some', 'ref'), 'rb')
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                          f.read()[:40])
+        f.close()
+
+    def test_setitem_symbolic(self):
+        ones = '1' * 40
+        self._refs['HEAD'] = ones
+        self.assertEqual(ones, self._refs['HEAD'])
+
+        # ensure HEAD was not modified
+        f = open(os.path.join(self._refs.path, 'HEAD'), 'rb')
+        self.assertEqual('ref: refs/heads/master', iter(f).next().rstrip('\n'))
+        f.close()
+
+        # ensure the symbolic link was written through
+        f = open(os.path.join(self._refs.path, 'refs', 'heads', 'master'), 'rb')
+        self.assertEqual(ones, f.read()[:40])
+        f.close()
+
+    def test_set_if_equals(self):
+        nines = '9' * 40
+        self.assertFalse(self._refs.set_if_equals('HEAD', 'c0ffee', nines))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['HEAD'])
+
+        self.assertTrue(self._refs.set_if_equals(
+            'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec', nines))
+        self.assertEqual(nines, self._refs['HEAD'])
+
+        # ensure symref was followed
+        self.assertEqual(nines, self._refs['refs/heads/master'])
+
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'HEAD.lock')))
+
+    def test_add_if_new(self):
+        nines = '9' * 40
+        self.assertFalse(self._refs.add_if_new('refs/heads/master', nines))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/heads/master'])
+
+        self.assertTrue(self._refs.add_if_new('refs/some/ref', nines))
+        self.assertEqual(nines, self._refs['refs/some/ref'])
+
+        # don't overwrite packed ref
+        self.assertFalse(self._refs.add_if_new('refs/tags/refs-0.1', nines))
+        self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
+                         self._refs['refs/tags/refs-0.1'])
+
+    def test_check_refname(self):
+        try:
+            self._refs._check_refname('HEAD')
+        except KeyError:
+            self.fail()
+
+        try:
+            self._refs._check_refname('refs/heads/foo')
+        except KeyError:
+            self.fail()
+
+        self.assertRaises(KeyError, self._refs._check_refname, 'refs')
+        self.assertRaises(KeyError, self._refs._check_refname, 'notrefs/foo')
+
+    def test_follow(self):
+        self.assertEquals(
+            ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
+            self._refs._follow('HEAD'))
+        self.assertEquals(
+            ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
+            self._refs._follow('refs/heads/master'))
+        self.assertRaises(KeyError, self._refs._follow, 'notrefs/foo')
+        self.assertRaises(KeyError, self._refs._follow, 'refs/heads/loop')
+
+    def test_delitem(self):
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                          self._refs['refs/heads/master'])
+        del self._refs['refs/heads/master']
+        self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
+        ref_file = os.path.join(self._refs.path, 'refs', 'heads', 'master')
+        self.assertFalse(os.path.exists(ref_file))
+        self.assertFalse('refs/heads/master' in self._refs.get_packed_refs())
+
+    def test_delitem_symbolic(self):
+        self.assertEqual('ref: refs/heads/master',
+                          self._refs._read_ref_file('HEAD'))
+        del self._refs['HEAD']
+        self.assertRaises(KeyError, lambda: self._refs['HEAD'])
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/heads/master'])
+        self.assertFalse(os.path.exists(os.path.join(self._refs.path, 'HEAD')))
+
+    def test_remove_if_equals(self):
+        nines = '9' * 40
+        self.assertFalse(self._refs.remove_if_equals('HEAD', 'c0ffee'))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['HEAD'])
+
+        # HEAD is a symref, so shouldn't equal its dereferenced value
+        self.assertFalse(self._refs.remove_if_equals(
+            'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
+        self.assertTrue(self._refs.remove_if_equals(
+            'refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
+        self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
+
+        # HEAD is now a broken symref
+        self.assertRaises(KeyError, lambda: self._refs['HEAD'])
+        self.assertEqual('ref: refs/heads/master',
+                          self._refs._read_ref_file('HEAD'))
+
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'HEAD.lock')))
+
+        # test removing ref that is only packed
+        self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
+                         self._refs['refs/tags/refs-0.1'])
+        self.assertTrue(
+            self._refs.remove_if_equals('refs/tags/refs-0.1',
+            'df6800012397fb85c56e7418dd4eb9405dee075c'))
+        self.assertRaises(KeyError, lambda: self._refs['refs/tags/refs-0.1'])

+ 519 - 0
dulwich/tests/test_server.py

@@ -0,0 +1,519 @@
+# test_server.py -- Tests for the git server
+# Copyright (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# or (at your option) any later version of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+
+"""Tests for the smart protocol server."""
+
+
+from cStringIO import StringIO
+from unittest import TestCase
+
+from dulwich.errors import (
+    GitProtocolError,
+    )
+from dulwich.server import (
+    UploadPackHandler,
+    ProtocolGraphWalker,
+    SingleAckGraphWalkerImpl,
+    MultiAckGraphWalkerImpl,
+    MultiAckDetailedGraphWalkerImpl,
+    )
+
+from dulwich.protocol import (
+    SINGLE_ACK,
+    MULTI_ACK,
+    )
+
+ONE = '1' * 40
+TWO = '2' * 40
+THREE = '3' * 40
+FOUR = '4' * 40
+FIVE = '5' * 40
+
+class TestProto(object):
+    def __init__(self):
+        self._output = []
+        self._received = {0: [], 1: [], 2: [], 3: []}
+
+    def set_output(self, output_lines):
+        self._output = ['%s\n' % line.rstrip() for line in output_lines]
+
+    def read_pkt_line(self):
+        if self._output:
+            return self._output.pop(0)
+        else:
+            return None
+
+    def write_sideband(self, band, data):
+        self._received[band].append(data)
+
+    def write_pkt_line(self, data):
+        if data is None:
+            data = 'None'
+        self._received[0].append(data)
+
+    def get_received_line(self, band=0):
+        lines = self._received[band]
+        if lines:
+            return lines.pop(0)
+        else:
+            return None
+
+
+class UploadPackHandlerTestCase(TestCase):
+    def setUp(self):
+        self._handler = UploadPackHandler(None, None, None)
+
+    def test_set_client_capabilities(self):
+        try:
+            self._handler.set_client_capabilities([])
+        except GitProtocolError:
+            self.fail()
+
+        try:
+            self._handler.set_client_capabilities([
+                'multi_ack', 'side-band-64k', 'thin-pack', 'ofs-delta'])
+        except GitProtocolError:
+            self.fail()
+
+    def test_set_client_capabilities_error(self):
+        self.assertRaises(GitProtocolError,
+                          self._handler.set_client_capabilities,
+                          ['weird_ack_level', 'ofs-delta'])
+        try:
+            self._handler.set_client_capabilities(['include-tag'])
+        except GitProtocolError:
+            self.fail()
+
+
+class TestCommit(object):
+    def __init__(self, sha, parents, commit_time):
+        self.id = sha
+        self._parents = parents
+        self.commit_time = commit_time
+
+    def get_parents(self):
+        return self._parents
+
+    def __repr__(self):
+        return '%s(%s)' % (self.__class__.__name__, self._sha)
+
+
+class TestBackend(object):
+    def __init__(self, objects):
+        self.object_store = objects
+
+
+class TestHandler(object):
+    def __init__(self, objects, proto):
+        self.backend = TestBackend(objects)
+        self.proto = proto
+        self.stateless_rpc = False
+        self.advertise_refs = False
+
+    def capabilities(self):
+        return 'multi_ack'
+
+
+class ProtocolGraphWalkerTestCase(TestCase):
+    def setUp(self):
+        # Create the following commit tree:
+        #   3---5
+        #  /
+        # 1---2---4
+        self._objects = {
+            ONE: TestCommit(ONE, [], 111),
+            TWO: TestCommit(TWO, [ONE], 222),
+            THREE: TestCommit(THREE, [ONE], 333),
+            FOUR: TestCommit(FOUR, [TWO], 444),
+            FIVE: TestCommit(FIVE, [THREE], 555),
+            }
+        self._walker = ProtocolGraphWalker(
+            TestHandler(self._objects, TestProto()))
+
+    def test_is_satisfied_no_haves(self):
+        self.assertFalse(self._walker._is_satisfied([], ONE, 0))
+        self.assertFalse(self._walker._is_satisfied([], TWO, 0))
+        self.assertFalse(self._walker._is_satisfied([], THREE, 0))
+
+    def test_is_satisfied_have_root(self):
+        self.assertTrue(self._walker._is_satisfied([ONE], ONE, 0))
+        self.assertTrue(self._walker._is_satisfied([ONE], TWO, 0))
+        self.assertTrue(self._walker._is_satisfied([ONE], THREE, 0))
+
+    def test_is_satisfied_have_branch(self):
+        self.assertTrue(self._walker._is_satisfied([TWO], TWO, 0))
+        # wrong branch
+        self.assertFalse(self._walker._is_satisfied([TWO], THREE, 0))
+
+    def test_all_wants_satisfied(self):
+        self._walker.set_wants([FOUR, FIVE])
+        # trivial case: wants == haves
+        self.assertTrue(self._walker.all_wants_satisfied([FOUR, FIVE]))
+        # cases that require walking the commit tree
+        self.assertTrue(self._walker.all_wants_satisfied([ONE]))
+        self.assertFalse(self._walker.all_wants_satisfied([TWO]))
+        self.assertFalse(self._walker.all_wants_satisfied([THREE]))
+        self.assertTrue(self._walker.all_wants_satisfied([TWO, THREE]))
+
+    def test_read_proto_line(self):
+        self._walker.proto.set_output([
+            'want %s' % ONE,
+            'want %s' % TWO,
+            'have %s' % THREE,
+            'foo %s' % FOUR,
+            'bar',
+            'done',
+            ])
+        self.assertEquals(('want', ONE), self._walker.read_proto_line())
+        self.assertEquals(('want', TWO), self._walker.read_proto_line())
+        self.assertEquals(('have', THREE), self._walker.read_proto_line())
+        self.assertRaises(GitProtocolError, self._walker.read_proto_line)
+        self.assertRaises(GitProtocolError, self._walker.read_proto_line)
+        self.assertEquals(('done', None), self._walker.read_proto_line())
+        self.assertEquals((None, None), self._walker.read_proto_line())
+
+    def test_determine_wants(self):
+        self.assertRaises(GitProtocolError, self._walker.determine_wants, {})
+
+        self._walker.proto.set_output([
+            'want %s multi_ack' % ONE,
+            'want %s' % TWO,
+            ])
+        heads = {'ref1': ONE, 'ref2': TWO, 'ref3': THREE}
+        self.assertEquals([ONE, TWO], self._walker.determine_wants(heads))
+
+        self._walker.proto.set_output(['want %s multi_ack' % FOUR])
+        self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
+
+        self._walker.proto.set_output([])
+        self.assertEquals([], self._walker.determine_wants(heads))
+
+        self._walker.proto.set_output(['want %s multi_ack' % ONE, 'foo'])
+        self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
+
+        self._walker.proto.set_output(['want %s multi_ack' % FOUR])
+        self.assertRaises(GitProtocolError, self._walker.determine_wants, heads)
+
+    # TODO: test commit time cutoff
+
+
+class TestProtocolGraphWalker(object):
+    def __init__(self):
+        self.acks = []
+        self.lines = []
+        self.done = False
+        self.stateless_rpc = False
+        self.advertise_refs = False
+
+    def read_proto_line(self):
+        return self.lines.pop(0)
+
+    def send_ack(self, sha, ack_type=''):
+        self.acks.append((sha, ack_type))
+
+    def send_nak(self):
+        self.acks.append((None, 'nak'))
+
+    def all_wants_satisfied(self, haves):
+        return self.done
+
+    def pop_ack(self):
+        if not self.acks:
+            return None
+        return self.acks.pop(0)
+
+
+class AckGraphWalkerImplTestCase(TestCase):
+    """Base setup and asserts for AckGraphWalker tests."""
+    def setUp(self):
+        self._walker = TestProtocolGraphWalker()
+        self._walker.lines = [
+            ('have', TWO),
+            ('have', ONE),
+            ('have', THREE),
+            ('done', None),
+            ]
+        self._impl = self.impl_cls(self._walker)
+
+    def assertNoAck(self):
+        self.assertEquals(None, self._walker.pop_ack())
+
+    def assertAcks(self, acks):
+        for sha, ack_type in acks:
+            self.assertEquals((sha, ack_type), self._walker.pop_ack())
+        self.assertNoAck()
+
+    def assertAck(self, sha, ack_type=''):
+        self.assertAcks([(sha, ack_type)])
+
+    def assertNak(self):
+        self.assertAck(None, 'nak')
+
+    def assertNextEquals(self, sha):
+        self.assertEquals(sha, self._impl.next())
+
+
+class SingleAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
+    impl_cls = SingleAckGraphWalkerImpl
+
+    def test_single_ack(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self._walker.done = True
+        self._impl.ack(ONE)
+        self.assertAck(ONE)
+
+        self.assertNextEquals(THREE)
+        self._impl.ack(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNoAck()
+
+    def test_single_ack_flush(self):
+        # same as ack test but ends with a flush-pkt instead of done
+        self._walker.lines[-1] = (None, None)
+
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self._walker.done = True
+        self._impl.ack(ONE)
+        self.assertAck(ONE)
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNoAck()
+
+    def test_single_ack_nak(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNoAck()
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNak()
+
+    def test_single_ack_nak_flush(self):
+        # same as nak test but ends with a flush-pkt instead of done
+        self._walker.lines[-1] = (None, None)
+
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNoAck()
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNak()
+
+class MultiAckGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
+    impl_cls = MultiAckGraphWalkerImpl
+
+    def test_multi_ack(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self._walker.done = True
+        self._impl.ack(ONE)
+        self.assertAck(ONE, 'continue')
+
+        self.assertNextEquals(THREE)
+        self._impl.ack(THREE)
+        self.assertAck(THREE, 'continue')
+
+        self.assertNextEquals(None)
+        self.assertAck(THREE)
+
+    def test_multi_ack_partial(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self._impl.ack(ONE)
+        self.assertAck(ONE, 'continue')
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        # done, re-send ack of last common
+        self.assertAck(ONE)
+
+    def test_multi_ack_flush(self):
+        self._walker.lines = [
+            ('have', TWO),
+            (None, None),
+            ('have', ONE),
+            ('have', THREE),
+            ('done', None),
+            ]
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNak() # nak the flush-pkt
+
+        self._walker.done = True
+        self._impl.ack(ONE)
+        self.assertAck(ONE, 'continue')
+
+        self.assertNextEquals(THREE)
+        self._impl.ack(THREE)
+        self.assertAck(THREE, 'continue')
+
+        self.assertNextEquals(None)
+        self.assertAck(THREE)
+
+    def test_multi_ack_nak(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNoAck()
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNak()
+
+class MultiAckDetailedGraphWalkerImplTestCase(AckGraphWalkerImplTestCase):
+    impl_cls = MultiAckDetailedGraphWalkerImpl
+
+    def test_multi_ack(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self._walker.done = True
+        self._impl.ack(ONE)
+        self.assertAcks([(ONE, 'common'), (ONE, 'ready')])
+
+        self.assertNextEquals(THREE)
+        self._impl.ack(THREE)
+        self.assertAck(THREE, 'ready')
+
+        self.assertNextEquals(None)
+        self.assertAck(THREE)
+
+    def test_multi_ack_partial(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self._impl.ack(ONE)
+        self.assertAck(ONE, 'common')
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        # done, re-send ack of last common
+        self.assertAck(ONE)
+
+    def test_multi_ack_flush(self):
+        # same as ack test but contains a flush-pkt in the middle
+        self._walker.lines = [
+            ('have', TWO),
+            (None, None),
+            ('have', ONE),
+            ('have', THREE),
+            ('done', None),
+            ]
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNak() # nak the flush-pkt
+
+        self._walker.done = True
+        self._impl.ack(ONE)
+        self.assertAcks([(ONE, 'common'), (ONE, 'ready')])
+
+        self.assertNextEquals(THREE)
+        self._impl.ack(THREE)
+        self.assertAck(THREE, 'ready')
+
+        self.assertNextEquals(None)
+        self.assertAck(THREE)
+
+    def test_multi_ack_nak(self):
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNoAck()
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNak()
+
+    def test_multi_ack_nak_flush(self):
+        # same as nak test but contains a flush-pkt in the middle
+        self._walker.lines = [
+            ('have', TWO),
+            (None, None),
+            ('have', ONE),
+            ('have', THREE),
+            ('done', None),
+            ]
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNak()
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNak()
+
+    def test_multi_ack_stateless(self):
+        # transmission ends with a flush-pkt
+        self._walker.lines[-1] = (None, None)
+        self._walker.stateless_rpc = True
+
+        self.assertNextEquals(TWO)
+        self.assertNoAck()
+
+        self.assertNextEquals(ONE)
+        self.assertNoAck()
+
+        self.assertNextEquals(THREE)
+        self.assertNoAck()
+
+        self.assertNextEquals(None)
+        self.assertNak()

+ 289 - 0
dulwich/tests/test_web.py

@@ -0,0 +1,289 @@
+# test_web.py -- Tests for the git HTTP server
+# Copryight (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# or (at your option) any later version of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""Tests for the Git HTTP server."""
+
+from cStringIO import StringIO
+import re
+from unittest import TestCase
+
+from dulwich.objects import (
+    type_map,
+    Tag,
+    Blob,
+    )
+from dulwich.web import (
+    HTTP_OK,
+    HTTP_NOT_FOUND,
+    HTTP_FORBIDDEN,
+    send_file,
+    get_info_refs,
+    handle_service_request,
+    _LengthLimitedFile,
+    HTTPGitRequest,
+    HTTPGitApplication,
+    )
+
+
+class WebTestCase(TestCase):
+    """Base TestCase that sets up some useful instance vars."""
+    def setUp(self):
+        self._environ = {}
+        self._req = HTTPGitRequest(self._environ, self._start_response)
+        self._status = None
+        self._headers = []
+
+    def _start_response(self, status, headers):
+        self._status = status
+        self._headers = list(headers)
+
+
+class DumbHandlersTestCase(WebTestCase):
+
+    def test_send_file_not_found(self):
+        list(send_file(self._req, None, 'text/plain'))
+        self.assertEquals(HTTP_NOT_FOUND, self._status)
+
+    def test_send_file(self):
+        f = StringIO('foobar')
+        output = ''.join(send_file(self._req, f, 'text/plain'))
+        self.assertEquals('foobar', output)
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertTrue(('Content-Type', 'text/plain') in self._headers)
+        self.assertTrue(f.closed)
+
+    def test_send_file_buffered(self):
+        bufsize = 10240
+        xs = 'x' * bufsize
+        f = StringIO(2 * xs)
+        self.assertEquals([xs, xs],
+                          list(send_file(self._req, f, 'text/plain')))
+        self.assertEquals(HTTP_OK, self._status)
+        self.assertTrue(('Content-Type', 'text/plain') in self._headers)
+        self.assertTrue(f.closed)
+
+    def test_send_file_error(self):
+        class TestFile(object):
+            def __init__(self):
+                self.closed = False
+
+            def read(self, size=-1):
+                raise IOError
+
+            def close(self):
+                self.closed = True
+
+        f = TestFile()
+        list(send_file(self._req, f, 'text/plain'))
+        self.assertEquals(HTTP_NOT_FOUND, self._status)
+        self.assertTrue(f.closed)
+
+    def test_get_info_refs(self):
+        self._environ['QUERY_STRING'] = ''
+
+        class TestTag(object):
+            type = Tag().type
+
+            def __init__(self, sha, obj_type, obj_sha):
+                self.sha = lambda: sha
+                self.object = (obj_type, obj_sha)
+
+        class TestBlob(object):
+            type = Blob().type
+
+            def __init__(self, sha):
+                self.sha = lambda: sha
+
+        blob1 = TestBlob('111')
+        blob2 = TestBlob('222')
+        blob3 = TestBlob('333')
+
+        tag1 = TestTag('aaa', TestTag.type, 'bbb')
+        tag2 = TestTag('bbb', TestBlob.type, '222')
+
+        class TestBackend(object):
+            def __init__(self):
+                objects = [blob1, blob2, blob3, tag1, tag2]
+                self.repo = dict((o.sha(), o) for o in objects)
+
+            def get_refs(self):
+                return {
+                    'HEAD': '000',
+                    'refs/heads/master': blob1.sha(),
+                    'refs/tags/tag-tag': tag1.sha(),
+                    'refs/tags/blob-tag': blob3.sha(),
+                    }
+
+        self.assertEquals(['111\trefs/heads/master\n',
+                           '333\trefs/tags/blob-tag\n',
+                           'aaa\trefs/tags/tag-tag\n',
+                           '222\trefs/tags/tag-tag^{}\n'],
+                          list(get_info_refs(self._req, TestBackend(), None)))
+
+
+class SmartHandlersTestCase(WebTestCase):
+
+    class TestProtocol(object):
+        def __init__(self, handler):
+            self._handler = handler
+
+        def write_pkt_line(self, line):
+            if line is None:
+                self._handler.write('flush-pkt\n')
+            else:
+                self._handler.write('pkt-line: %s' % line)
+
+    class _TestUploadPackHandler(object):
+        def __init__(self, backend, read, write, stateless_rpc=False,
+                     advertise_refs=False):
+            self.read = read
+            self.write = write
+            self.proto = SmartHandlersTestCase.TestProtocol(self)
+            self.stateless_rpc = stateless_rpc
+            self.advertise_refs = advertise_refs
+
+        def handle(self):
+            self.write('handled input: %s' % self.read())
+
+    def _MakeHandler(self, *args, **kwargs):
+        self._handler = self._TestUploadPackHandler(*args, **kwargs)
+        return self._handler
+
+    def services(self):
+        return {'git-upload-pack': self._MakeHandler}
+
+    def test_handle_service_request_unknown(self):
+        mat = re.search('.*', '/git-evil-handler')
+        list(handle_service_request(self._req, 'backend', mat))
+        self.assertEquals(HTTP_FORBIDDEN, self._status)
+
+    def test_handle_service_request(self):
+        self._environ['wsgi.input'] = StringIO('foo')
+        mat = re.search('.*', '/git-upload-pack')
+        output = ''.join(handle_service_request(self._req, 'backend', mat,
+                                                services=self.services()))
+        self.assertEqual('handled input: foo', output)
+        response_type = 'application/x-git-upload-pack-response'
+        self.assertTrue(('Content-Type', response_type) in self._headers)
+        self.assertFalse(self._handler.advertise_refs)
+        self.assertTrue(self._handler.stateless_rpc)
+
+    def test_handle_service_request_with_length(self):
+        self._environ['wsgi.input'] = StringIO('foobar')
+        self._environ['CONTENT_LENGTH'] = 3
+        mat = re.search('.*', '/git-upload-pack')
+        output = ''.join(handle_service_request(self._req, 'backend', mat,
+                                                services=self.services()))
+        self.assertEqual('handled input: foo', output)
+        response_type = 'application/x-git-upload-pack-response'
+        self.assertTrue(('Content-Type', response_type) in self._headers)
+
+    def test_get_info_refs_unknown(self):
+        self._environ['QUERY_STRING'] = 'service=git-evil-handler'
+        list(get_info_refs(self._req, 'backend', None,
+                           services=self.services()))
+        self.assertEquals(HTTP_FORBIDDEN, self._status)
+
+    def test_get_info_refs(self):
+        self._environ['wsgi.input'] = StringIO('foo')
+        self._environ['QUERY_STRING'] = 'service=git-upload-pack'
+
+        output = ''.join(get_info_refs(self._req, 'backend', None,
+                                       services=self.services()))
+        self.assertEquals(('pkt-line: # service=git-upload-pack\n'
+                           'flush-pkt\n'
+                           # input is ignored by the handler
+                           'handled input: '), output)
+        self.assertTrue(self._handler.advertise_refs)
+        self.assertTrue(self._handler.stateless_rpc)
+
+
+class LengthLimitedFileTestCase(TestCase):
+    def test_no_cutoff(self):
+        f = _LengthLimitedFile(StringIO('foobar'), 1024)
+        self.assertEquals('foobar', f.read())
+
+    def test_cutoff(self):
+        f = _LengthLimitedFile(StringIO('foobar'), 3)
+        self.assertEquals('foo', f.read())
+        self.assertEquals('', f.read())
+
+    def test_multiple_reads(self):
+        f = _LengthLimitedFile(StringIO('foobar'), 3)
+        self.assertEquals('fo', f.read(2))
+        self.assertEquals('o', f.read(2))
+        self.assertEquals('', f.read())
+
+
+class HTTPGitRequestTestCase(WebTestCase):
+    def test_not_found(self):
+        self._req.cache_forever()  # cache headers should be discarded
+        message = 'Something not found'
+        self.assertEquals(message, self._req.not_found(message))
+        self.assertEquals(HTTP_NOT_FOUND, self._status)
+        self.assertEquals(set([('Content-Type', 'text/plain')]),
+                          set(self._headers))
+
+    def test_forbidden(self):
+        self._req.cache_forever()  # cache headers should be discarded
+        message = 'Something not found'
+        self.assertEquals(message, self._req.forbidden(message))
+        self.assertEquals(HTTP_FORBIDDEN, self._status)
+        self.assertEquals(set([('Content-Type', 'text/plain')]),
+                          set(self._headers))
+
+    def test_respond_ok(self):
+        self._req.respond()
+        self.assertEquals([], self._headers)
+        self.assertEquals(HTTP_OK, self._status)
+
+    def test_respond(self):
+        self._req.nocache()
+        self._req.respond(status=402, content_type='some/type',
+                          headers=[('X-Foo', 'foo'), ('X-Bar', 'bar')])
+        self.assertEquals(set([
+            ('X-Foo', 'foo'),
+            ('X-Bar', 'bar'),
+            ('Content-Type', 'some/type'),
+            ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
+            ('Pragma', 'no-cache'),
+            ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
+            ]), set(self._headers))
+        self.assertEquals(402, self._status)
+
+
+class HTTPGitApplicationTestCase(TestCase):
+    def setUp(self):
+        self._app = HTTPGitApplication('backend')
+
+    def test_call(self):
+        def test_handler(req, backend, mat):
+            # tests interface used by all handlers
+            self.assertEquals(environ, req.environ)
+            self.assertEquals('backend', backend)
+            self.assertEquals('/foo', mat.group(0))
+            return 'output'
+
+        self._app.services = {
+            ('GET', re.compile('/foo$')): test_handler,
+        }
+        environ = {
+            'PATH_INFO': '/foo',
+            'REQUEST_METHOD': 'GET',
+            }
+        self.assertEquals('output', self._app(environ, None))

+ 311 - 0
dulwich/web.py

@@ -0,0 +1,311 @@
+# web.py -- WSGI smart-http server
+# Copryight (C) 2010 Google, Inc.
+#
+# This program is free software; you can redistribute it and/or
+# modify it under the terms of the GNU General Public License
+# as published by the Free Software Foundation; version 2
+# or (at your option) any later version of the License.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
+# MA  02110-1301, USA.
+
+"""HTTP server for dulwich that implements the git smart HTTP protocol."""
+
+from cStringIO import StringIO
+import cgi
+import os
+import re
+import time
+
+from dulwich.objects import (
+    Tag,
+    num_type_map,
+    )
+from dulwich.repo import (
+    Repo,
+    )
+from dulwich.server import (
+    GitBackend,
+    ReceivePackHandler,
+    UploadPackHandler,
+    )
+
+HTTP_OK = '200 OK'
+HTTP_NOT_FOUND = '404 Not Found'
+HTTP_FORBIDDEN = '403 Forbidden'
+
+
+def date_time_string(self, timestamp=None):
+    # Based on BaseHTTPServer.py in python2.5
+    weekdays = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
+    months = [None,
+              'Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
+              'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
+    if timestamp is None:
+        timestamp = time.time()
+    year, month, day, hh, mm, ss, wd, y, z = time.gmtime(timestamp)
+    return '%s, %02d %3s %4d %02d:%02d:%02d GMD' % (
+            weekdays[wd], day, months[month], year, hh, mm, ss)
+
+
+def send_file(req, f, content_type):
+    """Send a file-like object to the request output.
+
+    :param req: The HTTPGitRequest object to send output to.
+    :param f: An open file-like object to send; will be closed.
+    :param content_type: The MIME type for the file.
+    :yield: The contents of the file.
+    """
+    if f is None:
+        yield req.not_found('File not found')
+        return
+    try:
+        try:
+            req.respond(HTTP_OK, content_type)
+            while True:
+                data = f.read(10240)
+                if not data:
+                    break
+                yield data
+        except IOError:
+            yield req.not_found('Error reading file')
+    finally:
+        f.close()
+
+
+def get_text_file(req, backend, mat):
+    req.nocache()
+    return send_file(req, backend.repo.get_named_file(mat.group()),
+                     'text/plain')
+
+
+def get_loose_object(req, backend, mat):
+    sha = mat.group(1) + mat.group(2)
+    object_store = backend.object_store
+    if not object_store.contains_loose(sha):
+        yield req.not_found('Object not found')
+        return
+    try:
+        data = object_store[sha].as_legacy_object()
+    except IOError:
+        yield req.not_found('Error reading object')
+    req.cache_forever()
+    req.respond(HTTP_OK, 'application/x-git-loose-object')
+    yield data
+
+
+def get_pack_file(req, backend, mat):
+    req.cache_forever()
+    return send_file(req, backend.repo.get_named_file(mat.group()),
+                     'application/x-git-packed-objects', False)
+
+
+def get_idx_file(req, backend, mat):
+    req.cache_forever()
+    return send_file(req, backend.repo.get_named_file(mat.group()),
+                     'application/x-git-packed-objects-toc', False)
+
+
+services = {'git-upload-pack': UploadPackHandler,
+            'git-receive-pack': ReceivePackHandler}
+def get_info_refs(req, backend, mat, services=None):
+    if services is None:
+        services = services
+    params = cgi.parse_qs(req.environ['QUERY_STRING'])
+    service = params.get('service', [None])[0]
+    if service:
+        handler_cls = services.get(service, None)
+        if handler_cls is None:
+            yield req.forbidden('Unsupported service %s' % service)
+            return
+        req.nocache()
+        req.respond(HTTP_OK, 'application/x-%s-advertisement' % service)
+        output = StringIO()
+        dummy_input = StringIO()  # GET request, handler doesn't need to read
+        handler = handler_cls(backend, dummy_input.read, output.write,
+                              stateless_rpc=True, advertise_refs=True)
+        handler.proto.write_pkt_line('# service=%s\n' % service)
+        handler.proto.write_pkt_line(None)
+        handler.handle()
+        yield output.getvalue()
+    else:
+        # non-smart fallback
+        # TODO: select_getanyfile() (see http-backend.c)
+        req.nocache()
+        req.respond(HTTP_OK, 'text/plain')
+        refs = backend.get_refs()
+        for name in sorted(refs.iterkeys()):
+            # get_refs() includes HEAD as a special case, but we don't want to
+            # advertise it
+            if name == 'HEAD':
+                continue
+            sha = refs[name]
+            o = backend.repo[sha]
+            if not o:
+                continue
+            yield '%s\t%s\n' % (sha, name)
+            obj_type = num_type_map[o.type]
+            if obj_type == Tag:
+                while obj_type == Tag:
+                    num_type, sha = o.object
+                    obj_type = num_type_map[num_type]
+                    o = backend.repo[sha]
+                if not o:
+                    continue
+                yield '%s\t%s^{}\n' % (o.sha(), name)
+
+
+def get_info_packs(req, backend, mat):
+    req.nocache()
+    req.respond(HTTP_OK, 'text/plain')
+    for pack in backend.object_store.packs:
+        yield 'P pack-%s.pack\n' % pack.name()
+
+
+class _LengthLimitedFile(object):
+    """Wrapper class to limit the length of reads from a file-like object.
+
+    This is used to ensure EOF is read from the wsgi.input object once
+    Content-Length bytes are read. This behavior is required by the WSGI spec
+    but not implemented in wsgiref as of 2.5.
+    """
+    def __init__(self, input, max_bytes):
+        self._input = input
+        self._bytes_avail = max_bytes
+
+    def read(self, size=-1):
+        if self._bytes_avail <= 0:
+            return ''
+        if size == -1 or size > self._bytes_avail:
+            size = self._bytes_avail
+        self._bytes_avail -= size
+        return self._input.read(size)
+
+    # TODO: support more methods as necessary
+
+def handle_service_request(req, backend, mat, services=services):
+    if services is None:
+        services = services
+    service = mat.group().lstrip('/')
+    handler_cls = services.get(service, None)
+    if handler_cls is None:
+        yield req.forbidden('Unsupported service %s' % service)
+        return
+    req.nocache()
+    req.respond(HTTP_OK, 'application/x-%s-response' % service)
+
+    output = StringIO()
+    input = req.environ['wsgi.input']
+    # This is not necessary if this app is run from a conforming WSGI server.
+    # Unfortunately, there's no way to tell that at this point.
+    # TODO: git may used HTTP/1.1 chunked encoding instead of specifying
+    # content-length
+    if 'CONTENT_LENGTH' in req.environ:
+        input = _LengthLimitedFile(input, int(req.environ['CONTENT_LENGTH']))
+    handler = handler_cls(backend, input.read, output.write, stateless_rpc=True)
+    handler.handle()
+    yield output.getvalue()
+
+
+class HTTPGitRequest(object):
+    """Class encapsulating the state of a single git HTTP request.
+
+    :ivar environ: the WSGI environment for the request.
+    """
+
+    def __init__(self, environ, start_response):
+        self.environ = environ
+        self._start_response = start_response
+        self._cache_headers = []
+        self._headers = []
+
+    def add_header(self, name, value):
+        """Add a header to the response."""
+        self._headers.append((name, value))
+
+    def respond(self, status=HTTP_OK, content_type=None, headers=None):
+        """Begin a response with the given status and other headers."""
+        if headers:
+            self._headers.extend(headers)
+        if content_type:
+            self._headers.append(('Content-Type', content_type))
+        self._headers.extend(self._cache_headers)
+
+        self._start_response(status, self._headers)
+
+    def not_found(self, message):
+        """Begin a HTTP 404 response and return the text of a message."""
+        self._cache_headers = []
+        self.respond(HTTP_NOT_FOUND, 'text/plain')
+        return message
+
+    def forbidden(self, message):
+        """Begin a HTTP 403 response and return the text of a message."""
+        self._cache_headers = []
+        self.respond(HTTP_FORBIDDEN, 'text/plain')
+        return message
+
+    def nocache(self):
+        """Set the response to never be cached by the client."""
+        self._cache_headers = [
+            ('Expires', 'Fri, 01 Jan 1980 00:00:00 GMT'),
+            ('Pragma', 'no-cache'),
+            ('Cache-Control', 'no-cache, max-age=0, must-revalidate'),
+            ]
+
+    def cache_forever(self):
+        """Set the response to be cached forever by the client."""
+        now = time.time()
+        self._cache_headers = [
+            ('Date', date_time_string(now)),
+            ('Expires', date_time_string(now + 31536000)),
+            ('Cache-Control', 'public, max-age=31536000'),
+            ]
+
+
+class HTTPGitApplication(object):
+    """Class encapsulating the state of a git WSGI application.
+
+    :ivar backend: the Backend object backing this application
+    """
+
+    services = {
+        ('GET', re.compile('/HEAD$')): get_text_file,
+        ('GET', re.compile('/info/refs$')): get_info_refs,
+        ('GET', re.compile('/objects/info/alternates$')): get_text_file,
+        ('GET', re.compile('/objects/info/http-alternates$')): get_text_file,
+        ('GET', re.compile('/objects/info/packs$')): get_info_packs,
+        ('GET', re.compile('/objects/([0-9a-f]{2})/([0-9a-f]{38})$')): get_loose_object,
+        ('GET', re.compile('/objects/pack/pack-([0-9a-f]{40})\\.pack$')): get_pack_file,
+        ('GET', re.compile('/objects/pack/pack-([0-9a-f]{40})\\.idx$')): get_idx_file,
+
+        ('POST', re.compile('/git-upload-pack$')): handle_service_request,
+        ('POST', re.compile('/git-receive-pack$')): handle_service_request,
+    }
+
+    def __init__(self, backend):
+        self.backend = backend
+
+    def __call__(self, environ, start_response):
+        path = environ['PATH_INFO']
+        method = environ['REQUEST_METHOD']
+        req = HTTPGitRequest(environ, start_response)
+        # environ['QUERY_STRING'] has qs args
+        handler = None
+        for smethod, spath in self.services.iterkeys():
+            if smethod != method:
+                continue
+            mat = spath.search(path)
+            if mat:
+                handler = self.services[smethod, spath]
+                break
+        if handler is None:
+            return req.not_found('Sorry, that method is not supported')
+        return handler(req, self.backend, mat)

+ 1 - 1
setup.py

@@ -51,7 +51,7 @@ setup(name='dulwich',
       in one of the Monty Python sketches.
       """,
       packages=['dulwich', 'dulwich.tests'],
-      scripts=['bin/dulwich', 'bin/dul-daemon'],
+      scripts=['bin/dulwich', 'bin/dul-daemon', 'bin/dul-web'],
       features = {'speedups': speedups},
       ext_modules = mandatory_ext_modules,
       )