Browse Source

Improve ref handling.

DiskRefsConatiner now handles packed-refs transparently, and
understands both the peeled and unpeeled packed-refs formats
correctly. Ref cycles are handled by giving up after reaching an
arbitrary recurion depth.

Includes tests for all new functionality.

Change-Id: I742117d7a2b99cbb52ee2a9d3a625037844c55b6
Dave Borowitz 15 years ago
parent
commit
afad608e26

+ 8 - 0
dulwich/errors.py

@@ -108,3 +108,11 @@ class HangupException(GitProtocolError):
     def __init__(self):
     def __init__(self):
         Exception.__init__(self,
         Exception.__init__(self,
             "The remote server unexpectedly closed the connection.")
             "The remote server unexpectedly closed the connection.")
+
+
+class FileFormatException(Exception):
+    """Base class for exceptions relating to reading git file formats."""
+
+
+class PackedRefsException(FileFormatException):
+    """Indicates an error parsing a packed-refs file."""

+ 13 - 0
dulwich/file.py

@@ -23,6 +23,13 @@
 import errno
 import errno
 import os
 import os
 
 
+def ensure_dir_exists(dirname):
+    """Ensure a directory exists, creating if necessary."""
+    try:
+        os.makedirs(dirname)
+    except OSError, e:
+        if e.errno != errno.EEXIST:
+            raise
 
 
 def GitFile(filename, mode='r', bufsize=-1):
 def GitFile(filename, mode='r', bufsize=-1):
     """Create a file object that obeys the git file locking protocol.
     """Create a file object that obeys the git file locking protocol.
@@ -84,6 +91,7 @@ class _GitFile(object):
         self._lockfilename = '%s.lock' % self._filename
         self._lockfilename = '%s.lock' % self._filename
         fd = os.open(self._lockfilename, os.O_RDWR | os.O_CREAT | os.O_EXCL)
         fd = os.open(self._lockfilename, os.O_RDWR | os.O_CREAT | os.O_EXCL)
         self._file = os.fdopen(fd, mode, bufsize)
         self._file = os.fdopen(fd, mode, bufsize)
+        self._closed = False
 
 
         for method in self.PROXY_METHODS:
         for method in self.PROXY_METHODS:
             setattr(self, method, getattr(self._file, method))
             setattr(self, method, getattr(self._file, method))
@@ -93,9 +101,12 @@ class _GitFile(object):
 
 
         If the file is already closed, this is a no-op.
         If the file is already closed, this is a no-op.
         """
         """
+        if self._closed:
+            return
         self._file.close()
         self._file.close()
         try:
         try:
             os.remove(self._lockfilename)
             os.remove(self._lockfilename)
+            self._closed = True
         except OSError, e:
         except OSError, e:
             # The file may have been removed already, which is ok.
             # The file may have been removed already, which is ok.
             if e.errno != errno.ENOENT:
             if e.errno != errno.ENOENT:
@@ -112,6 +123,8 @@ class _GitFile(object):
             file is still closed, so further attempts to write to the same file
             file is still closed, so further attempts to write to the same file
             object will raise ValueError.
             object will raise ValueError.
         """
         """
+        if self._closed:
+            return
         self._file.close()
         self._file.close()
         try:
         try:
             os.rename(self._lockfilename, self._filename)
             os.rename(self._lockfilename, self._filename)

+ 374 - 109
dulwich/repo.py

@@ -22,6 +22,7 @@
 """Repository access."""
 """Repository access."""
 
 
 
 
+import errno
 import os
 import os
 import stat
 import stat
 
 
@@ -31,8 +32,12 @@ from dulwich.errors import (
     NotCommitError, 
     NotCommitError, 
     NotGitRepository,
     NotGitRepository,
     NotTreeError, 
     NotTreeError, 
+    PackedRefsException,
+    )
+from dulwich.file import (
+    ensure_dir_exists,
+    GitFile,
     )
     )
-from dulwich.file import GitFile
 from dulwich.object_store import (
 from dulwich.object_store import (
     DiskObjectStore,
     DiskObjectStore,
     )
     )
@@ -42,6 +47,7 @@ from dulwich.objects import (
     ShaFile,
     ShaFile,
     Tag,
     Tag,
     Tree,
     Tree,
+    hex_to_sha,
     )
     )
 
 
 OBJECTDIR = 'objects'
 OBJECTDIR = 'objects'
@@ -52,20 +58,36 @@ REFSDIR_HEADS = 'heads'
 INDEX_FILENAME = "index"
 INDEX_FILENAME = "index"
 
 
 
 
-def follow_ref(container, name):
-    """Follow a ref back to a SHA1.
-    
-    :param container: Ref container to use for looking up refs.
-    :param name: Name of the original ref.
+def check_ref_format(refname):
+    """Check if a refname is correctly formatted.
+
+    Implements all the same rules as git-check-ref-format[1].
+
+    [1] http://www.kernel.org/pub/software/scm/git/docs/git-check-ref-format.html
+
+    :param refname: The refname to check
+    :return: True if refname is valid, False otherwise
     """
     """
-    contents = container[name]
-    if contents.startswith(SYMREF):
-        ref = contents[len(SYMREF):]
-        if ref[-1] == '\n':
-            ref = ref[:-1]
-        return follow_ref(container, ref)
-    assert len(contents) == 40, 'Invalid ref in %s' % name
-    return contents
+    # These could be combined into one big expression, but are listed separately
+    # to parallel [1].
+    if '/.' in refname or refname.startswith('.'):
+        return False
+    if '/' not in refname:
+        return False
+    if '..' in refname:
+        return False
+    for c in refname:
+        if ord(c) < 040 or c in '\177 ~^:?*[':
+            return False
+    if refname[-1] in '/.':
+        return False
+    if refname.endswith('.lock'):
+        return False
+    if '@{' in refname:
+        return False
+    if '\\' in refname:
+        return False
+    return True
 
 
 
 
 class RefsContainer(object):
 class RefsContainer(object):
