Ver código fonte

Improve ref handling.

DiskRefsConatiner now handles packed-refs transparently, and
understands both the peeled and unpeeled packed-refs formats
correctly. Ref cycles are handled by giving up after reaching an
arbitrary recurion depth.

Includes tests for all new functionality.

Change-Id: I742117d7a2b99cbb52ee2a9d3a625037844c55b6
Dave Borowitz 15 anos atrás
pai
commit
afad608e26

+ 8 - 0
dulwich/errors.py

@@ -108,3 +108,11 @@ class HangupException(GitProtocolError):
     def __init__(self):
         Exception.__init__(self,
             "The remote server unexpectedly closed the connection.")
+
+
+class FileFormatException(Exception):
+    """Base class for exceptions relating to reading git file formats."""
+
+
+class PackedRefsException(FileFormatException):
+    """Indicates an error parsing a packed-refs file."""

+ 13 - 0
dulwich/file.py

@@ -23,6 +23,13 @@
 import errno
 import os
 
+def ensure_dir_exists(dirname):
+    """Ensure a directory exists, creating if necessary."""
+    try:
+        os.makedirs(dirname)
+    except OSError, e:
+        if e.errno != errno.EEXIST:
+            raise
 
 def GitFile(filename, mode='r', bufsize=-1):
     """Create a file object that obeys the git file locking protocol.
@@ -84,6 +91,7 @@ class _GitFile(object):
         self._lockfilename = '%s.lock' % self._filename
         fd = os.open(self._lockfilename, os.O_RDWR | os.O_CREAT | os.O_EXCL)
         self._file = os.fdopen(fd, mode, bufsize)
+        self._closed = False
 
         for method in self.PROXY_METHODS:
             setattr(self, method, getattr(self._file, method))
@@ -93,9 +101,12 @@ class _GitFile(object):
 
         If the file is already closed, this is a no-op.
         """
+        if self._closed:
+            return
         self._file.close()
         try:
             os.remove(self._lockfilename)
+            self._closed = True
         except OSError, e:
             # The file may have been removed already, which is ok.
             if e.errno != errno.ENOENT:
@@ -112,6 +123,8 @@ class _GitFile(object):
             file is still closed, so further attempts to write to the same file
             object will raise ValueError.
         """
+        if self._closed:
+            return
         self._file.close()
         try:
             os.rename(self._lockfilename, self._filename)

+ 374 - 109
dulwich/repo.py

@@ -22,6 +22,7 @@
 """Repository access."""
 
 
+import errno
 import os
 import stat
 
@@ -31,8 +32,12 @@ from dulwich.errors import (
     NotCommitError, 
     NotGitRepository,
     NotTreeError, 
+    PackedRefsException,
+    )
+from dulwich.file import (
+    ensure_dir_exists,
+    GitFile,
     )
-from dulwich.file import GitFile
 from dulwich.object_store import (
     DiskObjectStore,
     )
@@ -42,6 +47,7 @@ from dulwich.objects import (
     ShaFile,
     Tag,
     Tree,
+    hex_to_sha,
     )
 
 OBJECTDIR = 'objects'
@@ -52,20 +58,36 @@ REFSDIR_HEADS = 'heads'
 INDEX_FILENAME = "index"
 
 
-def follow_ref(container, name):
-    """Follow a ref back to a SHA1.
-    
-    :param container: Ref container to use for looking up refs.
-    :param name: Name of the original ref.
+def check_ref_format(refname):
+    """Check if a refname is correctly formatted.
+
+    Implements all the same rules as git-check-ref-format[1].
+
+    [1] http://www.kernel.org/pub/software/scm/git/docs/git-check-ref-format.html
+
+    :param refname: The refname to check
+    :return: True if refname is valid, False otherwise
     """
-    contents = container[name]
-    if contents.startswith(SYMREF):
-        ref = contents[len(SYMREF):]
-        if ref[-1] == '\n':
-            ref = ref[:-1]
-        return follow_ref(container, ref)
-    assert len(contents) == 40, 'Invalid ref in %s' % name
-    return contents
+    # These could be combined into one big expression, but are listed separately
+    # to parallel [1].
+    if '/.' in refname or refname.startswith('.'):
+        return False
+    if '/' not in refname:
+        return False
+    if '..' in refname:
+        return False
+    for c in refname:
+        if ord(c) < 040 or c in '\177 ~^:?*[':
+            return False
+    if refname[-1] in '/.':
+        return False
+    if refname.endswith('.lock'):
+        return False
+    if '@{' in refname:
+        return False
+    if '\\' in refname:
+        return False
+    return True
 
 
 class RefsContainer(object):
@@ -75,13 +97,6 @@ class RefsContainer(object):
         """Return the contents of this ref container under base as a dict."""
         raise NotImplementedError(self.as_dict)
 
