瀏覽代碼

Primarily use encoded paths in DiskRefsContainer.

Jelmer Vernooij 7 年之前
父節點
當前提交
1dc8015185
共有 3 個文件被更改,包括 33 次插入29 次删除
  1. 20 18
      dulwich/refs.py
  2. 1 1
      dulwich/tests/test_porcelain.py
  3. 12 10
      dulwich/tests/test_refs.py

+ 20 - 18
dulwich/refs.py

@@ -478,7 +478,11 @@ class DiskRefsContainer(RefsContainer):
         if getattr(path, 'encode', None) is not None:
             path = path.encode(sys.getfilesystemencoding())
         self.path = path
-        self.worktree_path = worktree_path or path
+        if worktree_path is None:
+            worktree_path = path
+        if getattr(worktree_path, 'encode', None) is not None:
+            worktree_path = worktree_path.encode(sys.getfilesystemencoding())
+        self.worktree_path = worktree_path
         self._packed_refs = None
         self._peeled_refs = None
 
@@ -487,7 +491,7 @@ class DiskRefsContainer(RefsContainer):
 
     def subkeys(self, base):
         subkeys = set()
-        path = self.refpath(base).encode(sys.getfilesystemencoding())
+        path = self.refpath(base)
         for root, unused_dirs, files in os.walk(path):
             dir = root[len(path):]
             if os.path.sep != '/':
@@ -510,7 +514,7 @@ class DiskRefsContainer(RefsContainer):
         if os.path.exists(self.refpath(b'HEAD')):
             allkeys.add(b'HEAD')
         path = self.refpath(b'')
-        refspath = self.refpath('refs').encode(sys.getfilesystemencoding())
+        refspath = self.refpath(b'refs')
         for root, unused_dirs, files in os.walk(refspath):
             dir = root[len(path):]
             if os.path.sep != '/':
@@ -531,7 +535,7 @@ class DiskRefsContainer(RefsContainer):
             name = name.replace("/", os.path.sep)
         # TODO: as the 'HEAD' reference is working tree specific, it
         # should actually not be a part of RefsContainer
-        if name == 'HEAD':
+        if name == b'HEAD':
             return os.path.join(self.worktree_path, name)
         else:
             return os.path.join(self.path, name)
@@ -550,7 +554,7 @@ class DiskRefsContainer(RefsContainer):
             # None if and only if _packed_refs is also None.
             self._packed_refs = {}
             self._peeled_refs = {}
-            path = os.path.join(self.path, 'packed-refs')
+            path = os.path.join(self.path, b'packed-refs')
             try:
                 f = GitFile(path, 'rb')
             except IOError as e:
@@ -618,7 +622,7 @@ class DiskRefsContainer(RefsContainer):
     def _remove_packed_ref(self, name):
         if self._packed_refs is None:
             return
-        filename = os.path.join(self.path, 'packed-refs')
+        filename = os.path.join(self.path, b'packed-refs')
         # reread cached refs from disk, while holding the lock
         f = GitFile(filename, 'wb')
         try:
@@ -647,19 +651,17 @@ class DiskRefsContainer(RefsContainer):
         self._check_refname(name)
         self._check_refname(other)
         filename = self.refpath(name)
+        f = GitFile(filename, 'wb')
         try:
-            f = GitFile(filename, 'wb')
-            try:
-                f.write(SYMREF + other + b'\n')
-            except (IOError, OSError):
-                f.abort()
-                raise
-            else:
-                sha = self.follow(name)[-1]
-                self._log(name, sha, sha, committer=committer,
-                          timestamp=timestamp, timezone=timezone,
-                          message=message)
-        finally:
+            f.write(SYMREF + other + b'\n')
+            sha = self.follow(name)[-1]
+            self._log(name, sha, sha, committer=committer,
+                      timestamp=timestamp, timezone=timezone,
+                      message=message)
+        except:
+            f.abort()
+            raise
+        else:
             f.close()
 
     def set_if_equals(self, name, old_ref, new_ref, committer=None,

+ 1 - 1
dulwich/tests/test_porcelain.py

@@ -1251,7 +1251,7 @@ class UpdateHeadTests(PorcelainTestCase):
         porcelain.update_head(self.repo, "blah")
         self.assertEqual(c1.id, self.repo.head())
         self.assertEqual(b'ref: refs/heads/blah',
-                         self.repo.refs.read_ref('HEAD'))
+                         self.repo.refs.read_ref(b'HEAD'))
 
     def test_set_to_branch_detached(self):
         [c1] = build_commit_graph(self.repo.object_store, [[1]])