@@ -75,13 +97,6 @@ class RefsContainer(object):
         """Return the contents of this ref container under base as a dict."""
         """Return the contents of this ref container under base as a dict."""
         raise NotImplementedError(self.as_dict)
         raise NotImplementedError(self.as_dict)
 
 
-    def follow(self, name):
-        """Follow a ref name back to a SHA1.
-        
-        :param name: Name of the ref
-        """
-        return follow_ref(self, name)
-
     def set_ref(self, name, other):
     def set_ref(self, name, other):
         """Make a ref point at another ref.
         """Make a ref point at another ref.
 
 
@@ -100,54 +115,68 @@ class DiskRefsContainer(RefsContainer):
 
 
     def __init__(self, path):
     def __init__(self, path):
         self.path = path
         self.path = path
+        self._packed_refs = None
+        self._peeled_refs = {}
 
 
     def __repr__(self):
     def __repr__(self):
         return "%s(%r)" % (self.__class__.__name__, self.path)
         return "%s(%r)" % (self.__class__.__name__, self.path)
 
 
     def keys(self, base=None):
     def keys(self, base=None):
-        """Refs present in this container."""
-        return list(self.iterkeys(base))
+        """Refs present in this container.
 
 
-    def iterkeys(self, base=None):
+        :param base: An optional base to return refs under
+        :return: An unsorted set of valid refs in this container, including
+            packed refs.
+        """
         if base is not None:
         if base is not None:
-            return self.itersubkeys(base)
+            return self.subkeys(base)
         else:
         else:
-            return self.iterallkeys()
+            return self.allkeys()
 
 
-    def itersubkeys(self, base):
+    def subkeys(self, base):
+        keys = set()
         path = self.refpath(base)
         path = self.refpath(base)
         for root, dirs, files in os.walk(path):
         for root, dirs, files in os.walk(path):
-            dir = root[len(path):].strip("/").replace(os.path.sep, "/")
+            dir = root[len(path):].strip(os.path.sep).replace(os.path.sep, "/")
             for filename in files:
             for filename in files:
-                yield ("%s/%s" % (dir, filename)).strip("/")
-
-    def iterallkeys(self):
+                refname = ("%s/%s" % (dir, filename)).strip("/")
+                # check_ref_format requires at least one /, so we prepend the
+                # base before calling it.
+                if check_ref_format("%s/%s" % (base, refname)):
+                    keys.add(refname)
+        for key in self.get_packed_refs():
+            if key.startswith(base):
+                keys.add(key[len(base):].strip("/"))
+        return keys
+
+    def allkeys(self):
+        keys = set()
         if os.path.exists(self.refpath("HEAD")):
         if os.path.exists(self.refpath("HEAD")):
-            yield "HEAD"
+            keys.add("HEAD")
         path = self.refpath("")
         path = self.refpath("")
         for root, dirs, files in os.walk(self.refpath("refs")):
         for root, dirs, files in os.walk(self.refpath("refs")):
-            dir = root[len(path):].strip("/").replace(os.path.sep, "/")
+            dir = root[len(path):].strip(os.path.sep).replace(os.path.sep, "/")
             for filename in files:
             for filename in files:
-                yield ("%s/%s" % (dir, filename)).strip("/")
+                refname = ("%s/%s" % (dir, filename)).strip("/")
+                if check_ref_format(refname):
+                    keys.add(refname)
+        keys.update(self.get_packed_refs())
+        return keys
 
 
-    def as_dict(self, base=None, follow=True):
+    def as_dict(self, base=None):
         """Return the contents of this container as a dictionary.
         """Return the contents of this container as a dictionary.
 
 
         """
         """
         ret = {}
         ret = {}
+        keys = self.keys(base)
         if base is None:
         if base is None:
-            keys = self.iterkeys()
             base = ""
             base = ""
-        else:
-            keys = self.itersubkeys(base)
         for key in keys:
         for key in keys:
-                if follow:
-                    try:
-                        ret[key] = self.follow(("%s/%s" % (base, key)).strip("/"))
-                    except KeyError:
-                        continue # Unable to resolve
-                else:
-                    ret[key] = self[("%s/%s" % (base, key)).strip("/")]
+            try:
+                ret[key] = self[("%s/%s" % (base, key)).strip("/")]
+            except KeyError:
+                continue # Unable to resolve
+
         return ret
         return ret
 
 
     def refpath(self, name):
     def refpath(self, name):
@@ -158,51 +187,328 @@ class DiskRefsContainer(RefsContainer):
             name = name.replace("/", os.path.sep)
             name = name.replace("/", os.path.sep)
         return os.path.join(self.path, name)
         return os.path.join(self.path, name)
 
 