-    def follow(self, name):
-        """Follow a ref name back to a SHA1.
-        
-        :param name: Name of the ref
-        """
-        return follow_ref(self, name)
-
     def set_ref(self, name, other):
         """Make a ref point at another ref.
 
@@ -100,54 +115,68 @@ class DiskRefsContainer(RefsContainer):
 
     def __init__(self, path):
         self.path = path
+        self._packed_refs = None
+        self._peeled_refs = {}
 
     def __repr__(self):
         return "%s(%r)" % (self.__class__.__name__, self.path)
 
     def keys(self, base=None):
-        """Refs present in this container."""
-        return list(self.iterkeys(base))
+        """Refs present in this container.
 
-    def iterkeys(self, base=None):
+        :param base: An optional base to return refs under
+        :return: An unsorted set of valid refs in this container, including
+            packed refs.
+        """
         if base is not None:
-            return self.itersubkeys(base)
+            return self.subkeys(base)
         else:
-            return self.iterallkeys()
+            return self.allkeys()
 
-    def itersubkeys(self, base):
+    def subkeys(self, base):
+        keys = set()
         path = self.refpath(base)
         for root, dirs, files in os.walk(path):
-            dir = root[len(path):].strip("/").replace(os.path.sep, "/")
+            dir = root[len(path):].strip(os.path.sep).replace(os.path.sep, "/")
             for filename in files:
-                yield ("%s/%s" % (dir, filename)).strip("/")
-
-    def iterallkeys(self):
+                refname = ("%s/%s" % (dir, filename)).strip("/")
+                # check_ref_format requires at least one /, so we prepend the
+                # base before calling it.
+                if check_ref_format("%s/%s" % (base, refname)):
+                    keys.add(refname)
+        for key in self.get_packed_refs():
+            if key.startswith(base):
+                keys.add(key[len(base):].strip("/"))
+        return keys
+
+    def allkeys(self):
+        keys = set()
         if os.path.exists(self.refpath("HEAD")):
-            yield "HEAD"
+            keys.add("HEAD")
         path = self.refpath("")
         for root, dirs, files in os.walk(self.refpath("refs")):
-            dir = root[len(path):].strip("/").replace(os.path.sep, "/")
+            dir = root[len(path):].strip(os.path.sep).replace(os.path.sep, "/")
             for filename in files:
-                yield ("%s/%s" % (dir, filename)).strip("/")
+                refname = ("%s/%s" % (dir, filename)).strip("/")
+                if check_ref_format(refname):
+                    keys.add(refname)
+        keys.update(self.get_packed_refs())
+        return keys
 
-    def as_dict(self, base=None, follow=True):
+    def as_dict(self, base=None):
         """Return the contents of this container as a dictionary.
 
         """
         ret = {}
+        keys = self.keys(base)
         if base is None:
-            keys = self.iterkeys()
             base = ""
-        else:
-            keys = self.itersubkeys(base)
         for key in keys:
-                if follow:
-                    try:
-                        ret[key] = self.follow(("%s/%s" % (base, key)).strip("/"))
-                    except KeyError:
-                        continue # Unable to resolve
-                else:
-                    ret[key] = self[("%s/%s" % (base, key)).strip("/")]
+            try:
+                ret[key] = self[("%s/%s" % (base, key)).strip("/")]
+            except KeyError:
+                continue # Unable to resolve
+
         return ret
 
     def refpath(self, name):
@@ -158,51 +187,328 @@ class DiskRefsContainer(RefsContainer):
             name = name.replace("/", os.path.sep)
         return os.path.join(self.path, name)
 
+    def get_packed_refs(self):
+        """Get contents of the packed-refs file.
+
+        :return: Dictionary mapping ref names to SHA1s
+
+        :note: Will return an empty dictionary when no packed-refs file is
+            present.
+        """
+        # TODO: invalidate the cache on repacking
+        if self._packed_refs is None:
+            self._packed_refs = {}
+            path = os.path.join(self.path, 'packed-refs')
+            try:
+                f = GitFile(path, 'rb')
+            except IOError, e:
+                if e.errno == errno.ENOENT:
+                    return {}
+                raise
+            try:
+                first_line = iter(f).next().rstrip()
+                if (first_line.startswith("# pack-refs") and " peeled" in
+                        first_line):
+                    for sha, name, peeled in read_packed_refs_with_peeled(f):
+                        self._packed_refs[name] = sha
+                        if peeled:
+                            self._peeled_refs[name] = peeled
+                else:
+                    f.seek(0)
+                    for sha, name in read_packed_refs(f):
+                        self._packed_refs[name] = sha
+            finally:
+                f.close()
+        return self._packed_refs
+
+    def _check_refname(self, name):
+        """Ensure a refname is valid and lives in refs or is HEAD.
+
+        HEAD is not a valid refname according to git-check-ref-format, but this
+        class needs to be able to touch HEAD. Also, check_ref_format expects
+        refnames without the leading 'refs/', but this class requires that
+        so it cannot touch anything outside the refs dir (or HEAD).
+
+        :param name: The name of the reference.
+        :raises KeyError: if a refname is not HEAD or is otherwise not valid.
+        """
+        if name == 'HEAD':
+            return
+        if not name.startswith('refs/') or not check_ref_format(name[5:]):
+            raise KeyError(name)
+
+    def _read_ref_file(self, name):
+        """Read a reference file and return its contents.
+
+        If the reference file a symbolic reference, only read the first line of
+        the file. Otherwise, only read the first 40 bytes.
+
+        :param name: the refname to read, relative to refpath
+        :return: The contents of the ref file, or None if the file does not
+            exist.
+        :raises IOError: if any other error occurs
+        """
+        filename = self.refpath(name)
+        try:
+            f = GitFile(filename, 'rb')
+            try:
+                header = f.read(len(SYMREF))
+                if header == SYMREF:
+                    # Read only the first line
+                    return header + iter(f).next().rstrip("\n")
+                else:
+                    # Read only the first 40 bytes
+                    return header + f.read(40-len(SYMREF))
+            finally:
+                f.close()
+        except IOError, e:
+            if e.errno == errno.ENOENT:
+                return None
+            raise
+
+    def _follow(self, name):
+        """Follow a reference name.
+
+        :return: a tuple of (refname, sha), where refname is the name of the
+            last reference in the symbolic reference chain
+        """
+        self._check_refname(name)
+        contents = SYMREF + name
+        depth = 0
+        while contents.startswith(SYMREF):
+            refname = contents[len(SYMREF):]
+            contents = self._read_ref_file(refname)
+            if not contents:
+                contents = self.get_packed_refs().get(refname, None)
+                if not contents:
+                    break
+            depth += 1
+            if depth > 5:
+                raise KeyError(name)
+        return refname, contents
+
     def __getitem__(self, name):
-        file = self.refpath(name)
-        if not os.path.exists(file):
+        """Get the SHA1 for a reference name.
+
+        This method follows all symbolic references.
+        """
+        _, sha = self._follow(name)
+        if sha is None:
             raise KeyError(name)
-        f = GitFile(file, 'rb')
+        return sha
+
+    def _remove_packed_ref(self, name):
+        if self._packed_refs is None:
+            return
+        filename = os.path.join(self.path, 'packed-refs')
+        # reread cached refs from disk, while holding the lock
+        f = GitFile(filename, 'wb')
+        try:
+            self._packed_refs = None
+            self.get_packed_refs()
+
+            if name not in self._packed_refs:
+                return
+
+            del self._packed_refs[name]
+            if name in self._peeled_refs:
+                del self._peeled_refs[name]
+            write_packed_refs(f, self._packed_refs, self._peeled_refs)
+            f.close()
+        finally:
+            f.abort()
+
+    def set_if_equals(self, name, old_ref, new_ref):
+        """Set a refname to new_ref only if it currently equals old_ref.
+
+        This method follows all symbolic references, and can be used to perform
+        an atomic compare-and-swap operation.
+
+        :param name: The refname to set.
+        :param old_ref: The old sha the refname must refer to, or None to set
+            unconditionally.
+        :param new_ref: The new sha the refname will refer to.
+        :return: True if the set was successful, False otherwise.
+        """
+        try:
+            realname, _ = self._follow(name)
+        except KeyError:
+            realname = name
+        filename = self.refpath(realname)
+        ensure_dir_exists(os.path.dirname(filename))
+        f = GitFile(filename, 'wb')
         try:
-            return f.read().strip("\n")
+            if old_ref is not None:
+                try:
+                    # read again while holding the lock
+                    orig_ref = self._read_ref_file(realname)
+                    if orig_ref is None:
+                        orig_ref = self.get_packed_refs().get(realname, None)
+                    if orig_ref != old_ref:
+                        f.abort()
+                        return False
+                except (OSError, IOError):
+                    f.abort()
+                    raise
+            try:
+                f.write(new_ref+"\n")
+            except (OSError, IOError):
+                f.abort()
+                raise
         finally:
             f.close()
+        return True
+
+    def add_if_new(self, name, ref):
+        """Add a new reference only if it does not already exist."""
+        self._check_refname(name)
+        filename = self.refpath(name)
+        ensure_dir_exists(os.path.dirname(filename))
+        f = GitFile(filename, 'wb')
+        try:
+            if os.path.exists(filename) or name in self.get_packed_refs():
+                f.abort()
+                return False
+            try:
+                f.write(ref+"\n")
+            except (OSError, IOError):
+                f.abort()
+                raise
+        finally:
+            f.close()
+        return True
 
     def __setitem__(self, name, ref):
-        file = self.refpath(name)
-        dirpath = os.path.dirname(file)
-        if not os.path.exists(dirpath):
-            os.makedirs(dirpath)
-        f = GitFile(file, 'wb')
+        """Set a reference name to point to the given SHA1.
+
+        This method follows all symbolic references.
+
+        :note: This method unconditionally overwrites the contents of a reference
+            on disk. To update atomically only if the reference has not changed
+            on disk, use set_if_equals().
+        """
+        self.set_if_equals(name, None, ref)
+
+    def remove_if_equals(self, name, old_ref):
+        """Remove a refname only if it currently equals old_ref.
+
+        This method does not follow symbolic references. It can be used to
+        perform an atomic compare-and-delete operation.
+
+        :param name: The refname to delete.
+        :param old_ref: The old sha the refname must refer to, or None to delete
+            unconditionally.
+        :return: True if the delete was successful, False otherwise.
+        """
+        self._check_refname(name)
+        filename = self.refpath(name)
+        f = GitFile(filename, 'wb')
         try:
-            f.write(ref+"\n")
+            if old_ref is not None:
+                orig_ref = self._read_ref_file(name)
+                if orig_ref is None:
+                    orig_ref = self.get_packed_refs().get(name, None)
+                if orig_ref != old_ref:
+                    return False
+            # may only be packed
+            if os.path.exists(filename):
+                os.remove(filename)
+            self._remove_packed_ref(name)
         finally:
-            f.close()
+            # never write, we just wanted the lock
+            f.abort()
+        return True
 
     def __delitem__(self, name):
-        file = self.refpath(name)
-        if os.path.exists(file):
-            os.remove(file)
+        """Remove a refname.
+
+        This method does not follow symbolic references.
+        :note: This method unconditionally deletes the contents of a reference
+            on disk. To delete atomically only if the reference has not changed
+            on disk, use set_if_equals().
+        """
+        self.remove_if_equals(name, None)
+
+
+def _split_ref_line(line):
+    """Split a single ref line into a tuple of SHA1 and name."""
+    fields = line.rstrip("\n").split(" ")
+    if len(fields) != 2:
+        raise PackedRefsException("invalid ref line '%s'" % line)
+    sha, name = fields
+    try:
+        hex_to_sha(sha)
+    except (AssertionError, TypeError), e:
+        raise PackedRefsException(e)
+    if not check_ref_format(name):
+        raise PackedRefsException("invalid ref name '%s'" % name)
+    return (sha, name)
 
 
 def read_packed_refs(f):
     """Read a packed refs file.
 
