Pārlūkot izejas kodu

Add locked_ref context manager and fix packed ref deletion (#1713)

* Add locked_ref context manager for atomic ref operations with methods:
get(), ensure_equals(), set(), set_symbolic_ref(), delete()
* Fix DiskRefsContainer._remove_packed_ref to actually persist deletions
to disk instead of always aborting changes
Jelmer Vernooij 1 mēnesi atpakaļ
vecāks
revīzija
f1b19cf24c
2 mainītis faili ar 126 papildinājumiem un 2 dzēšanām
  1. 7 0
      NEWS
  2. 119 2
      dulwich/refs.py

+ 7 - 0
NEWS

@@ -27,6 +27,13 @@
 
 
  * Add ``reflog`` command in porcelain. (Jelmer Vernooij)
  * Add ``reflog`` command in porcelain. (Jelmer Vernooij)
 
 
+ * Add ``locked_ref`` context manager for atomic ref operations.
+   (Jelmer Vernooij)
+
+ * Fix bug in ``DiskRefsContainer._remove_packed_ref`` that prevented
+   packed ref deletions from being persisted to disk.
+   (Jelmer Vernooij)
+
  * Optimize writing unchanged refs by avoiding unnecessary fsync
  * Optimize writing unchanged refs by avoiding unnecessary fsync
    when ref already has the desired value. File locking behavior
    when ref already has the desired value. File locking behavior
    is preserved to ensure proper concurrency control.
    is preserved to ensure proper concurrency control.

+ 119 - 2
dulwich/refs.py

@@ -23,10 +23,14 @@
 """Ref handling."""
 """Ref handling."""
 
 
 import os
 import os
+import types
 import warnings
 import warnings
 from collections.abc import Iterator
 from collections.abc import Iterator
 from contextlib import suppress
 from contextlib import suppress
-from typing import Any, Optional, Union
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+if TYPE_CHECKING:
+    from .file import _GitFile
 
 
 from .errors import PackedRefsException, RefFormatError
 from .errors import PackedRefsException, RefFormatError
 from .file import GitFile, ensure_dir_exists
 from .file import GitFile, ensure_dir_exists
@@ -840,6 +844,7 @@ class DiskRefsContainer(RefsContainer):
             self.get_packed_refs()
             self.get_packed_refs()
 
 
             if name not in self._packed_refs:
             if name not in self._packed_refs:
+                f.abort()
                 return
                 return
 
 
             del self._packed_refs[name]
             del self._packed_refs[name]
@@ -847,8 +852,9 @@ class DiskRefsContainer(RefsContainer):
                 del self._peeled_refs[name]
                 del self._peeled_refs[name]
             write_packed_refs(f, self._packed_refs, self._peeled_refs)
             write_packed_refs(f, self._packed_refs, self._peeled_refs)
             f.close()
             f.close()
-        finally:
+        except BaseException:
             f.abort()
             f.abort()
+            raise
 
 
     def set_symbolic_ref(
     def set_symbolic_ref(
         self,
         self,
@@ -1370,3 +1376,114 @@ def serialize_refs(store, refs):
                 ret[ref + PEELED_TAG_SUFFIX] = peeled.id
                 ret[ref + PEELED_TAG_SUFFIX] = peeled.id
             ret[ref] = unpeeled.id
             ret[ref] = unpeeled.id
     return ret
     return ret
+
+
+class locked_ref:
+    """Lock a ref while making modifications.
+
+    Works as a context manager.
+    """
+
+    def __init__(self, refs_container: DiskRefsContainer, refname: Ref) -> None:
+        self._refs_container = refs_container
+        self._refname = refname
+        self._file: Optional[_GitFile] = None
+        self._realname: Optional[Ref] = None
+        self._deleted = False
+
+    def __enter__(self) -> "locked_ref":
+        self._refs_container._check_refname(self._refname)
+        try:
+            realnames, _ = self._refs_container.follow(self._refname)
+            self._realname = realnames[-1]
+        except (KeyError, IndexError, SymrefLoop):
+            self._realname = self._refname
+
+        filename = self._refs_container.refpath(self._realname)
+        ensure_dir_exists(os.path.dirname(filename))
+        self._file = GitFile(filename, "wb")
+        return self
+
+    def __exit__(
+        self,
+        exc_type: Optional[type],
+        exc_value: Optional[BaseException],
+        traceback: Optional[types.TracebackType],
+    ) -> None:
+        if self._file:
+            if exc_type is not None or self._deleted:
+                self._file.abort()
+            else:
+                self._file.close()
+
+    def get(self) -> Optional[bytes]:
+        """Get the current value of the ref."""
+        if not self._file:
+            raise RuntimeError("locked_ref not in context")
+
+        current_ref = self._refs_container.read_loose_ref(self._realname)
+        if current_ref is None:
+            current_ref = self._refs_container.get_packed_refs().get(
+                self._realname, None
+            )
+        return current_ref
+
+    def ensure_equals(self, expected_value: Optional[bytes]) -> bool:
+        """Ensure the ref currently equals the expected value.
+
+        Args:
+            expected_value: The expected current value of the ref
+        Returns:
+            True if the ref equals the expected value, False otherwise
+        """
+        current_value = self.get()
+        return current_value == expected_value
+
+    def set(self, new_ref: bytes) -> None:
+        """Set the ref to a new value.
+
+        Args:
+            new_ref: The new SHA1 or symbolic ref value
+        """
+        if not self._file:
+            raise RuntimeError("locked_ref not in context")
+
+        if not (valid_hexsha(new_ref) or new_ref.startswith(SYMREF)):
+            raise ValueError(f"{new_ref!r} must be a valid sha (40 chars) or a symref")
+
+        self._file.seek(0)
+        self._file.truncate()
+        self._file.write(new_ref + b"\n")
+        self._deleted = False
+
+    def set_symbolic_ref(self, target: Ref) -> None:
+        """Make this ref point at another ref.
+
+        Args:
+            target: Name of the ref to point at
+        """
+        if not self._file:
+            raise RuntimeError("locked_ref not in context")
+
+        self._refs_container._check_refname(target)
+        self._file.seek(0)
+        self._file.truncate()
+        self._file.write(SYMREF + target + b"\n")
+        self._deleted = False
+
+    def delete(self) -> None:
+        """Delete the ref file while holding the lock."""
+        if not self._file:
+            raise RuntimeError("locked_ref not in context")
+
+        # Delete the actual ref file while holding the lock
+        if self._realname:
+            filename = self._refs_container.refpath(self._realname)
+            try:
+                if os.path.lexists(filename):
+                    os.remove(filename)
+            except FileNotFoundError:
+                pass
+            self._refs_container._remove_packed_ref(self._realname)
+
+        self._deleted = True