+    def get_packed_refs(self):
+        """Get contents of the packed-refs file.
+
+        :return: Dictionary mapping ref names to SHA1s
+
+        :note: Will return an empty dictionary when no packed-refs file is
+            present.
+        """
+        # TODO: invalidate the cache on repacking
+        if self._packed_refs is None:
+            self._packed_refs = {}
+            path = os.path.join(self.path, 'packed-refs')
+            try:
+                f = GitFile(path, 'rb')
+            except IOError, e:
+                if e.errno == errno.ENOENT:
+                    return {}
+                raise
+            try:
+                first_line = iter(f).next().rstrip()
+                if (first_line.startswith("# pack-refs") and " peeled" in
+                        first_line):
+                    for sha, name, peeled in read_packed_refs_with_peeled(f):
+                        self._packed_refs[name] = sha
+                        if peeled:
+                            self._peeled_refs[name] = peeled
+                else:
+                    f.seek(0)
+                    for sha, name in read_packed_refs(f):
+                        self._packed_refs[name] = sha
+            finally:
+                f.close()
+        return self._packed_refs
+
+    def _check_refname(self, name):
+        """Ensure a refname is valid and lives in refs or is HEAD.
+
+        HEAD is not a valid refname according to git-check-ref-format, but this
+        class needs to be able to touch HEAD. Also, check_ref_format expects
+        refnames without the leading 'refs/', but this class requires that
+        so it cannot touch anything outside the refs dir (or HEAD).
+
+        :param name: The name of the reference.
+        :raises KeyError: if a refname is not HEAD or is otherwise not valid.
+        """
+        if name == 'HEAD':
+            return
+        if not name.startswith('refs/') or not check_ref_format(name[5:]):
+            raise KeyError(name)
+
+    def _read_ref_file(self, name):
+        """Read a reference file and return its contents.
+
+        If the reference file a symbolic reference, only read the first line of
+        the file. Otherwise, only read the first 40 bytes.
+
+        :param name: the refname to read, relative to refpath
+        :return: The contents of the ref file, or None if the file does not
+            exist.
+        :raises IOError: if any other error occurs
+        """
+        filename = self.refpath(name)
+        try:
+            f = GitFile(filename, 'rb')
+            try:
+                header = f.read(len(SYMREF))
+                if header == SYMREF:
+                    # Read only the first line
+                    return header + iter(f).next().rstrip("\n")
+                else:
+                    # Read only the first 40 bytes
+                    return header + f.read(40-len(SYMREF))
+            finally:
+                f.close()
+        except IOError, e:
+            if e.errno == errno.ENOENT:
+                return None
+            raise
+
+    def _follow(self, name):
+        """Follow a reference name.
+
+        :return: a tuple of (refname, sha), where refname is the name of the
+            last reference in the symbolic reference chain
+        """
+        self._check_refname(name)
+        contents = SYMREF + name
+        depth = 0
+        while contents.startswith(SYMREF):
+            refname = contents[len(SYMREF):]
+            contents = self._read_ref_file(refname)
+            if not contents:
+                contents = self.get_packed_refs().get(refname, None)
+                if not contents:
+                    break
+            depth += 1
+            if depth > 5:
+                raise KeyError(name)
+        return refname, contents
+
     def __getitem__(self, name):
     def __getitem__(self, name):
-        file = self.refpath(name)
-        if not os.path.exists(file):
+        """Get the SHA1 for a reference name.
+
+        This method follows all symbolic references.
+        """
+        _, sha = self._follow(name)
+        if sha is None:
             raise KeyError(name)
             raise KeyError(name)
-        f = GitFile(file, 'rb')
+        return sha
+
+    def _remove_packed_ref(self, name):
+        if self._packed_refs is None:
+            return
+        filename = os.path.join(self.path, 'packed-refs')
+        # reread cached refs from disk, while holding the lock
+        f = GitFile(filename, 'wb')
+        try:
+            self._packed_refs = None
+            self.get_packed_refs()
+
+            if name not in self._packed_refs:
+                return
+
+            del self._packed_refs[name]
+            if name in self._peeled_refs:
+                del self._peeled_refs[name]
+            write_packed_refs(f, self._packed_refs, self._peeled_refs)
+            f.close()
+        finally:
+            f.abort()
+
+    def set_if_equals(self, name, old_ref, new_ref):
+        """Set a refname to new_ref only if it currently equals old_ref.
+
+        This method follows all symbolic references, and can be used to perform
+        an atomic compare-and-swap operation.
+
+        :param name: The refname to set.
+        :param old_ref: The old sha the refname must refer to, or None to set
+            unconditionally.
+        :param new_ref: The new sha the refname will refer to.
+        :return: True if the set was successful, False otherwise.
+        """
+        try:
+            realname, _ = self._follow(name)
+        except KeyError:
+            realname = name
+        filename = self.refpath(realname)
+        ensure_dir_exists(os.path.dirname(filename))
+        f = GitFile(filename, 'wb')
         try:
         try:
-            return f.read().strip("\n")
+            if old_ref is not None:
+                try:
+                    # read again while holding the lock
+                    orig_ref = self._read_ref_file(realname)
+                    if orig_ref is None:
+                        orig_ref = self.get_packed_refs().get(realname, None)
+                    if orig_ref != old_ref:
+                        f.abort()
+                        return False
+                except (OSError, IOError):
+                    f.abort()
+                    raise
+            try:
+                f.write(new_ref+"\n")
+            except (OSError, IOError):
+                f.abort()
+                raise
         finally:
         finally:
             f.close()
             f.close()
+        return True
+
+    def add_if_new(self, name, ref):
+        """Add a new reference only if it does not already exist."""
+        self._check_refname(name)
+        filename = self.refpath(name)
+        ensure_dir_exists(os.path.dirname(filename))
+        f = GitFile(filename, 'wb')
+        try:
+            if os.path.exists(filename) or name in self.get_packed_refs():
+                f.abort()
+                return False
+            try:
+                f.write(ref+"\n")
+            except (OSError, IOError):
+                f.abort()
+                raise
+        finally:
+            f.close()
+        return True
 
 
     def __setitem__(self, name, ref):
     def __setitem__(self, name, ref):