-    Yields tuples with ref names and SHA1s.
+    Yields tuples with SHA1s and ref names.
 
     :param f: file-like object to read from
     """
-    l = f.readline()
-    for l in f.readlines():
+    for l in f:
         if l[0] == "#":
             # Comment
             continue
         if l[0] == "^":
-            # FIXME: Return somehow
+            raise PackedRefsException(
+                "found peeled ref in packed-refs without peeled")
+        yield _split_ref_line(l)
+
+
+def read_packed_refs_with_peeled(f):
+    """Read a packed refs file including peeled refs.
+
+    Assumes the "# pack-refs with: peeled" line was already read. Yields tuples
+    with ref names, SHA1s, and peeled SHA1s (or None).
+
+    :param f: file-like object to read from, seek'ed to the second line
+    """
+    last = None
+    for l in f:
+        if l[0] == "#":
             continue
-        yield tuple(l.rstrip("\n").split(" ", 2))
+        l = l.rstrip("\n")
+        if l[0] == "^":
+            if not last:
+                raise PackedRefsException("unexpected peeled ref line")
+            try:
+                hex_to_sha(l[1:])
+            except (AssertionError, TypeError), e:
+                raise PackedRefsException(e)
+            sha, name = _split_ref_line(last)
+            last = None
+            yield (sha, name, l[1:])
+        else:
+            if last:
+                sha, name = _split_ref_line(last)
+                yield (sha, name, None)
+            last = l
+    if last:
+        sha, name = _split_ref_line(last)
+        yield (sha, name, None)
 
 
+def write_packed_refs(f, packed_refs, peeled_refs=None):
+    """Write a packed refs file.
+
+    :param f: empty file-like object to write to
+    :param packed_refs: dict of refname to sha of packed refs to write
+    """
+    if peeled_refs is None:
+        peeled_refs = {}
+    else:
+        f.write('# pack-refs with: peeled\n')
+    for refname in sorted(packed_refs.iterkeys()):
+        f.write('%s %s\n' % (packed_refs[refname], refname))
+        if refname in peeled_refs:
+            f.write('^%s\n' % peeled_refs[refname])
 
 class BaseRepo(object):
     """Base class for a git repository.
