Bladeren bron

Add locked_ref context manager and fix packed ref deletion (#1713)

* Add locked_ref context manager for atomic ref operations with methods:
get(), ensure_equals(), set(), set_symbolic_ref(), delete()
* Fix DiskRefsContainer._remove_packed_ref to actually persist deletions
to disk instead of always aborting changes
Jelmer Vernooij 1 maand geleden
bovenliggende
commit
f1b19cf24c
2 gewijzigde bestanden met toevoegingen van 126 en 2 verwijderingen
  1. 7 0
      NEWS
  2. 119 2
      dulwich/refs.py

+ 7 - 0
NEWS

@@ -27,6 +27,13 @@
 
  * Add ``reflog`` command in porcelain. (Jelmer Vernooij)
 
+ * Add ``locked_ref`` context manager for atomic ref operations.
+   (Jelmer Vernooij)
+
+ * Fix bug in ``DiskRefsContainer._remove_packed_ref`` that prevented
+   packed ref deletions from being persisted to disk.
+   (Jelmer Vernooij)
+
  * Optimize writing unchanged refs by avoiding unnecessary fsync
    when ref already has the desired value. File locking behavior
    is preserved to ensure proper concurrency control.

+ 119 - 2
dulwich/refs.py

@@ -23,10 +23,14 @@
 """Ref handling."""
 
 import os
+import types
 import warnings
 from collections.abc import Iterator
 from contextlib import suppress
-from typing import Any, Optional, Union
+from typing import TYPE_CHECKING, Any, Optional, Union
+
+if TYPE_CHECKING:
+    from .file import _GitFile
 
 from .errors import PackedRefsException, RefFormatError
 from .file import GitFile, ensure_dir_exists
@@ -840,6 +844,7 @@ class DiskRefsContainer(RefsContainer):
             self.get_packed_refs()
 
             if name not in self._packed_refs:
+                f.abort()
                 return
 
             del self._packed_refs[name]
@@ -847,8 +852,9 @@ class DiskRefsContainer(RefsContainer):
                 del self._peeled_refs[name]
             write_packed_refs(f, self._packed_refs, self._peeled_refs)
             f.close()
-        finally:
+        except BaseException:
             f.abort()
+            raise
 
     def set_symbolic_ref(
         self,
@@ -1370,3 +1376,114 @@ def serialize_refs(store, refs):
                 ret[ref + PEELED_TAG_SUFFIX] = peeled.id
             ret[ref] = unpeeled.id
     return ret
+
+
+class locked_ref:
+    """Lock a ref while making modifications.
+
+    Works as a context manager.
+    """
+
+    def __init__(self, refs_container: DiskRefsContainer, refname: Ref) -> None:
+        self._refs_container = refs_container
+        self._refname = refname
+        self._file: Optional[_GitFile] = None
+        self._realname: Optional[Ref] = None
+        self._deleted = False
+
+    def __enter__(self) -> "locked_ref":
+        self._refs_container._check_refname(self._refname)
+        try:
+            realnames, _ = self._refs_container.follow(self._refname)
+            self._realname = realnames[-1]
+        except (KeyError, IndexError, SymrefLoop):
+            self._realname = self._refname
+
+        filename = self._refs_container.refpath(self._realname)
+        ensure_dir_exists(os.path.dirname(filename))
+        self._file = GitFile(filename, "wb")
+        return self
+
+    def __exit__(
+        self,
+        exc_type: Optional[type],
+        exc_value: Optional[BaseException],
+        traceback: Optional[types.TracebackType],
+    ) -> None:
+        if self._file:
+            if exc_type is not None or self._deleted:
+                self._file.abort()
+            else:
+                self._file.close()
+
+    def get(self) -> Optional[bytes]:
+        """Get the current value of the ref."""
+        if not self._file:
+            raise RuntimeError("locked_ref not in context")
+
+        current_ref = self._refs_container.read_loose_ref(self._realname)
+        if current_ref is None:
+            current_ref = self._refs_container.get_packed_refs().get(
+                self._realname, None
+            )
+        return current_ref
+
+    def ensure_equals(self, expected_value: Optional[bytes]) -> bool:
+        """Ensure the ref currently equals the expected value.
+
+        Args:
+            expected_value: The expected current value of the ref
+        Returns:
+            True if the ref equals the expected value, False otherwise
+        """
+        current_value = self.get()
+        return current_value == expected_value
+
+    def set(self, new_ref: bytes) -> None:
+        """Set the ref to a new value.
+
+        Args:
+            new_ref: The new SHA1 or symbolic ref value
+        """
+        if not self._file:
+            raise RuntimeError("locked_ref not in context")
+
+        if not (valid_hexsha(new_ref) or new_ref.startswith(SYMREF)):
+            raise ValueError(f"{new_ref!r} must be a valid sha (40 chars) or a symref")
+
+        self._file.seek(0)
+        self._file.truncate()
+        self._file.write(new_ref + b"\n")
+        self._deleted = False
+
+    def set_symbolic_ref(self, target: Ref) -> None:
+        """Make this ref point at another ref.
+
+        Args:
+            target: Name of the ref to point at
+        """
+        if not self._file:
+            raise RuntimeError("locked_ref not in context")
+
+        self._refs_container._check_refname(target)
+        self._file.seek(0)
+        self._file.truncate()
+        self._file.write(SYMREF + target + b"\n")
+        self._deleted = False
+
+    def delete(self) -> None:
+        """Delete the ref file while holding the lock."""
+        if not self._file:
+            raise RuntimeError("locked_ref not in context")
+
+        # Delete the actual ref file while holding the lock
+        if self._realname:
+            filename = self._refs_container.refpath(self._realname)
+            try:
+                if os.path.lexists(filename):
+                    os.remove(filename)
+            except FileNotFoundError:
+                pass
+            self._refs_container._remove_packed_ref(self._realname)
+
+        self._deleted = True