Kaynağa Gözat

Move some functions back from DiskRefsContainer to RefsContainer.

Jelmer Vernooij 15 yıl önce
ebeveyn
işleme
4eedd56dd1
2 değiştirilmiş dosya ile 100 ekleme ve 86 silme
  1. 98 83
      dulwich/repo.py
  2. 2 3
      dulwich/tests/test_repository.py

+ 98 - 83
dulwich/repo.py

@@ -105,33 +105,27 @@ def check_ref_format(refname):
 class RefsContainer(object):
     """A container for refs."""
 
-    def as_dict(self, base):
-        """Return the contents of this ref container under base as a dict."""
-        raise NotImplementedError(self.as_dict)
-
     def set_ref(self, name, other):
         """Make a ref point at another ref.
 
         :param name: Name of the ref to set
         :param other: Name of the ref to point at
         """
-        self[name] = "ref: %s\n" % other
-
-    def import_refs(self, base, other):
-        for name, value in other.iteritems():
-            self["%s/%s" % (base, name)] = value
+        self[name] = SYMREF + other + '\n'
 
+    def get_packed_refs(self):
+        """Get contents of the packed-refs file.
 
-class DiskRefsContainer(RefsContainer):
-    """Refs container that reads refs from disk."""
+        :return: Dictionary mapping ref names to SHA1s
 
-    def __init__(self, path):
-        self.path = path
-        self._packed_refs = None
-        self._peeled_refs = {}
+        :note: Will return an empty dictionary when no packed-refs file is
+            present.
+        """
+        raise NotImplementedError(self.get_packed_refs)
 
-    def __repr__(self):
-        return "%s(%r)" % (self.__class__.__name__, self.path)
+    def import_refs(self, base, other):
+        for name, value in other.iteritems():
+            self["%s/%s" % (base, name)] = value
 
     def keys(self, base=None):
         """Refs present in this container.
@@ -145,6 +139,90 @@ class DiskRefsContainer(RefsContainer):
         else:
             return self.allkeys()
 
+    def as_dict(self, base=None):
+        """Return the contents of this container as a dictionary.
+
+        """
+        ret = {}
+        keys = self.keys(base)
+        if base is None:
+            base = ""
+        for key in keys:
+            try:
+                ret[key] = self[("%s/%s" % (base, key)).strip("/")]
+            except KeyError:
+                continue # Unable to resolve
+
+        return ret
+
+    def _check_refname(self, name):
+        """Ensure a refname is valid and lives in refs or is HEAD.
+
+        HEAD is not a valid refname according to git-check-ref-format, but this
+        class needs to be able to touch HEAD. Also, check_ref_format expects
+        refnames without the leading 'refs/', but this class requires that
+        so it cannot touch anything outside the refs dir (or HEAD).
+
+        :param name: The name of the reference.
+        :raises KeyError: if a refname is not HEAD or is otherwise not valid.
+        """
+        if name == 'HEAD':
+            return
+        if not name.startswith('refs/') or not check_ref_format(name[5:]):
+            raise KeyError(name)
+
+    def read_loose_ref(self, name):
+        """Read a loose reference and return its contents.
+
+        :param name: the refname to read
+        :return: The contents of the ref file, or None if it does 
+            not exist.
+        """
+        raise NotImplementedError(self.read_loose_ref)
+
+    def _follow(self, name):
+        """Follow a reference name.
+
+        :return: a tuple of (refname, sha), where refname is the name of the
+            last reference in the symbolic reference chain
+        """
+        self._check_refname(name)
+        contents = SYMREF + name
+        depth = 0
+        while contents.startswith(SYMREF):
+            refname = contents[len(SYMREF):]
+            contents = self.read_loose_ref(refname)
+            if not contents:
+                contents = self.get_packed_refs().get(refname, None)
+                if not contents:
+                    break
+            depth += 1
+            if depth > 5:
+                raise KeyError(name)
+        return refname, contents
+
+    def __getitem__(self, name):
+        """Get the SHA1 for a reference name.
+
+        This method follows all symbolic references.
+        """
+        _, sha = self._follow(name)
+        if sha is None:
+            raise KeyError(name)
+        return sha
+
+
+class DiskRefsContainer(RefsContainer):
+    """Refs container that reads refs from disk."""
+
+    def __init__(self, path):
+        self.path = path
+        self._packed_refs = None
+        self._peeled_refs = {}
+
+    def __repr__(self):
+        return "%s(%r)" % (self.__class__.__name__, self.path)
+
     def subkeys(self, base):
         keys = set()
         path = self.refpath(base)
@@ -175,22 +253,6 @@ class DiskRefsContainer(RefsContainer):
         keys.update(self.get_packed_refs())
         return keys
 
-    def as_dict(self, base=None):
-        """Return the contents of this container as a dictionary.
-
-        """
-        ret = {}
-        keys = self.keys(base)
-        if base is None:
-            base = ""
-        for key in keys:
-            try:
-                ret[key] = self[("%s/%s" % (base, key)).strip("/")]
-            except KeyError:
-                continue # Unable to resolve
-
-        return ret
-
     def refpath(self, name):
         """Return the disk path of a ref.
 