-        file = self.refpath(name)
-        dirpath = os.path.dirname(file)
-        if not os.path.exists(dirpath):
-            os.makedirs(dirpath)
-        f = GitFile(file, 'wb')
+        """Set a reference name to point to the given SHA1.
+
+        This method follows all symbolic references.
+
+        :note: This method unconditionally overwrites the contents of a reference
+            on disk. To update atomically only if the reference has not changed
+            on disk, use set_if_equals().
+        """
+        self.set_if_equals(name, None, ref)
+
+    def remove_if_equals(self, name, old_ref):
+        """Remove a refname only if it currently equals old_ref.
+
+        This method does not follow symbolic references. It can be used to
+        perform an atomic compare-and-delete operation.
+
+        :param name: The refname to delete.
+        :param old_ref: The old sha the refname must refer to, or None to delete
+            unconditionally.
+        :return: True if the delete was successful, False otherwise.
+        """
+        self._check_refname(name)
+        filename = self.refpath(name)
+        f = GitFile(filename, 'wb')
         try:
         try:
-            f.write(ref+"\n")
+            if old_ref is not None:
+                orig_ref = self._read_ref_file(name)
+                if orig_ref is None:
+                    orig_ref = self.get_packed_refs().get(name, None)
+                if orig_ref != old_ref:
+                    return False
+            # may only be packed
+            if os.path.exists(filename):
+                os.remove(filename)
+            self._remove_packed_ref(name)
         finally:
         finally:
-            f.close()
+            # never write, we just wanted the lock
+            f.abort()
+        return True
 
 
     def __delitem__(self, name):
     def __delitem__(self, name):
-        file = self.refpath(name)
-        if os.path.exists(file):
-            os.remove(file)
+        """Remove a refname.
+
+        This method does not follow symbolic references.
+        :note: This method unconditionally deletes the contents of a reference
+            on disk. To delete atomically only if the reference has not changed
+            on disk, use set_if_equals().
+        """
+        self.remove_if_equals(name, None)
+
+
+def _split_ref_line(line):
+    """Split a single ref line into a tuple of SHA1 and name."""
+    fields = line.rstrip("\n").split(" ")
+    if len(fields) != 2:
+        raise PackedRefsException("invalid ref line '%s'" % line)
+    sha, name = fields
+    try:
+        hex_to_sha(sha)
+    except (AssertionError, TypeError), e:
+        raise PackedRefsException(e)
+    if not check_ref_format(name):
+        raise PackedRefsException("invalid ref name '%s'" % name)
+    return (sha, name)
 
 
 
 
 def read_packed_refs(f):
 def read_packed_refs(f):
     """Read a packed refs file.
     """Read a packed refs file.
 
 
-    Yields tuples with ref names and SHA1s.
+    Yields tuples with SHA1s and ref names.
 
 
     :param f: file-like object to read from
     :param f: file-like object to read from
     """
     """
-    l = f.readline()
-    for l in f.readlines():
+    for l in f:
         if l[0] == "#":
         if l[0] == "#":
             # Comment
             # Comment
             continue
             continue
         if l[0] == "^":
         if l[0] == "^":
-            # FIXME: Return somehow
+            raise PackedRefsException(
+                "found peeled ref in packed-refs without peeled")
+        yield _split_ref_line(l)
+
+
+def read_packed_refs_with_peeled(f):
+    """Read a packed refs file including peeled refs.
+
+    Assumes the "# pack-refs with: peeled" line was already read. Yields tuples
+    with ref names, SHA1s, and peeled SHA1s (or None).
+
+    :param f: file-like object to read from, seek'ed to the second line
+    """
+    last = None
+    for l in f:
+        if l[0] == "#":
             continue
             continue
-        yield tuple(l.rstrip("\n").split(" ", 2))
+        l = l.rstrip("\n")
+        if l[0] == "^":
+            if not last:
+                raise PackedRefsException("unexpected peeled ref line")
+            try:
+                hex_to_sha(l[1:])
+            except (AssertionError, TypeError), e:
+                raise PackedRefsException(e)
+            sha, name = _split_ref_line(last)
+            last = None
+            yield (sha, name, l[1:])
+        else:
+            if last:
+                sha, name = _split_ref_line(last)
+                yield (sha, name, None)
+            last = l
+    if last:
+        sha, name = _split_ref_line(last)
+        yield (sha, name, None)
 
 
 
 
+def write_packed_refs(f, packed_refs, peeled_refs=None):
+    """Write a packed refs file.
+
+    :param f: empty file-like object to write to
+    :param packed_refs: dict of refname to sha of packed refs to write
+    """
+    if peeled_refs is None:
+        peeled_refs = {}
+    else:
+        f.write('# pack-refs with: peeled\n')
+    for refname in sorted(packed_refs.iterkeys()):
+        f.write('%s %s\n' % (packed_refs[refname], refname))
+        if refname in peeled_refs:
+            f.write('^%s\n' % peeled_refs[refname])
 
 
 class BaseRepo(object):
 class BaseRepo(object):
     """Base class for a git repository.
     """Base class for a git repository.
@@ -255,15 +561,15 @@ class BaseRepo(object):
 
 
     def ref(self, name):
     def ref(self, name):
         """Return the SHA1 a ref is pointing to."""
         """Return the SHA1 a ref is pointing to."""
-        raise NotImplementedError(self.refs)
+        return self.refs[name]
 
 
     def get_refs(self):
     def get_refs(self):
         """Get dictionary with all refs."""
         """Get dictionary with all refs."""
-        raise NotImplementedError(self.get_refs)
+        return self.refs.as_dict()
 
 
     def head(self):
     def head(self):
         """Return the SHA1 pointed at by HEAD."""
         """Return the SHA1 pointed at by HEAD."""
