浏览代码

Fix remaining tests.

Jelmer Vernooij 16 年之前
父节点
当前提交
3d3d5cbde4
共有 1 个文件被更改,包括 37 次插入6 次删除
  1. 37 6
      dulwich/repo.py

+ 37 - 6
dulwich/repo.py

@@ -84,12 +84,43 @@ class DiskRefsContainer(RefsContainer):
     def __repr__(self):
         return "%s(%r)" % (self.__class__.__name__, self.path)
 
-    def as_dict(self, base):
-        ret = {}
+    def keys(self, base=None):
+        return list(self.iterkeys(base))
+
+    def iterkeys(self, base=None):
+        if base is not None:
+            return self.itersubkeys(base)
+        else:
+            return self.iterallkeys()
+
+    def itersubkeys(self, base):
         path = self.refpath(base)
         for root, dirs, files in os.walk(path):
-            for name in files:
-                ret[name] = self.follow("%s/%s" % (base, name))
+            dir = root[len(path):].strip("/")
+            for filename in files:
+                yield ("%s/%s" % (dir, filename)).strip("/")
+
+    def iterallkeys(self):
+        if os.path.exists(self.refpath("HEAD")):
+            yield "HEAD"
+        path = self.refpath("")
+        for root, dirs, files in os.walk(self.refpath("refs")):
+            dir = root[len(path):].strip("/")
+            for filename in files:
+                yield ("%s/%s" % (dir, filename)).strip("/")
+
+    def as_dict(self, base=None, follow=True):
+        ret = {}
+        if base is None:
+            keys = self.iterkeys()
+            base = ""
+        else:
+            keys = self.itersubkeys(base)
+        for key in keys:
+            if follow:
+                ret[key] = self.follow(("%s/%s" % (base, key)).strip("/"))
+            else:
+                ret[key] = self[("%s/%s" % (base, key)).strip("/")]
         return ret
 
     def refpath(self, name):
@@ -114,7 +145,7 @@ class DiskRefsContainer(RefsContainer):
             os.makedirs(dirpath)
         f = open(file, 'w')
         try:
-            f.write(value+"\n")
+            f.write(ref+"\n")
         finally:
             f.close()
 
@@ -215,7 +246,7 @@ class Repo(object):
         ret = {}
         if self.head():
             ret['HEAD'] = self.head()
-        ret.update(self.refs.as_dict(REFSDIR))
+        ret.update(self.refs.as_dict())
         ret.update(self.get_packed_refs())
         return ret