@@ -233,23 +295,7 @@ class DiskRefsContainer(RefsContainer):
                 f.close()
         return self._packed_refs
 
-    def _check_refname(self, name):
-        """Ensure a refname is valid and lives in refs or is HEAD.
-
-        HEAD is not a valid refname according to git-check-ref-format, but this
-        class needs to be able to touch HEAD. Also, check_ref_format expects
-        refnames without the leading 'refs/', but this class requires that
-        so it cannot touch anything outside the refs dir (or HEAD).
-
-        :param name: The name of the reference.
-        :raises KeyError: if a refname is not HEAD or is otherwise not valid.
-        """
-        if name == 'HEAD':
-            return
-        if not name.startswith('refs/') or not check_ref_format(name[5:]):
-            raise KeyError(name)
-
-    def _read_ref_file(self, name):
+    def read_loose_ref(self, name):
         """Read a reference file and return its contents.
 
         If the reference file a symbolic reference, only read the first line of
@@ -278,37 +324,6 @@ class DiskRefsContainer(RefsContainer):
                 return None
             raise
 
-    def _follow(self, name):
-        """Follow a reference name.
-
-        :return: a tuple of (refname, sha), where refname is the name of the
-            last reference in the symbolic reference chain
-        """
-        self._check_refname(name)
-        contents = SYMREF + name
-        depth = 0
-        while contents.startswith(SYMREF):
-            refname = contents[len(SYMREF):]
-            contents = self._read_ref_file(refname)
-            if not contents:
-                contents = self.get_packed_refs().get(refname, None)
-                if not contents:
-                    break
-            depth += 1
-            if depth > 5:
-                raise KeyError(name)
-        return refname, contents
-
-    def __getitem__(self, name):
-        """Get the SHA1 for a reference name.
-
-        This method follows all symbolic references.
-        """
-        _, sha = self._follow(name)
-        if sha is None:
-            raise KeyError(name)
-        return sha
-
     def _remove_packed_ref(self, name):
         if self._packed_refs is None:
             return
@@ -353,7 +368,7 @@ class DiskRefsContainer(RefsContainer):
             if old_ref is not None:
                 try:
                     # read again while holding the lock
-                    orig_ref = self._read_ref_file(realname)
+                    orig_ref = self.read_loose_ref(realname)
                     if orig_ref is None:
                         orig_ref = self.get_packed_refs().get(realname, None)
                     if orig_ref != old_ref:
@@ -418,7 +433,7 @@ class DiskRefsContainer(RefsContainer):
         f = GitFile(filename, 'wb')
         try:
             if old_ref is not None:
-                orig_ref = self._read_ref_file(name)
+                orig_ref = self.read_loose_ref(name)
                 if orig_ref is None:
                     orig_ref = self.get_packed_refs().get(name, None)
                 if orig_ref != old_ref:

+ 2 - 3
dulwich/tests/test_repository.py

@@ -29,7 +29,6 @@ import unittest
 from dulwich import errors
 from dulwich.repo import (
     check_ref_format,
-    DiskRefsContainer,
     Repo,
     read_packed_refs,
     read_packed_refs_with_peeled,
@@ -378,7 +377,7 @@ class RefsContainerTests(unittest.TestCase):
 
     def test_delitem_symbolic(self):
         self.assertEqual('ref: refs/heads/master',
-                          self._refs._read_ref_file('HEAD'))
+                          self._refs.read_loose_ref('HEAD'))
         del self._refs['HEAD']
         self.assertRaises(KeyError, lambda: self._refs['HEAD'])
         self.assertEqual('42d06bd4b77fed026b154d16493e5deab78f02ec',
@@ -401,7 +400,7 @@ class RefsContainerTests(unittest.TestCase):
         # HEAD is now a broken symref
         self.assertRaises(KeyError, lambda: self._refs['HEAD'])
         self.assertEqual('ref: refs/heads/master',
-                          self._refs._read_ref_file('HEAD'))
+                          self._refs.read_loose_ref('HEAD'))
 
         self.assertFalse(os.path.exists(
             os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))