Преглед изворни кода

Allow ZERO_SHA to mean 'does not exist' when setting/removing refs.

Jelmer Vernooij пре 8 година
родитељ
комит
c0ed5019cd
3 измењених фајлова са 20 додато и 7 уклоњено
  1. 7 3
      dulwich/client.py
  2. 6 3
      dulwich/refs.py
  3. 7 1
      dulwich/tests/test_refs.py

+ 7 - 3
dulwich/client.py

@@ -779,6 +779,8 @@ class LocalGitClient(GitClient):
         :return: new_refs dictionary containing the changes that were made
             {refname: new_ref}, including deleted refs.
         """
+        if not progress:
+            progress = lambda x: None
         from dulwich.repo import Repo
 
         with closing(Repo(path)) as target:
@@ -791,7 +793,7 @@ class LocalGitClient(GitClient):
                 if new_sha1 not in have and not new_sha1 in want and new_sha1 != ZERO_SHA:
                     want.append(new_sha1)
 
-            if not want and set(old_refs.items()).issubset(set(new_refs.items())):
+            if not want and set(new_refs.items()).issubset(set(old_refs.items())):
                 return new_refs
 
             target.object_store.add_objects(generate_pack_contents(have, want))
@@ -799,9 +801,11 @@ class LocalGitClient(GitClient):
             for refname, new_sha1 in new_refs.items():
                 old_sha1 = old_refs.get(refname, ZERO_SHA)
                 if new_sha1 != ZERO_SHA:
-                    target.refs.set_if_equals(refname, old_sha1, new_sha1)
+                    if not target.refs.set_if_equals(refname, old_sha1, new_sha1):
+                        progress('unable to set %s to %s' % (refname, new_sha1))
                 else:
-                    target.refs.remove_if_equals(refname, old_sha1)
+                    if not target.refs.remove_if_equals(refname, old_sha1):
+                        progress('unable to remove %s' % refname)
 
         return new_refs
 

+ 6 - 3
dulwich/refs.py

@@ -345,7 +345,10 @@ class DictRefsContainer(RefsContainer):
     def remove_if_equals(self, name, old_ref):
         if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
             return False
-        del self._refs[name]
+        try:
+            del self._refs[name]
+        except KeyError:
+            pass
         return True
 
     def get_peeled(self, name):
@@ -593,7 +596,7 @@ class DiskRefsContainer(RefsContainer):
                     # read again while holding the lock
                     orig_ref = self.read_loose_ref(realname)
                     if orig_ref is None:
-                        orig_ref = self.get_packed_refs().get(realname, None)
+                        orig_ref = self.get_packed_refs().get(realname, ZERO_SHA)
                     if orig_ref != old_ref:
                         f.abort()
                         return False
@@ -657,7 +660,7 @@ class DiskRefsContainer(RefsContainer):
             if old_ref is not None:
                 orig_ref = self.read_loose_ref(name)
                 if orig_ref is None:
-                    orig_ref = self.get_packed_refs().get(name, None)
+                    orig_ref = self.get_packed_refs().get(name, ZERO_SHA)
                 if orig_ref != old_ref:
                     return False
             # may only be packed

+ 7 - 1
dulwich/tests/test_refs.py

@@ -29,6 +29,7 @@ from dulwich import errors
 from dulwich.file import (
     GitFile,
     )
+from dulwich.objects import ZERO_SHA
 from dulwich.refs import (
     DictRefsContainer,
     InfoRefsContainer,
@@ -203,6 +204,10 @@ class RefsContainerTests(object):
                                                  nines))
         self.assertEqual(nines, self._refs[b'refs/heads/master'])
 
+        self.assertTrue(self._refs.set_if_equals(
+            b'refs/heads/nonexistant', ZERO_SHA, nines))
+        self.assertEqual(nines, self._refs[b'refs/heads/nonexistant'])
+
     def test_add_if_new(self):
         nines = b'9' * 40
         self.assertFalse(self._refs.add_if_new(b'refs/heads/master', nines))
@@ -259,10 +264,11 @@ class RefsContainerTests(object):
                          self._refs[b'HEAD'])
         self.assertTrue(self._refs.remove_if_equals(
             b'refs/tags/refs-0.2', b'3ec9c43c84ff242e3ef4a9fc5bc111fd780a76a8'))
+        self.assertTrue(self._refs.remove_if_equals(
+            b'refs/tags/refs-0.2', ZERO_SHA))
         self.assertFalse(b'refs/tags/refs-0.2' in self._refs)
 
 
-
 class DictRefsContainerTests(RefsContainerTests, TestCase):
 
     def setUp(self):