@@ -255,15 +561,15 @@ class BaseRepo(object):
 
     def ref(self, name):
         """Return the SHA1 a ref is pointing to."""
-        raise NotImplementedError(self.refs)
+        return self.refs[name]
 
     def get_refs(self):
         """Get dictionary with all refs."""
-        raise NotImplementedError(self.get_refs)
+        return self.refs.as_dict()
 
     def head(self):
         """Return the SHA1 pointed at by HEAD."""
-        return self.refs.follow('HEAD')
+        return self.refs['HEAD']
 
     def _get_object(self, sha, cls):
         assert len(sha) in (20, 40)
@@ -339,7 +645,7 @@ class BaseRepo(object):
         if len(name) in (20, 40):
             return self.object_store[name]
         return self.object_store[self.refs[name]]
-    
+
     def __setitem__(self, name, value):
         if name.startswith("refs/") or name == "HEAD":
             if isinstance(value, ShaFile):
@@ -392,47 +698,6 @@ class Repo(BaseRepo):
         """Check if an index is present."""
         return os.path.exists(self.index_path())
 
-    def ref(self, name):
-        """Return the SHA1 a ref is pointing to."""
-        try:
-            return self.refs.follow(name)
-        except KeyError:
-            return self.get_packed_refs()[name]
-
-    def get_refs(self):
-        """Get dictionary with all refs."""
-        # TODO: move to base class after merging RefsContainer changes
-        ret = {}
-        try:
-            if self.head():
-                ret['HEAD'] = self.head()
-        except KeyError:
-            pass
-        ret.update(self.refs.as_dict())
-        ret.update(self.get_packed_refs())
-        return ret
-
-    def get_packed_refs(self):
-        """Get contents of the packed-refs file.
-
-        :return: Dictionary mapping ref names to SHA1s
-
-        :note: Will return an empty dictionary when no packed-refs file is
-            present.
-        """
-        # TODO: move to base class after merging RefsContainer changes
-        path = os.path.join(self.controldir(), 'packed-refs')
-        if not os.path.exists(path):
-            return {}
-        ret = {}
-        f = open(path, 'rb')
-        try:
-            for entry in read_packed_refs(f):
-                ret[entry[1]] = entry[0]
-            return ret
-        finally:
-            f.close()
-
     def __repr__(self):
         return "<Repo at %r>" % self.path
 

