浏览代码

Raise KeyError when RefsContainer can't find a ref, remove broad search for refs.

Jelmer Vernooij 16 年之前
父节点
当前提交
1f7aa28252
共有 2 个文件被更改,包括 8 次插入23 次删除
  1. 7 22
      dulwich/repo.py
  2. 1 1
      dulwich/tests/test_repository.py

+ 7 - 22
dulwich/repo.py

@@ -86,6 +86,8 @@ class RefsContainer(object):
 
     def __getitem__(self, name):
         file = self.refpath(name)
+        if not os.path.exists(file):
+            raise KeyError(name)
         f = open(file, 'rb')
         try:
             return f.read().strip("\n")
@@ -211,36 +213,19 @@ class Repo(object):
             heads = self.heads().values()
         return self.object_store.get_graph_walker(heads)
 
-    def _get_ref(self, file):
-        f = open(file, 'rb')
-        try:
-            contents = f.read()
-            if contents.startswith(SYMREF):
-                ref = contents[len(SYMREF):]
-                if ref[-1] == '\n':
-                    ref = ref[:-1]
-                return follow_ref(self.refs, ref)
-            assert len(contents) == 41, 'Invalid ref in %s' % file
-            return contents[:-1]
-        finally:
-            f.close()
-
     def ref(self, name):
         """Return the SHA1 a ref is pointing to."""
-        for dir in self.ref_locs:
-            file = os.path.join(self.controldir(), dir, name)
-            if os.path.exists(file):
-                return self._get_ref(file)
-        packed_refs = self.get_packed_refs()
-        if name in packed_refs:
-            return packed_refs[name]
+        try:
+            return self.refs.follow(name)
+        except KeyError:
+            return self.get_packed_refs()[name]
 
     def get_refs(self):
         """Get dictionary with all refs."""
         ret = {}
         if self.head():
             ret['HEAD'] = self.head()
-        ret.update(refs.as_dict(REFSDIR))
+        ret.update(self.refs.as_dict(REFSDIR))
         ret.update(self.get_packed_refs())
         return ret
 

+ 1 - 1
dulwich/tests/test_repository.py

@@ -38,7 +38,7 @@ class RepositoryTests(unittest.TestCase):
   
     def test_ref(self):
         r = self.open_repo('a')
-        self.assertEqual(r.ref('master'),
+        self.assertEqual(r.ref('refs/heads/master'),
                          'a90fa2d900a17e99b433217e988c4eb4a2e9a097')
   
     def test_get_refs(self):