|
@@ -23,10 +23,14 @@
|
|
|
"""Ref handling."""
|
|
|
|
|
|
import os
|
|
|
+import types
|
|
|
import warnings
|
|
|
from collections.abc import Iterator
|
|
|
from contextlib import suppress
|
|
|
-from typing import Any, Optional, Union
|
|
|
+from typing import TYPE_CHECKING, Any, Optional, Union
|
|
|
+
|
|
|
+if TYPE_CHECKING:
|
|
|
+ from .file import _GitFile
|
|
|
|
|
|
from .errors import PackedRefsException, RefFormatError
|
|
|
from .file import GitFile, ensure_dir_exists
|
|
@@ -840,6 +844,7 @@ class DiskRefsContainer(RefsContainer):
|
|
|
self.get_packed_refs()
|
|
|
|
|
|
if name not in self._packed_refs:
|
|
|
+ f.abort()
|
|
|
return
|
|
|
|
|
|
del self._packed_refs[name]
|
|
@@ -847,8 +852,9 @@ class DiskRefsContainer(RefsContainer):
|
|
|
del self._peeled_refs[name]
|
|
|
write_packed_refs(f, self._packed_refs, self._peeled_refs)
|
|
|
f.close()
|
|
|
- finally:
|
|
|
+ except BaseException:
|
|
|
f.abort()
|
|
|
+ raise
|
|
|
|
|
|
def set_symbolic_ref(
|
|
|
self,
|
|
@@ -1370,3 +1376,114 @@ def serialize_refs(store, refs):
|
|
|
ret[ref + PEELED_TAG_SUFFIX] = peeled.id
|
|
|
ret[ref] = unpeeled.id
|
|
|
return ret
|
|
|
+
|
|
|
+
|
|
|
+class locked_ref:
|
|
|
+ """Lock a ref while making modifications.
|
|
|
+
|
|
|
+ Works as a context manager.
|
|
|
+ """
|
|
|
+
|
|
|
+ def __init__(self, refs_container: DiskRefsContainer, refname: Ref) -> None:
|
|
|
+ self._refs_container = refs_container
|
|
|
+ self._refname = refname
|
|
|
+ self._file: Optional[_GitFile] = None
|
|
|
+ self._realname: Optional[Ref] = None
|
|
|
+ self._deleted = False
|
|
|
+
|
|
|
+ def __enter__(self) -> "locked_ref":
|
|
|
+ self._refs_container._check_refname(self._refname)
|
|
|
+ try:
|
|
|
+ realnames, _ = self._refs_container.follow(self._refname)
|
|
|
+ self._realname = realnames[-1]
|
|
|
+ except (KeyError, IndexError, SymrefLoop):
|
|
|
+ self._realname = self._refname
|
|
|
+
|
|
|
+ filename = self._refs_container.refpath(self._realname)
|
|
|
+ ensure_dir_exists(os.path.dirname(filename))
|
|
|
+ self._file = GitFile(filename, "wb")
|
|
|
+ return self
|
|
|
+
|
|
|
+ def __exit__(
|
|
|
+ self,
|
|
|
+ exc_type: Optional[type],
|
|
|
+ exc_value: Optional[BaseException],
|
|
|
+ traceback: Optional[types.TracebackType],
|
|
|
+ ) -> None:
|
|
|
+ if self._file:
|
|
|
+ if exc_type is not None or self._deleted:
|
|
|
+ self._file.abort()
|
|
|
+ else:
|
|
|
+ self._file.close()
|
|
|
+
|
|
|
+ def get(self) -> Optional[bytes]:
|
|
|
+ """Get the current value of the ref."""
|
|
|
+ if not self._file:
|
|
|
+ raise RuntimeError("locked_ref not in context")
|
|
|
+
|
|
|
+ current_ref = self._refs_container.read_loose_ref(self._realname)
|
|
|
+ if current_ref is None:
|
|
|
+ current_ref = self._refs_container.get_packed_refs().get(
|
|
|
+ self._realname, None
|
|
|
+ )
|
|
|
+ return current_ref
|
|
|
+
|
|
|
+ def ensure_equals(self, expected_value: Optional[bytes]) -> bool:
|
|
|
+ """Ensure the ref currently equals the expected value.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ expected_value: The expected current value of the ref
|
|
|
+ Returns:
|
|
|
+ True if the ref equals the expected value, False otherwise
|
|
|
+ """
|
|
|
+ current_value = self.get()
|
|
|
+ return current_value == expected_value
|
|
|
+
|
|
|
+ def set(self, new_ref: bytes) -> None:
|
|
|
+ """Set the ref to a new value.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ new_ref: The new SHA1 or symbolic ref value
|
|
|
+ """
|
|
|
+ if not self._file:
|
|
|
+ raise RuntimeError("locked_ref not in context")
|
|
|
+
|
|
|
+ if not (valid_hexsha(new_ref) or new_ref.startswith(SYMREF)):
|
|
|
+ raise ValueError(f"{new_ref!r} must be a valid sha (40 chars) or a symref")
|
|
|
+
|
|
|
+ self._file.seek(0)
|
|
|
+ self._file.truncate()
|
|
|
+ self._file.write(new_ref + b"\n")
|
|
|
+ self._deleted = False
|
|
|
+
|
|
|
+ def set_symbolic_ref(self, target: Ref) -> None:
|
|
|
+ """Make this ref point at another ref.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ target: Name of the ref to point at
|
|
|
+ """
|
|
|
+ if not self._file:
|
|
|
+ raise RuntimeError("locked_ref not in context")
|
|
|
+
|
|
|
+ self._refs_container._check_refname(target)
|
|
|
+ self._file.seek(0)
|
|
|
+ self._file.truncate()
|
|
|
+ self._file.write(SYMREF + target + b"\n")
|
|
|
+ self._deleted = False
|
|
|
+
|
|
|
+ def delete(self) -> None:
|
|
|
+ """Delete the ref file while holding the lock."""
|
|
|
+ if not self._file:
|
|
|
+ raise RuntimeError("locked_ref not in context")
|
|
|
+
|
|
|
+ # Delete the actual ref file while holding the lock
|
|
|
+ if self._realname:
|
|
|
+ filename = self._refs_container.refpath(self._realname)
|
|
|
+ try:
|
|
|
+ if os.path.lexists(filename):
|
|
|
+ os.remove(filename)
|
|
|
+ except FileNotFoundError:
|
|
|
+ pass
|
|
|
+ self._refs_container._remove_packed_ref(self._realname)
|
|
|
+
|
|
|
+ self._deleted = True
|