+ 1 - 0
dulwich/tests/data/repos/refs.git/HEAD

@@ -0,0 +1 @@
+ref: refs/heads/master

BIN
dulwich/tests/data/repos/refs.git/objects/3b/9e5457140e738c2dcd39bf6d7acf88379b90d1


BIN
dulwich/tests/data/repos/refs.git/objects/42/d06bd4b77fed026b154d16493e5deab78f02ec


BIN
dulwich/tests/data/repos/refs.git/objects/a1/8114c31713746a33a2e70d9914d1ef3e781425


BIN
dulwich/tests/data/repos/refs.git/objects/df/6800012397fb85c56e7418dd4eb9405dee075c


+ 3 - 0
dulwich/tests/data/repos/refs.git/packed-refs

@@ -0,0 +1,3 @@
+# pack-refs with: peeled 
+df6800012397fb85c56e7418dd4eb9405dee075c refs/tags/refs-0.1
+^42d06bd4b77fed026b154d16493e5deab78f02ec

+ 1 - 0
dulwich/tests/data/repos/refs.git/refs/heads/loop

@@ -0,0 +1 @@
+ref: refs/heads/loop

+ 1 - 0
dulwich/tests/data/repos/refs.git/refs/heads/master

@@ -0,0 +1 @@
+42d06bd4b77fed026b154d16493e5deab78f02ec

+ 16 - 0
dulwich/tests/test_file.py

@@ -113,3 +113,19 @@ class GitFileTests(unittest.TestCase):
         new_orig_f = open(foo, 'rb')
         self.assertEquals(new_orig_f.read(), 'foo contents')
         new_orig_f.close()
+
+    def test_abort_close(self):
+        foo = self.path('foo')
+        f = GitFile(foo, 'wb')
+        f.abort()
+        try:
+            f.close()
+        except (IOError, OSError):
+            self.fail()
+
+        f = GitFile(foo, 'wb')
+        f.close()
+        try:
+            f.abort()
+        except (IOError, OSError):
+            self.fail()

+ 305 - 29
dulwich/tests/test_repository.py