+ 12 - 10
dulwich/tests/test_refs.py

@@ -335,7 +335,7 @@ class DiskRefsContainerTests(RefsContainerTests, TestCase):
 
     def test_setitem(self):
         RefsContainerTests.test_setitem(self)
-        f = open(os.path.join(self._refs.path, 'refs', 'some', 'ref'), 'rb')
+        f = open(os.path.join(self._refs.path, b'refs', b'some', b'ref'), 'rb')
         self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec',
                          f.read()[:40])
         f.close()
@@ -346,13 +346,13 @@ class DiskRefsContainerTests(RefsContainerTests, TestCase):
         self.assertEqual(ones, self._refs[b'HEAD'])
 
         # ensure HEAD was not modified
-        f = open(os.path.join(self._refs.path, 'HEAD'), 'rb')
+        f = open(os.path.join(self._refs.path, b'HEAD'), 'rb')
         v = next(iter(f)).rstrip(b'\n\r')
         f.close()
         self.assertEqual(b'ref: refs/heads/master', v)
 
         # ensure the symbolic link was written through
-        f = open(os.path.join(self._refs.path, 'refs', 'heads', 'master'),
+        f = open(os.path.join(self._refs.path, b'refs', b'heads', b'master'),
                  'rb')
         self.assertEqual(ones, f.read()[:40])
         f.close()
@@ -365,9 +365,9 @@ class DiskRefsContainerTests(RefsContainerTests, TestCase):
 
         # ensure lockfile was deleted
         self.assertFalse(os.path.exists(
-            os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
+            os.path.join(self._refs.path, b'refs', b'heads', b'master.lock')))
         self.assertFalse(os.path.exists(
-            os.path.join(self._refs.path, 'HEAD.lock')))
+            os.path.join(self._refs.path, b'HEAD.lock')))
 
     def test_add_if_new_packed(self):
         # don't overwrite packed ref
@@ -406,7 +406,7 @@ class DiskRefsContainerTests(RefsContainerTests, TestCase):
 
     def test_delitem(self):
         RefsContainerTests.test_delitem(self)
-        ref_file = os.path.join(self._refs.path, 'refs', 'heads', 'master')
+        ref_file = os.path.join(self._refs.path, b'refs', b'heads', b'master')
         self.assertFalse(os.path.exists(ref_file))
         self.assertFalse(b'refs/heads/master' in self._refs.get_packed_refs())
 
@@ -417,7 +417,7 @@ class DiskRefsContainerTests(RefsContainerTests, TestCase):
         self.assertRaises(KeyError, lambda: self._refs[b'HEAD'])
         self.assertEqual(b'42d06bd4b77fed026b154d16493e5deab78f02ec',
                          self._refs[b'refs/heads/master'])
-        self.assertFalse(os.path.exists(os.path.join(self._refs.path, 'HEAD')))
+        self.assertFalse(os.path.exists(os.path.join(self._refs.path, b'HEAD')))
 
     def test_remove_if_equals_symref(self):
         # HEAD is a symref, so shouldn't equal its dereferenced value
@@ -433,9 +433,9 @@ class DiskRefsContainerTests(RefsContainerTests, TestCase):
                          self._refs.read_loose_ref(b'HEAD'))
 
         self.assertFalse(os.path.exists(
-            os.path.join(self._refs.path, 'refs', 'heads', 'master.lock')))
+            os.path.join(self._refs.path, b'refs', b'heads', b'master.lock')))
         self.assertFalse(os.path.exists(
-            os.path.join(self._refs.path, 'HEAD.lock')))
+            os.path.join(self._refs.path, b'HEAD.lock')))
 
     def test_remove_packed_without_peeled(self):
         refs_file = os.path.join(self._repo.path, 'packed-refs')
@@ -475,7 +475,9 @@ class DiskRefsContainerTests(RefsContainerTests, TestCase):
         except UnicodeEncodeError:
             raise SkipTest(
                     "filesystem encoding doesn't support special character")
-        p = os.path.join(self._repo.path, 'refs', 'tags', u'schön')
+        p = os.path.join(
+                self._repo.path.encode(sys.getfilesystemencoding()),
+                encoded_ref)
         with open(p, 'w') as f:
             f.write('00' * 20)