-        return self.refs.follow('HEAD')
+        return self.refs['HEAD']
 
 
     def _get_object(self, sha, cls):
     def _get_object(self, sha, cls):
         assert len(sha) in (20, 40)
         assert len(sha) in (20, 40)
@@ -339,7 +645,7 @@ class BaseRepo(object):
         if len(name) in (20, 40):
         if len(name) in (20, 40):
             return self.object_store[name]
             return self.object_store[name]
         return self.object_store[self.refs[name]]
         return self.object_store[self.refs[name]]
-    
+
     def __setitem__(self, name, value):
     def __setitem__(self, name, value):
         if name.startswith("refs/") or name == "HEAD":
         if name.startswith("refs/") or name == "HEAD":
             if isinstance(value, ShaFile):
             if isinstance(value, ShaFile):
@@ -392,47 +698,6 @@ class Repo(BaseRepo):
         """Check if an index is present."""
         """Check if an index is present."""
         return os.path.exists(self.index_path())
         return os.path.exists(self.index_path())
 
 
-    def ref(self, name):
-        """Return the SHA1 a ref is pointing to."""
-        try:
-            return self.refs.follow(name)
-        except KeyError:
-            return self.get_packed_refs()[name]
-
-    def get_refs(self):
-        """Get dictionary with all refs."""
-        # TODO: move to base class after merging RefsContainer changes
-        ret = {}
-        try:
-            if self.head():
-                ret['HEAD'] = self.head()
-        except KeyError:
-            pass
-        ret.update(self.refs.as_dict())
-        ret.update(self.get_packed_refs())
-        return ret
-
-    def get_packed_refs(self):
-        """Get contents of the packed-refs file.
-
-        :return: Dictionary mapping ref names to SHA1s
-
-        :note: Will return an empty dictionary when no packed-refs file is
-            present.
-        """
-        # TODO: move to base class after merging RefsContainer changes
-        path = os.path.join(self.controldir(), 'packed-refs')
-        if not os.path.exists(path):
-            return {}
-        ret = {}
-        f = open(path, 'rb')
-        try:
-            for entry in read_packed_refs(f):
-                ret[entry[1]] = entry[0]
-            return ret
-        finally:
-            f.close()
-
     def __repr__(self):
     def __repr__(self):
         return "<Repo at %r>" % self.path
         return "<Repo at %r>" % self.path
 
 

+ 1 - 0
dulwich/tests/data/repos/refs.git/HEAD

@@ -0,0 +1 @@
+ref: refs/heads/master

BIN
dulwich/tests/data/repos/refs.git/objects/3b/9e5457140e738c2dcd39bf6d7acf88379b90d1


BIN
dulwich/tests/data/repos/refs.git/objects/42/d06bd4b77fed026b154d16493e5deab78f02ec


BIN
dulwich/tests/data/repos/refs.git/objects/a1/8114c31713746a33a2e70d9914d1ef3e781425


BIN
dulwich/tests/data/repos/refs.git/objects/df/6800012397fb85c56e7418dd4eb9405dee075c


+ 3 - 0
dulwich/tests/data/repos/refs.git/packed-refs

@@ -0,0 +1,3 @@
+# pack-refs with: peeled 
+df6800012397fb85c56e7418dd4eb9405dee075c refs/tags/refs-0.1
+^42d06bd4b77fed026b154d16493e5deab78f02ec

+ 1 - 0
dulwich/tests/data/repos/refs.git/refs/heads/loop

@@ -0,0 +1 @@
+ref: refs/heads/loop

+ 1 - 0
dulwich/tests/data/repos/refs.git/refs/heads/master

@@ -0,0 +1 @@
+42d06bd4b77fed026b154d16493e5deab78f02ec

+ 16 - 0
dulwich/tests/test_file.py

@@ -113,3 +113,19 @@ class GitFileTests(unittest.TestCase):
         new_orig_f = open(foo, 'rb')
         new_orig_f = open(foo, 'rb')
         self.assertEquals(new_orig_f.read(), 'foo contents')
         self.assertEquals(new_orig_f.read(), 'foo contents')
         new_orig_f.close()
         new_orig_f.close()
+
+    def test_abort_close(self):
+        foo = self.path('foo')
+        f = GitFile(foo, 'wb')
+        f.abort()
+        try:
+            f.close()
+        except (IOError, OSError):
+            self.fail()
+
+        f = GitFile(foo, 'wb')
+        f.close()
+        try:
+            f.abort()
+        except (IOError, OSError):
+            self.fail()

+ 305 - 29
dulwich/tests/test_repository.py

@@ -20,75 +20,110 @@
 
 
 """Tests for the repository."""
 """Tests for the repository."""
 
 
-
+from cStringIO import StringIO
 import os
 import os
+import shutil
+import tempfile
 import unittest
 import unittest
 
 
 from dulwich import errors
 from dulwich import errors
-from dulwich.repo import Repo
+from dulwich.repo import (
+    check_ref_format,
+    DiskRefsContainer,
+    Repo,
+    read_packed_refs,
+    read_packed_refs_with_peeled,
+    write_packed_refs,
+    _split_ref_line,
+    )
 
 
 missing_sha = 'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
 missing_sha = 'b91fa4d900e17e99b433218e988c4eb4a3e9a097'
 
 