@@ -20,75 +20,110 @@
 
 """Tests for the repository."""
 
-
+from cStringIO import StringIO
 import os
+import shutil
+import tempfile
 import unittest
 
 from dulwich import errors
-from dulwich.repo import Repo
+from dulwich.repo import (
+    check_ref_format,
+    DiskRefsContainer,
+    Repo,
+    read_packed_refs,
+    read_packed_refs_with_peeled,
+    write_packed_refs,
+    _split_ref_line,
+    )
 
 missing_sha = 'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
 
+
+def open_repo(name):
+    """Open a copy of a repo in a temporary directory.
+
+    Use this function for accessing repos in dulwich/tests/data/repos to avoid
+    accidentally or intentionally modifying those repos in place. Use
+    tear_down_repo to delete any temp files created.
+
+    :param name: The name of the repository, relative to
+        dulwich/tests/data/repos
+    :returns: An initialized Repo object that lives in a temporary directory.
+    """
+    temp_dir = tempfile.mkdtemp()
+    repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos', name)
+    temp_repo_dir = os.path.join(temp_dir, name)
+    shutil.copytree(repo_dir, temp_repo_dir, symlinks=True)
+    return Repo(temp_repo_dir)
+
+def tear_down_repo(repo):
+    """Tear down a test repository."""
+    temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
+    shutil.rmtree(temp_dir)
+
+
 class RepositoryTests(unittest.TestCase):
-  
-    def open_repo(self, name):
-        return Repo(os.path.join(os.path.dirname(__file__),
-                          'data', 'repos', name))
+
+    def setUp(self):
+        self._repo = None
+
+    def tearDown(self):
+        if self._repo is not None:
+            tear_down_repo(self._repo)
   
     def test_simple_props(self):