+
+def open_repo(name):
+    """Open a copy of a repo in a temporary directory.
+
+    Use this function for accessing repos in dulwich/tests/data/repos to avoid
+    accidentally or intentionally modifying those repos in place. Use
+    tear_down_repo to delete any temp files created.
+
+    :param name: The name of the repository, relative to
+        dulwich/tests/data/repos
+    :returns: An initialized Repo object that lives in a temporary directory.
+    """
+    temp_dir = tempfile.mkdtemp()
+    repo_dir = os.path.join(os.path.dirname(__file__), 'data', 'repos', name)
+    temp_repo_dir = os.path.join(temp_dir, name)
+    shutil.copytree(repo_dir, temp_repo_dir, symlinks=True)
+    return Repo(temp_repo_dir)
+
+def tear_down_repo(repo):
+    """Tear down a test repository."""
+    temp_dir = os.path.dirname(repo.path.rstrip(os.sep))
+    shutil.rmtree(temp_dir)
+
+
 class RepositoryTests(unittest.TestCase):
 class RepositoryTests(unittest.TestCase):
-  
-    def open_repo(self, name):
-        return Repo(os.path.join(os.path.dirname(__file__),
-                          'data', 'repos', name))
+
+    def setUp(self):
+        self._repo = None
+
+    def tearDown(self):
+        if self._repo is not None:
+            tear_down_repo(self._repo)
   
   
     def test_simple_props(self):
     def test_simple_props(self):
-        r = self.open_repo('a.git')
-        basedir = os.path.join(os.path.dirname(__file__), 
-                os.path.join('data', 'repos', 'a.git'))
-        self.assertEqual(r.controldir(), basedir)
+        r = self._repo = open_repo('a.git')
+        self.assertEqual(r.controldir(), r.path)
   
   
     def test_ref(self):
     def test_ref(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertEqual(r.ref('refs/heads/master'),
         self.assertEqual(r.ref('refs/heads/master'),
                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
   
   
     def test_get_refs(self):
     def test_get_refs(self):
-        r = self.open_repo('a.git')
-        self.assertEquals({
+        r = self._repo = open_repo('a.git')
+        self.assertEqual({
             'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097', 
             'HEAD': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097', 
             'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
             'refs/heads/master': 'a90fa2d900a17e99b433217e988c4eb4a2e9a097'
             }, r.get_refs())
             }, r.get_refs())
   
   
     def test_head(self):
     def test_head(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
         self.assertEqual(r.head(), 'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
   
   
     def test_get_object(self):
     def test_get_object(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         obj = r.get_object(r.head())
         obj = r.get_object(r.head())
         self.assertEqual(obj._type, 'commit')
         self.assertEqual(obj._type, 'commit')
   
   
     def test_get_object_non_existant(self):
     def test_get_object_non_existant(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertRaises(KeyError, r.get_object, missing_sha)
         self.assertRaises(KeyError, r.get_object, missing_sha)
   
   
     def test_commit(self):
     def test_commit(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         obj = r.commit(r.head())
         obj = r.commit(r.head())
         self.assertEqual(obj._type, 'commit')
         self.assertEqual(obj._type, 'commit')
   
   
     def test_commit_not_commit(self):
     def test_commit_not_commit(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertRaises(errors.NotCommitError,
         self.assertRaises(errors.NotCommitError,
                           r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
                           r.commit, '4f2e6529203aa6d44b5af6e3292c837ceda003f9')
   
   
     def test_tree(self):
     def test_tree(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         commit = r.commit(r.head())
         commit = r.commit(r.head())
         tree = r.tree(commit.tree)
         tree = r.tree(commit.tree)
         self.assertEqual(tree._type, 'tree')
         self.assertEqual(tree._type, 'tree')
         self.assertEqual(tree.sha().hexdigest(), commit.tree)
         self.assertEqual(tree.sha().hexdigest(), commit.tree)
   
   
     def test_tree_not_tree(self):
     def test_tree_not_tree(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertRaises(errors.NotTreeError, r.tree, r.head())
         self.assertRaises(errors.NotTreeError, r.tree, r.head())
   
   
     def test_get_blob(self):
     def test_get_blob(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         commit = r.commit(r.head())
         commit = r.commit(r.head())
         tree = r.tree(commit.tree)
         tree = r.tree(commit.tree)
         blob_sha = tree.entries()[0][2]
         blob_sha = tree.entries()[0][2]
@@ -97,18 +132,18 @@ class RepositoryTests(unittest.TestCase):
         self.assertEqual(blob.sha().hexdigest(), blob_sha)
         self.assertEqual(blob.sha().hexdigest(), blob_sha)
   
   
     def test_get_blob_notblob(self):
     def test_get_blob_notblob(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
         self.assertRaises(errors.NotBlobError, r.get_blob, r.head())
     
     
     def test_linear_history(self):
     def test_linear_history(self):
-        r = self.open_repo('a.git')
+        r = self._repo = open_repo('a.git')
         history = r.revision_history(r.head())
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, [r.head(),
         self.assertEqual(shas, [r.head(),
                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
                                 '2a72d929692c41d8554c07f6301757ba18a65d91'])
   
   
     def test_merge_history(self):
     def test_merge_history(self):
-        r = self.open_repo('simple_merge.git')
+        r = self._repo = open_repo('simple_merge.git')
         history = r.revision_history(r.head())
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, ['5dac377bdded4c9aeb8dff595f0faeebcc8498cc',
         self.assertEqual(shas, ['5dac377bdded4c9aeb8dff595f0faeebcc8498cc',
@@ -118,13 +153,13 @@ class RepositoryTests(unittest.TestCase):
                                 '0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
                                 '0d89f20333fbb1d2f3a94da77f4981373d8f4310'])
   
   
     def test_revision_history_missing_commit(self):
     def test_revision_history_missing_commit(self):
-        r = self.open_repo('simple_merge.git')
+        r = self._repo = open_repo('simple_merge.git')
         self.assertRaises(errors.MissingCommitError, r.revision_history,
         self.assertRaises(errors.MissingCommitError, r.revision_history,
                           missing_sha)
                           missing_sha)
   
   
     def test_out_of_order_merge(self):
     def test_out_of_order_merge(self):
         """Test that revision history is ordered by date, not parent order."""
         """Test that revision history is ordered by date, not parent order."""
-        r = self.open_repo('ooo_merge.git')
+        r = self._repo = open_repo('ooo_merge.git')
         history = r.revision_history(r.head())
         history = r.revision_history(r.head())
         shas = [c.sha().hexdigest() for c in history]
         shas = [c.sha().hexdigest() for c in history]
         self.assertEqual(shas, ['7601d7f6231db6a57f7bbb79ee52e4d462fd44d1',
         self.assertEqual(shas, ['7601d7f6231db6a57f7bbb79ee52e4d462fd44d1',
@@ -133,9 +168,250 @@ class RepositoryTests(unittest.TestCase):
                                 'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
                                 'f9e39b120c68182a4ba35349f832d0e4e61f485c'])
   
   
     def test_get_tags_empty(self):
     def test_get_tags_empty(self):
-        r = self.open_repo('ooo_merge.git')
-        self.assertEquals({}, r.refs.as_dict('refs/tags'))
+        r = self._repo = open_repo('ooo_merge.git')
+        self.assertEqual({}, r.refs.as_dict('refs/tags'))
 
 
     def test_get_config(self):
     def test_get_config(self):
-        r = self.open_repo('ooo_merge.git')
+        r = self._repo = open_repo('ooo_merge.git')
         self.assertEquals({}, r.get_config())
         self.assertEquals({}, r.get_config())
+
+
+class CheckRefFormatTests(unittest.TestCase):
+    """Tests for the check_ref_format function.
+
+    These are the same tests as in the git test suite.
+    """
+
+    def test_valid(self):
+        self.assertTrue(check_ref_format('heads/foo'))
+        self.assertTrue(check_ref_format('foo/bar/baz'))
+        self.assertTrue(check_ref_format('refs///heads/foo'))
+        self.assertTrue(check_ref_format('foo./bar'))
+        self.assertTrue(check_ref_format('heads/foo@bar'))
+        self.assertTrue(check_ref_format('heads/fix.lock.error'))
+
+    def test_invalid(self):
+        self.assertFalse(check_ref_format('foo'))
+        self.assertFalse(check_ref_format('heads/foo/'))
+        self.assertFalse(check_ref_format('./foo'))
+        self.assertFalse(check_ref_format('.refs/foo'))
+        self.assertFalse(check_ref_format('heads/foo..bar'))
+        self.assertFalse(check_ref_format('heads/foo?bar'))
+        self.assertFalse(check_ref_format('heads/foo.lock'))
+        self.assertFalse(check_ref_format('heads/v@{ation'))
+        self.assertFalse(check_ref_format('heads/foo\bar'))
+
+
+ONES = "1" * 40
+TWOS = "2" * 40
+THREES = "3" * 40
+FOURS = "4" * 40
+
+class PackedRefsFileTests(unittest.TestCase):
+    def test_split_ref_line_errors(self):
+        self.assertRaises(errors.PackedRefsException, _split_ref_line,
+                          'singlefield')
+        self.assertRaises(errors.PackedRefsException, _split_ref_line,
+                          'badsha name')
+        self.assertRaises(errors.PackedRefsException, _split_ref_line,
+                          '%s bad/../refname' % ONES)
+
+    def test_read_without_peeled(self):
+        f = StringIO('# comment\n%s ref/1\n%s ref/2' % (ONES, TWOS))
+        self.assertEqual([(ONES, 'ref/1'), (TWOS, 'ref/2')],
+                         list(read_packed_refs(f)))
+
+    def test_read_without_peeled_errors(self):
+        f = StringIO('%s ref/1\n^%s' % (ONES, TWOS))
+        self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
+
+    def test_read_with_peeled(self):
+        f = StringIO('%s ref/1\n%s ref/2\n^%s\n%s ref/4' % (
+            ONES, TWOS, THREES, FOURS))
+        self.assertEqual([
+            (ONES, 'ref/1', None),
+            (TWOS, 'ref/2', THREES),
+            (FOURS, 'ref/4', None),
+            ], list(read_packed_refs_with_peeled(f)))
+
+    def test_read_with_peeled_errors(self):
+        f = StringIO('^%s\n%s ref/1' % (TWOS, ONES))
+        self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
+
+        f = StringIO('%s ref/1\n^%s\n^%s' % (ONES, TWOS, THREES))
+        self.assertRaises(errors.PackedRefsException, list, read_packed_refs(f))
+
+    def test_write_with_peeled(self):
+        f = StringIO()
+        write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS},
+                          {'ref/1': THREES})
+        self.assertEqual(
+            "# pack-refs with: peeled\n%s ref/1\n^%s\n%s ref/2\n" % (
+            ONES, THREES, TWOS), f.getvalue())
+
+    def test_write_without_peeled(self):
+        f = StringIO()
+        write_packed_refs(f, {'ref/1': ONES, 'ref/2': TWOS})
+        self.assertEqual("%s ref/1\n%s ref/2\n" % (ONES, TWOS), f.getvalue())
+
+
+class RefsContainerTests(unittest.TestCase):
+    def setUp(self):
+        self._repo = open_repo('refs.git')
+        self._refs = self._repo.refs
+
+    def tearDown(self):
+        tear_down_repo(self._repo)
+
+    def test_get_packed_refs(self):
+        self.assertEqual(
+            {'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c'},
+            self._refs.get_packed_refs())
+
+    def test_keys(self):
+        self.assertEqual([
+            'HEAD',
+            'refs/heads/loop',
+            'refs/heads/master',
+            'refs/tags/refs-0.1',
+            ], sorted(list(self._refs.keys())))
+        self.assertEqual(['loop', 'master'],
+                         sorted(self._refs.keys('refs/heads')))
+        self.assertEqual(['refs-0.1'], list(self._refs.keys('refs/tags')))
+
+    def test_as_dict(self):
+        # refs/heads/loop does not show up
+        self.assertEqual({
+            'HEAD': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+            'refs/heads/master': '42d06bd4b77fed026b154d16493e5deab78f02ec',
+            'refs/tags/refs-0.1': 'df6800012397fb85c56e7418dd4eb9405dee075c',
+            }, self._refs.as_dict())
+
+    def test_setitem(self):
+        self._refs['refs/some/ref'] = '42d06bd4b77fed026b154d16493e5deab78f02ec'
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/some/ref'])
+        f = open(os.path.join(self._refs.path, 'refs', 'some', 'ref'), 'rb')
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                          f.read()[:40])
+        f.close()
+
+    def test_setitem_symbolic(self):
+        ones = '1' * 40
+        self._refs['HEAD'] = ones
+        self.assertEqual(ones, self._refs['HEAD'])
+
+        # ensure HEAD was not modified
+        f = open(os.path.join(self._refs.path, 'HEAD'), 'rb')
+        self.assertEqual('ref: refs/heads/master', iter(f).next().rstrip('\n'))
+        f.close()
+
+        # ensure the symbolic link was written through
+        f = open(os.path.join(self._refs.path, 'refs', 'heads', 'master'), 'rb')
+        self.assertEqual(ones, f.read()[:40])
+        f.close()
+
+    def test_set_if_equals(self):
+        nines = '9' * 40
+        self.assertFalse(self._refs.set_if_equals('HEAD', 'c0ffee', nines))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['HEAD'])
+
+        self.assertTrue(self._refs.set_if_equals(
+            'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec', nines))
+        self.assertEqual(nines, self._refs['HEAD'])
+
+        # ensure symref was followed
+        self.assertEqual(nines, self._refs['refs/heads/master'])
+
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'HEAD.lock')))
+
+    def test_add_if_new(self):
+        nines = '9' * 40
+        self.assertFalse(self._refs.add_if_new('refs/heads/master', nines))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/heads/master'])
+
+        self.assertTrue(self._refs.add_if_new('refs/some/ref', nines))
+        self.assertEqual(nines, self._refs['refs/some/ref'])
+
+        # don't overwrite packed ref
+        self.assertFalse(self._refs.add_if_new('refs/tags/refs-0.1', nines))
+        self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
+                         self._refs['refs/tags/refs-0.1'])
+
+    def test_check_refname(self):
+        try:
+            self._refs._check_refname('HEAD')
+        except KeyError:
+            self.fail()
+
+        try:
+            self._refs._check_refname('refs/heads/foo')
+        except KeyError:
+            self.fail()
+
+        self.assertRaises(KeyError, self._refs._check_refname, 'refs')
+        self.assertRaises(KeyError, self._refs._check_refname, 'notrefs/foo')
+
+    def test_follow(self):
+        self.assertEquals(
+            ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
+            self._refs._follow('HEAD'))
+        self.assertEquals(
+            ('refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'),
+            self._refs._follow('refs/heads/master'))
+        self.assertRaises(KeyError, self._refs._follow, 'notrefs/foo')
+        self.assertRaises(KeyError, self._refs._follow, 'refs/heads/loop')
+
+    def test_delitem(self):
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                          self._refs['refs/heads/master'])
+        del self._refs['refs/heads/master']
+        self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
+        ref_file = os.path.join(self._refs.path, 'refs', 'heads', 'master')
+        self.assertFalse(os.path.exists(ref_file))
+        self.assertFalse('refs/heads/master' in self._refs.get_packed_refs())
+
+    def test_delitem_symbolic(self):
+        self.assertEqual('ref: refs/heads/master',
+                          self._refs._read_ref_file('HEAD'))
+        del self._refs['HEAD']
+        self.assertRaises(KeyError, lambda: self._refs['HEAD'])
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['refs/heads/master'])
+        self.assertFalse(os.path.exists(os.path.join(self._refs.path, 'HEAD')))
+
+    def test_remove_if_equals(self):
+        nines = '9' * 40
+        self.assertFalse(self._refs.remove_if_equals('HEAD', 'c0ffee'))
+        self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
+                         self._refs['HEAD'])
+
+        # HEAD is a symref, so shouldn't equal its dereferenced value
+        self.assertFalse(self._refs.remove_if_equals(
+            'HEAD', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
+        self.assertTrue(self._refs.remove_if_equals(
+            'refs/heads/master', '42d06bd4b77fed026b154d16493e5deab78f02ec'))
+        self.assertRaises(KeyError, lambda: self._refs['refs/heads/master'])
+
+        # HEAD is now a broken symref
+        self.assertRaises(KeyError, lambda: self._refs['HEAD'])
+        self.assertEqual('ref: refs/heads/master',
+                          self._refs._read_ref_file('HEAD'))
+
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
+        self.assertFalse(os.path.exists(
+            os.path.join(self._refs.path, 'HEAD.lock')))
+
+        # test removing ref that is only packed
+        self.assertEqual('df6800012397fb85c56e7418dd4eb9405dee075c',
+                         self._refs['refs/tags/refs-0.1'])
+        self.assertTrue(
+            self._refs.remove_if_equals('refs/tags/refs-0.1',
+            'df6800012397fb85c56e7418dd4eb9405dee075c'))
+        self.assertRaises(KeyError, lambda: self._refs['refs/tags/refs-0.1'])