-        r = self.open_repo('a.git')
-        basedir = os.path.join(os.path.dirname(__file__), 
-                os.path.join('data', 'repos', 'a.git'))
-        self.assertEqual(r.controldir(), basedir)
+        r = self._repo = open_repo('a.git')
+        self.assertEqual(r.controldir(), r.path)
   
     def test_ref(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertEqual(r.ref('refs/heads/master'),
                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
   
     def test_get_refs(self):
-        r = self.open_repo('a.git')
-        self.assertEquals({
+        r = self._repo = open_repo('a.git')
+        self.assertEqual({
             'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097', 
             'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
             }, r.get_refs())
   
     def test_head(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
   
     def test_get_object(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         obj = r.get_object(r.head())
         self.assertEqual(obj._type, 'commit')
   
     def test_get_object_non_existant(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertRaises(KeyError, r.get_object, missing_sha)
   
     def test_commit(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         obj = r.commit(r.head())
         self.assertEqual(obj._type, 'commit')
   
     def test_commit_not_commit(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertRaises(errors.NotCommitError,
                           r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
   
     def test_tree(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         commit = r.commit(r.head())
         tree = r.tree(commit.tree)
         self.assertEqual(tree._type, 'tree')
         self.assertEqual(tree.sha().hexdigest(), commit.tree)
   
     def test_tree_not_tree(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertRaises(errors.NotTreeError, r.tree, r.head())
   
     def test_get_blob(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         commit = r.commit(r.head())
         tree = r.tree(commit.tree)
         blob_sha = tree.entries()[0][2]
@@ -97,18 +132,18 @@ class RepositoryTests(unittest.TestCase):
         self.assertEqual(blob.sha().hexdigest(), blob_sha)
   
     def test_get_blob_notblob(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
     
     def test_linear_history(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, [r.head(),
                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
   
     def test_merge_history(self):
-        r = self.open_repo('simple_merge.git')
+        r = self._repo = open_repo('simple_merge.git')
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, ['5dac377bdded4c9aeb8dff595f0faeebcc8498cc',
@@ -118,13 +153,13 @@ class RepositoryTests(unittest.TestCase):
                                 '0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
   
     def test_revision_history_missing_commit(self):
-        r = self.open_repo('simple_merge.git')
+        r = self._repo = open_repo('simple_merge.git')
         self.assertRaises(errors.MissingCommitError, r.revision_history,
                           missing_sha)
   
     def test_out_of_order_merge(self):
         """Test that revision history is ordered by date, not parent order."""
-        r = self.open_repo('ooo_merge.git')
+        r = self._repo = open_repo('ooo_merge.git')
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, ['7601d7f6231db6a57f7bbb79ee52e4d462fd44d1',
@@ -133,9 +168,250 @@ class RepositoryTests(unittest.TestCase):
                                 'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
   
     def test_get_tags_empty(self):
-        r = self.open_repo('ooo_merge.git')
-        self.assertEquals({}, r.refs.as_dict('refs/tags'))
+        r = self._repo = open_repo('ooo_merge.git')
+        self.assertEqual({}, r.refs.as_dict('refs/tags'))
 
     def test_get_config(self):
-        r = self.open_repo('ooo_merge.git')
+        r = self._repo = open_repo('ooo_merge.git')
         self.assertEquals({}, r.get_config())
+
+
+class CheckRefFormatTests(unittest.TestCase):
+    """Tests for the check_ref_format function.
+
+    These are the same tests as in the git test suite.
+    """
+
+    def test_valid(self):
+        self.assertTrue(check_ref_format('heads/foo'))
+        self.assertTrue(check_ref_format('foo/bar/baz'))
+        self.assertTrue(check_ref_format('refs///heads/foo'))
+        self.assertTrue(check_ref_format('foo./bar'))
+        self.assertTrue(check_ref_format('heads/foo@bar'))
+        self.assertTrue(check_ref_format('heads/fix.lock.error'))
+
+    def test_invalid(self):
+        self.assertFalse(check_ref_format('foo'))
+        self.assertFalse(check_ref_format('heads/foo/'))
+        self.assertFalse(check_ref_format('./foo'))
+        self.assertFalse(check_ref_format('.refs/foo'))
+        self.assertFalse(check_ref_format('heads/foo..bar'))
+        self.assertFalse(check_ref_format('heads/foo?bar'))
+        self.assertFalse(check_ref_format('heads/foo.lock'))
+        self.assertFalse(check_ref_format('heads/v@{ation'))
+        self.assertFalse(check_ref_format('heads/foo\bar'))
+
+
+ONES = "1" * 40
+TWOS = "2" * 40
+THREES = "3" * 40
+FOURS = "4" * 40
+
+class PackedRefsFileTests(unittest.TestCase):
+    def test_split_ref_line_errors(self):
+        self.assertRaises(errors.PackedRefsException, _split_ref_line,
+                          'singlefield')
+        self.assertRaises(errors.PackedRefsException, _split_ref_line,
+                          'badsha name')
+        self.assertRaises(errors.PackedRefsException, _split_ref_line,
+                          '%s bad/../refname' % ONES)
+
+    def test_read_without_peeled(self):
+        f = StringIO('# comment\n%s ref/1\n%s ref/2' % (ONES, TWOS))
+        self.assertEqual([(ONES, 'ref/1'), (TWOS, 'ref/2')],
+                         list(read_packed_refs(f)))
+
+    def test_read_without_peeled_errors(self):
+        f = StringIO('%s ref/1\n^%s' % (ONES, TWOS))
+        self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
+
+    def test_read_with_peeled(self):
+        f = StringIO('%s ref/1\n%s ref/2\n^%s\n%s ref/4' % (
+            ONES, TWOS, THREES, FOURS))
+        self.assertEqual([
+            (ONES, 'ref/1', None),
+            (TWOS, 'ref/2', THREES),
+            (FOURS, 'ref/4', None),
+            ], list(read_packed_refs_with_peeled(f)))
+
+    def test_read_with_peeled_errors(self):
+        f = StringIO('^%s\n%s ref/1' % (TWOS, ONES))
+        self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
+
+        f = StringIO('%s ref/1\n^%s\n^%s' % (ONES, TWOS, THREES))
+        self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
+
+    def test_write_with_peeled(self):
+        f = StringIO()
+        write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS},
+                          {'ref/1': THREES})
+        self.assertEqual(
+            "# pack-refs with: peeled\n%s ref/1\n^%s\n%s ref/2\n" % (
+            ONES, THREES, TWOS), f.getvalue())
+
+    def test_write_without_peeled(self):
+        f = StringIO()
+        write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS})
+        self.assertEqual("%s ref/1\n%s ref/2\n" % (ONES, TWOS), f.getvalue())
+
+
+class RefsContainerTests(unittest.TestCase):
+    def setUp(self):
+        self._repo = open_repo('refs.git')
+        self._refs = self._repo.refs
+
+    def tearDown(self):
+        tear_down_repo(self._repo)
+
+    def test_get_packed_refs(self):
+        self.assertEqual(
+            {'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c'},
+            self._refs.get_packed_refs())
+
+    def test_keys(self):
+        self.assertEqual([
+            'HEAD',
+            'refs/heads/loop',
+            'refs/heads/master',
+            'refs/tags/refs-0.1',
+            ], sorted(list(self._refs.keys())))
+        self.assertEqual(['loop', 'master'],
+                         sorted(self._refs.keys('refs/heads')))
+        self.assertEqual(['refs-0.1'], list(self._refs.keys('refs/tags')))
+
+    def test_as_dict(self):
+        # refs/heads/loop does not show up
+        self.assertEqual({
+            'HEAD': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+            'refs/heads/master': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+            'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
+            }, self._refs.as_dict())
+
+    def test_setitem(self):
+        self._refs['refs/some/ref'] = '42d06bd4b77fed026b154d16493e5deab78f02ec'
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/some/ref'])
+        f = open(os.path.join(self._refs.path, 'refs', 'some', 'ref'), 'rb')
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                          f.read()[:40])
+        f.close()
+
+    def test_setitem_symbolic(self):
+        ones = '1' * 40
+        self._refs['HEAD'] = ones
+        self.assertEqual(ones, self._refs['HEAD'])
+
+        # ensure HEAD was not modified
+        f = open(os.path.join(self._refs.path, 'HEAD'), 'rb')
+        self.assertEqual('ref: refs/heads/master', iter(f).next().rstrip('\n'))
+        f.close()
+
+        # ensure the symbolic link was written through
+        f = open(os.path.join(self._refs.path, 'refs', 'heads', 'master'), 'rb')
+        self.assertEqual(ones, f.read()[:40])
+        f.close()
+
+    def test_set_if_equals(self):
+        nines = '9' * 40
+        self.assertFalse(self._refs.set_if_equals('HEAD', 'c0ffee', nines))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['HEAD'])
+
+        self.assertTrue(self._refs.set_if_equals(
+            'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec', nines))
+        self.assertEqual(nines, self._refs['HEAD'])
+
+        # ensure symref was followed
+        self.assertEqual(nines, self._refs['refs/heads/master'])
+
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'HEAD.lock')))
+
+    def test_add_if_new(self):
+        nines = '9' * 40
+        self.assertFalse(self._refs.add_if_new('refs/heads/master', nines))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/heads/master'])
+
+        self.assertTrue(self._refs.add_if_new('refs/some/ref', nines))
+        self.assertEqual(nines, self._refs['refs/some/ref'])
+
+        # don't overwrite packed ref
+        self.assertFalse(self._refs.add_if_new('refs/tags/refs-0.1', nines))
+        self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
+                         self._refs['refs/tags/refs-0.1'])
+
+    def test_check_refname(self):
+        try:
+            self._refs._check_refname('HEAD')
+        except KeyError:
+            self.fail()
+
+        try:
+            self._refs._check_refname('refs/heads/foo')
+        except KeyError:
+            self.fail()
+
+        self.assertRaises(KeyError, self._refs._check_refname, 'refs')
+        self.assertRaises(KeyError, self._refs._check_refname, 'notrefs/foo')
+
+    def test_follow(self):
+        self.assertEquals(
+            ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
+            self._refs._follow('HEAD'))
+        self.assertEquals(
+            ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
+            self._refs._follow('refs/heads/master'))
+        self.assertRaises(KeyError, self._refs._follow, 'notrefs/foo')
+        self.assertRaises(KeyError, self._refs._follow, 'refs/heads/loop')
+
+    def test_delitem(self):
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                          self._refs['refs/heads/master'])
+        del self._refs['refs/heads/master']
+        self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
+        ref_file = os.path.join(self._refs.path, 'refs', 'heads', 'master')
+        self.assertFalse(os.path.exists(ref_file))
+        self.assertFalse('refs/heads/master' in self._refs.get_packed_refs())
+
+    def test_delitem_symbolic(self):
+        self.assertEqual('ref: refs/heads/master',
+                          self._refs._read_ref_file('HEAD'))
+        del self._refs['HEAD']
+        self.assertRaises(KeyError, lambda: self._refs['HEAD'])
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/heads/master'])
+        self.assertFalse(os.path.exists(os.path.join(self._refs.path, 'HEAD')))
+
+    def test_remove_if_equals(self):
+        nines = '9' * 40
+        self.assertFalse(self._refs.remove_if_equals('HEAD', 'c0ffee'))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['HEAD'])
+
+        # HEAD is a symref, so shouldn't equal its dereferenced value
+        self.assertFalse(self._refs.remove_if_equals(
+            'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
+        self.assertTrue(self._refs.remove_if_equals(
+            'refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
+        self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
+
+        # HEAD is now a broken symref
+        self.assertRaises(KeyError, lambda: self._refs['HEAD'])
+        self.assertEqual('ref: refs/heads/master',
+                          self._refs._read_ref_file('HEAD'))
+
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'HEAD.lock')))
+
+        # test removing ref that is only packed
+        self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
+                         self._refs['refs/tags/refs-0.1'])
+        self.assertTrue(
+            self._refs.remove_if_equals('refs/tags/refs-0.1',
+            'df6800012397fb85c56e7418dd4eb9405dee075c'))
+        self.assertRaises(KeyError, lambda: self._refs['refs/tags/refs-0.1'])