refs.py 64 KB


  1. # refs.py -- For dealing with git refs
  2. # Copyright (C) 2008-2013 Jelmer Vernooij <jelmer@jelmer.uk>
  3. #
  4. # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
  5. # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
  6. # General Public License as published by the Free Software Foundation; version 2.0
  7. # or (at your option) any later version. You can redistribute it and/or
  8. # modify it under the terms of either of these two licenses.
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # You should have received a copy of the licenses; if not, see
  17. # <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
  18. # and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
  19. # License, Version 2.0.
  20. #
  21. """Ref handling."""
  22. import os
  23. import types
  24. from collections.abc import Callable, Iterable, Iterator, Mapping
  25. from contextlib import suppress
  26. from typing import (
  27. IO,
  28. TYPE_CHECKING,
  29. Any,
  30. BinaryIO,
  31. NewType,
  32. TypeVar,
  33. )
  34. if TYPE_CHECKING:
  35. from .file import _GitFile
  36. from .errors import PackedRefsException, RefFormatError
  37. from .file import GitFile, ensure_dir_exists
  38. from .objects import ZERO_SHA, ObjectID, Tag, git_line, valid_hexsha
  39. Ref = NewType("Ref", bytes)
  40. T = TypeVar("T", dict[Ref, ObjectID], dict[Ref, ObjectID | None])
  41. HEADREF = Ref(b"HEAD")
  42. SYMREF = b"ref: "
  43. LOCAL_BRANCH_PREFIX = b"refs/heads/"
  44. LOCAL_TAG_PREFIX = b"refs/tags/"
  45. LOCAL_REMOTE_PREFIX = b"refs/remotes/"
  46. LOCAL_NOTES_PREFIX = b"refs/notes/"
  47. LOCAL_REPLACE_PREFIX = b"refs/replace/"
  48. BAD_REF_CHARS: set[int] = set(b"\177 ~^:?*[")
  49. class SymrefLoop(Exception):
  50. """There is a loop between one or more symrefs."""
  51. def __init__(self, ref: bytes, depth: int) -> None:
  52. """Initialize SymrefLoop exception."""
  53. self.ref = ref
  54. self.depth = depth
  55. def parse_symref_value(contents: bytes) -> bytes:
  56. """Parse a symref value.
  57. Args:
  58. contents: Contents to parse
  59. Returns: Destination
  60. """
  61. if contents.startswith(SYMREF):
  62. return contents[len(SYMREF) :].rstrip(b"\r\n")
  63. raise ValueError(contents)
  64. def check_ref_format(refname: Ref) -> bool:
  65. """Check if a refname is correctly formatted.
  66. Implements all the same rules as git-check-ref-format[1].
  67. [1]
  68. http://www.kernel.org/pub/software/scm/git/docs/git-check-ref-format.html
  69. Args:
  70. refname: The refname to check
  71. Returns: True if refname is valid, False otherwise
  72. """
  73. # These could be combined into one big expression, but are listed
  74. # separately to parallel [1].
  75. if b"/." in refname or refname.startswith(b"."): # type: ignore[comparison-overlap]
  76. return False
  77. if b"/" not in refname: # type: ignore[comparison-overlap]
  78. return False
  79. if b".." in refname: # type: ignore[comparison-overlap]
  80. return False
  81. for i, c in enumerate(refname):
  82. if ord(refname[i : i + 1]) < 0o40 or c in BAD_REF_CHARS:
  83. return False
  84. if refname[-1] in b"/.":
  85. return False
  86. if refname.endswith(b".lock"):
  87. return False
  88. if b"@{" in refname: # type: ignore[comparison-overlap]
  89. return False
  90. if b"\\" in refname: # type: ignore[comparison-overlap]
  91. return False
  92. return True
  93. def parse_remote_ref(ref: bytes) -> tuple[bytes, bytes]:
  94. """Parse a remote ref into remote name and branch name.
  95. Args:
  96. ref: Remote ref like b"refs/remotes/origin/main"
  97. Returns:
  98. Tuple of (remote_name, branch_name)
  99. Raises:
  100. ValueError: If ref is not a valid remote ref
  101. """
  102. if not ref.startswith(LOCAL_REMOTE_PREFIX):
  103. raise ValueError(f"Not a remote ref: {ref!r}")
  104. # Remove the prefix
  105. remainder = ref[len(LOCAL_REMOTE_PREFIX) :]
  106. # Split into remote name and branch name
  107. parts = remainder.split(b"/", 1)
  108. if len(parts) != 2:
  109. raise ValueError(f"Invalid remote ref format: {ref!r}")
  110. remote_name, branch_name = parts
  111. return (remote_name, branch_name)
  112. def set_ref_from_raw(refs: "RefsContainer", name: Ref, raw_ref: bytes) -> None:
  113. """Set a reference from a raw ref value.
  114. This handles both symbolic refs (starting with 'ref: ') and direct ObjectID refs.
  115. Args:
  116. refs: The RefsContainer to set the ref in
  117. name: The ref name to set
  118. raw_ref: The raw ref value (either a symbolic ref or an ObjectID)
  119. """
  120. if raw_ref.startswith(SYMREF):
  121. # It's a symbolic ref
  122. target = Ref(raw_ref[len(SYMREF) :])
  123. refs.set_symbolic_ref(name, target)
  124. else:
  125. # It's a direct ObjectID
  126. refs[name] = ObjectID(raw_ref)
  127. class RefsContainer:
  128. """A container for refs."""
  129. def __init__(
  130. self,
  131. logger: Callable[
  132. [bytes, bytes, bytes, bytes | None, int | None, int | None, bytes], None
  133. ]
  134. | None = None,
  135. ) -> None:
  136. """Initialize RefsContainer with optional logger function."""
  137. self._logger = logger
  138. def _log(
  139. self,
  140. ref: bytes,
  141. old_sha: bytes | None,
  142. new_sha: bytes | None,
  143. committer: bytes | None = None,
  144. timestamp: int | None = None,
  145. timezone: int | None = None,
  146. message: bytes | None = None,
  147. ) -> None:
  148. if self._logger is None:
  149. return
  150. if message is None:
  151. return
  152. # Use ZERO_SHA for None values, matching git behavior
  153. if old_sha is None:
  154. old_sha = ZERO_SHA
  155. if new_sha is None:
  156. new_sha = ZERO_SHA
  157. self._logger(ref, old_sha, new_sha, committer, timestamp, timezone, message)
  158. def set_symbolic_ref(
  159. self,
  160. name: Ref,
  161. other: Ref,
  162. committer: bytes | None = None,
  163. timestamp: int | None = None,
  164. timezone: int | None = None,
  165. message: bytes | None = None,
  166. ) -> None:
  167. """Make a ref point at another ref.
  168. Args:
  169. name: Name of the ref to set
  170. other: Name of the ref to point at
  171. committer: Optional committer name/email
  172. timestamp: Optional timestamp
  173. timezone: Optional timezone
  174. message: Optional message
  175. """
  176. raise NotImplementedError(self.set_symbolic_ref)
  177. def get_packed_refs(self) -> dict[Ref, ObjectID]:
  178. """Get contents of the packed-refs file.
  179. Returns: Dictionary mapping ref names to SHA1s
  180. Note: Will return an empty dictionary when no packed-refs file is
  181. present.
  182. """
  183. raise NotImplementedError(self.get_packed_refs)
  184. def add_packed_refs(self, new_refs: Mapping[Ref, ObjectID | None]) -> None:
  185. """Add the given refs as packed refs.
  186. Args:
  187. new_refs: A mapping of ref names to targets; if a target is None that
  188. means remove the ref
  189. """
  190. raise NotImplementedError(self.add_packed_refs)
  191. def get_peeled(self, name: Ref) -> ObjectID | None:
  192. """Return the cached peeled value of a ref, if available.
  193. Args:
  194. name: Name of the ref to peel
  195. Returns: The peeled value of the ref. If the ref is known not point to
  196. a tag, this will be the SHA the ref refers to. If the ref may point
  197. to a tag, but no cached information is available, None is returned.
  198. """
  199. return None
  200. def import_refs(
  201. self,
  202. base: Ref,
  203. other: Mapping[Ref, ObjectID | None],
  204. committer: bytes | None = None,
  205. timestamp: bytes | None = None,
  206. timezone: bytes | None = None,
  207. message: bytes | None = None,
  208. prune: bool = False,
  209. ) -> None:
  210. """Import refs from another repository.
  211. Args:
  212. base: Base ref to import into (e.g., b'refs/remotes/origin')
  213. other: Dictionary of refs to import
  214. committer: Optional committer for reflog
  215. timestamp: Optional timestamp for reflog
  216. timezone: Optional timezone for reflog
  217. message: Optional message for reflog
  218. prune: If True, remove refs not in other
  219. """
  220. if prune:
  221. to_delete = set(self.subkeys(base))
  222. else:
  223. to_delete = set()
  224. for name, value in other.items():
  225. if value is None:
  226. to_delete.add(name)
  227. else:
  228. self.set_if_equals(
  229. Ref(b"/".join((base, name))), None, value, message=message
  230. )
  231. if to_delete:
  232. try:
  233. to_delete.remove(name)
  234. except KeyError:
  235. pass
  236. for ref in to_delete:
  237. self.remove_if_equals(Ref(b"/".join((base, ref))), None, message=message)
  238. def allkeys(self) -> set[Ref]:
  239. """All refs present in this container."""
  240. raise NotImplementedError(self.allkeys)
  241. def __iter__(self) -> Iterator[Ref]:
  242. """Iterate over all reference keys."""
  243. return iter(self.allkeys())
  244. def keys(self, base: Ref | None = None) -> set[Ref]:
  245. """Refs present in this container.
  246. Args:
  247. base: An optional base to return refs under.
  248. Returns: An unsorted set of valid refs in this container, including
  249. packed refs.
  250. """
  251. if base is not None:
  252. return self.subkeys(base)
  253. else:
  254. return self.allkeys()
  255. def subkeys(self, base: Ref) -> set[Ref]:
  256. """Refs present in this container under a base.
  257. Args:
  258. base: The base to return refs under.
  259. Returns: A set of valid refs in this container under the base; the base
  260. prefix is stripped from the ref names returned.
  261. """
  262. keys: set[Ref] = set()
  263. base_len = len(base) + 1
  264. for refname in self.allkeys():
  265. if refname.startswith(base):
  266. keys.add(Ref(refname[base_len:]))
  267. return keys
  268. def as_dict(self, base: Ref | None = None) -> dict[Ref, ObjectID]:
  269. """Return the contents of this container as a dictionary."""
  270. ret: dict[Ref, ObjectID] = {}
  271. keys = self.keys(base)
  272. base_bytes: bytes
  273. if base is None:
  274. base_bytes = b""
  275. else:
  276. base_bytes = base.rstrip(b"/")
  277. for key in keys:
  278. try:
  279. ret[key] = self[Ref((base_bytes + b"/" + key).strip(b"/"))]
  280. except (SymrefLoop, KeyError):
  281. continue # Unable to resolve
  282. return ret
  283. def _check_refname(self, name: Ref) -> None:
  284. """Ensure a refname is valid and lives in refs or is HEAD.
  285. HEAD is not a valid refname according to git-check-ref-format, but this
  286. class needs to be able to touch HEAD. Also, check_ref_format expects
  287. refnames without the leading 'refs/', but this class requires that
  288. so it cannot touch anything outside the refs dir (or HEAD).
  289. Args:
  290. name: The name of the reference.
  291. Raises:
  292. KeyError: if a refname is not HEAD or is otherwise not valid.
  293. """
  294. if name in (HEADREF, Ref(b"refs/stash")):
  295. return
  296. if not name.startswith(b"refs/") or not check_ref_format(Ref(name[5:])):
  297. raise RefFormatError(name)
  298. def read_ref(self, refname: Ref) -> bytes | None:
  299. """Read a reference without following any references.
  300. Args:
  301. refname: The name of the reference
  302. Returns: The contents of the ref file, or None if it does
  303. not exist.
  304. """
  305. contents = self.read_loose_ref(refname)
  306. if not contents:
  307. contents = self.get_packed_refs().get(refname, None)
  308. return contents
  309. def read_loose_ref(self, name: Ref) -> bytes | None:
  310. """Read a loose reference and return its contents.
  311. Args:
  312. name: the refname to read
  313. Returns: The contents of the ref file, or None if it does
  314. not exist.
  315. """
  316. raise NotImplementedError(self.read_loose_ref)
  317. def follow(self, name: Ref) -> tuple[list[Ref], ObjectID | None]:
  318. """Follow a reference name.
  319. Returns: a tuple of (refnames, sha), wheres refnames are the names of
  320. references in the chain
  321. """
  322. contents: bytes | None = SYMREF + name
  323. depth = 0
  324. refnames: list[Ref] = []
  325. while contents and contents.startswith(SYMREF):
  326. refname = Ref(contents[len(SYMREF) :])
  327. refnames.append(refname)
  328. contents = self.read_ref(refname)
  329. if not contents:
  330. break
  331. depth += 1
  332. if depth > 5:
  333. raise SymrefLoop(name, depth)
  334. return refnames, ObjectID(contents) if contents else None
  335. def __contains__(self, refname: Ref) -> bool:
  336. """Check if a reference exists."""
  337. if self.read_ref(refname):
  338. return True
  339. return False
  340. def __getitem__(self, name: Ref) -> ObjectID:
  341. """Get the SHA1 for a reference name.
  342. This method follows all symbolic references.
  343. """
  344. _, sha = self.follow(name)
  345. if sha is None:
  346. raise KeyError(name)
  347. return sha
  348. def set_if_equals(
  349. self,
  350. name: Ref,
  351. old_ref: ObjectID | None,
  352. new_ref: ObjectID,
  353. committer: bytes | None = None,
  354. timestamp: int | None = None,
  355. timezone: int | None = None,
  356. message: bytes | None = None,
  357. ) -> bool:
  358. """Set a refname to new_ref only if it currently equals old_ref.
  359. This method follows all symbolic references if applicable for the
  360. subclass, and can be used to perform an atomic compare-and-swap
  361. operation.
  362. Args:
  363. name: The refname to set.
  364. old_ref: The old sha the refname must refer to, or None to set
  365. unconditionally.
  366. new_ref: The new sha the refname will refer to.
  367. committer: Optional committer name/email
  368. timestamp: Optional timestamp
  369. timezone: Optional timezone
  370. message: Message for reflog
  371. Returns: True if the set was successful, False otherwise.
  372. """
  373. raise NotImplementedError(self.set_if_equals)
  374. def add_if_new(
  375. self,
  376. name: Ref,
  377. ref: ObjectID,
  378. committer: bytes | None = None,
  379. timestamp: int | None = None,
  380. timezone: int | None = None,
  381. message: bytes | None = None,
  382. ) -> bool:
  383. """Add a new reference only if it does not already exist.
  384. Args:
  385. name: Ref name
  386. ref: Ref value
  387. committer: Optional committer name/email
  388. timestamp: Optional timestamp
  389. timezone: Optional timezone
  390. message: Optional message for reflog
  391. """
  392. raise NotImplementedError(self.add_if_new)
  393. def __setitem__(self, name: Ref, ref: ObjectID) -> None:
  394. """Set a reference name to point to the given SHA1.
  395. This method follows all symbolic references if applicable for the
  396. subclass.
  397. Note: This method unconditionally overwrites the contents of a
  398. reference. To update atomically only if the reference has not
  399. changed, use set_if_equals().
  400. Args:
  401. name: The refname to set.
  402. ref: The new sha the refname will refer to.
  403. """
  404. if not (valid_hexsha(ref) or ref.startswith(SYMREF)):
  405. raise ValueError(f"{ref!r} must be a valid sha (40 chars) or a symref")
  406. self.set_if_equals(name, None, ref)
  407. def remove_if_equals(
  408. self,
  409. name: Ref,
  410. old_ref: ObjectID | None,
  411. committer: bytes | None = None,
  412. timestamp: int | None = None,
  413. timezone: int | None = None,
  414. message: bytes | None = None,
  415. ) -> bool:
  416. """Remove a refname only if it currently equals old_ref.
  417. This method does not follow symbolic references, even if applicable for
  418. the subclass. It can be used to perform an atomic compare-and-delete
  419. operation.
  420. Args:
  421. name: The refname to delete.
  422. old_ref: The old sha the refname must refer to, or None to
  423. delete unconditionally.
  424. committer: Optional committer name/email
  425. timestamp: Optional timestamp
  426. timezone: Optional timezone
  427. message: Message for reflog
  428. Returns: True if the delete was successful, False otherwise.
  429. """
  430. raise NotImplementedError(self.remove_if_equals)
  431. def __delitem__(self, name: Ref) -> None:
  432. """Remove a refname.
  433. This method does not follow symbolic references, even if applicable for
  434. the subclass.
  435. Note: This method unconditionally deletes the contents of a reference.
  436. To delete atomically only if the reference has not changed, use
  437. remove_if_equals().
  438. Args:
  439. name: The refname to delete.
  440. """
  441. self.remove_if_equals(name, None)
  442. def get_symrefs(self) -> dict[Ref, Ref]:
  443. """Get a dict with all symrefs in this container.
  444. Returns: Dictionary mapping source ref to target ref
  445. """
  446. ret: dict[Ref, Ref] = {}
  447. for src in self.allkeys():
  448. try:
  449. ref_value = self.read_ref(src)
  450. assert ref_value is not None
  451. dst = parse_symref_value(ref_value)
  452. except ValueError:
  453. pass
  454. else:
  455. ret[src] = Ref(dst)
  456. return ret
  457. def pack_refs(self, all: bool = False) -> None:
  458. """Pack loose refs into packed-refs file.
  459. Args:
  460. all: If True, pack all refs. If False, only pack tags.
  461. """
  462. raise NotImplementedError(self.pack_refs)
  463. class DictRefsContainer(RefsContainer):
  464. """RefsContainer backed by a simple dict.
  465. This container does not support symbolic or packed references and is not
  466. threadsafe.
  467. """
  468. def __init__(
  469. self,
  470. refs: dict[Ref, bytes],
  471. logger: Callable[
  472. [
  473. bytes,
  474. bytes | None,
  475. bytes | None,
  476. bytes | None,
  477. int | None,
  478. int | None,
  479. bytes | None,
  480. ],
  481. None,
  482. ]
  483. | None = None,
  484. ) -> None:
  485. """Initialize DictRefsContainer with refs dictionary and optional logger."""
  486. super().__init__(logger=logger)
  487. self._refs = refs
  488. self._peeled: dict[Ref, ObjectID] = {}
  489. self._watchers: set[Any] = set()
  490. def allkeys(self) -> set[Ref]:
  491. """Return all reference keys."""
  492. return set(self._refs.keys())
  493. def read_loose_ref(self, name: Ref) -> bytes | None:
  494. """Read a loose reference."""
  495. return self._refs.get(name, None)
  496. def get_packed_refs(self) -> dict[Ref, ObjectID]:
  497. """Get packed references."""
  498. return {}
  499. def _notify(self, ref: bytes, newsha: bytes | None) -> None:
  500. for watcher in self._watchers:
  501. watcher._notify((ref, newsha))
  502. def set_symbolic_ref(
  503. self,
  504. name: Ref,
  505. other: Ref,
  506. committer: bytes | None = None,
  507. timestamp: int | None = None,
  508. timezone: int | None = None,
  509. message: bytes | None = None,
  510. ) -> None:
  511. """Make a ref point at another ref.
  512. Args:
  513. name: Name of the ref to set
  514. other: Name of the ref to point at
  515. committer: Optional committer name for reflog
  516. timestamp: Optional timestamp for reflog
  517. timezone: Optional timezone for reflog
  518. message: Optional message for reflog
  519. """
  520. old = self.follow(name)[-1]
  521. new = SYMREF + other
  522. self._refs[name] = new
  523. self._notify(name, new)
  524. self._log(
  525. name,
  526. old,
  527. new,
  528. committer=committer,
  529. timestamp=timestamp,
  530. timezone=timezone,
  531. message=message,
  532. )
  533. def set_if_equals(
  534. self,
  535. name: Ref,
  536. old_ref: ObjectID | None,
  537. new_ref: ObjectID,
  538. committer: bytes | None = None,
  539. timestamp: int | None = None,
  540. timezone: int | None = None,
  541. message: bytes | None = None,
  542. ) -> bool:
  543. """Set a refname to new_ref only if it currently equals old_ref.
  544. This method follows all symbolic references, and can be used to perform
  545. an atomic compare-and-swap operation.
  546. Args:
  547. name: The refname to set.
  548. old_ref: The old sha the refname must refer to, or None to set
  549. unconditionally.
  550. new_ref: The new sha the refname will refer to.
  551. committer: Optional committer name for reflog
  552. timestamp: Optional timestamp for reflog
  553. timezone: Optional timezone for reflog
  554. message: Optional message for reflog
  555. Returns:
  556. True if the set was successful, False otherwise.
  557. """
  558. if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
  559. return False
  560. # Only update the specific ref requested, not the whole chain
  561. self._check_refname(name)
  562. old = self._refs.get(name)
  563. self._refs[name] = new_ref
  564. self._notify(name, new_ref)
  565. self._log(
  566. name,
  567. old,
  568. new_ref,
  569. committer=committer,
  570. timestamp=timestamp,
  571. timezone=timezone,
  572. message=message,
  573. )
  574. return True
  575. def add_if_new(
  576. self,
  577. name: Ref,
  578. ref: ObjectID,
  579. committer: bytes | None = None,
  580. timestamp: int | None = None,
  581. timezone: int | None = None,
  582. message: bytes | None = None,
  583. ) -> bool:
  584. """Add a new reference only if it does not already exist.
  585. Args:
  586. name: Ref name
  587. ref: Ref value
  588. committer: Optional committer name for reflog
  589. timestamp: Optional timestamp for reflog
  590. timezone: Optional timezone for reflog
  591. message: Optional message for reflog
  592. Returns:
  593. True if the add was successful, False otherwise.
  594. """
  595. if name in self._refs:
  596. return False
  597. self._refs[name] = ref
  598. self._notify(name, ref)
  599. self._log(
  600. name,
  601. None,
  602. ref,
  603. committer=committer,
  604. timestamp=timestamp,
  605. timezone=timezone,
  606. message=message,
  607. )
  608. return True
  609. def remove_if_equals(
  610. self,
  611. name: Ref,
  612. old_ref: ObjectID | None,
  613. committer: bytes | None = None,
  614. timestamp: int | None = None,
  615. timezone: int | None = None,
  616. message: bytes | None = None,
  617. ) -> bool:
  618. """Remove a refname only if it currently equals old_ref.
  619. This method does not follow symbolic references. It can be used to
  620. perform an atomic compare-and-delete operation.
  621. Args:
  622. name: The refname to delete.
  623. old_ref: The old sha the refname must refer to, or None to
  624. delete unconditionally.
  625. committer: Optional committer name for reflog
  626. timestamp: Optional timestamp for reflog
  627. timezone: Optional timezone for reflog
  628. message: Optional message for reflog
  629. Returns:
  630. True if the delete was successful, False otherwise.
  631. """
  632. if old_ref is not None and self._refs.get(name, ZERO_SHA) != old_ref:
  633. return False
  634. try:
  635. old = self._refs.pop(name)
  636. except KeyError:
  637. pass
  638. else:
  639. self._notify(name, None)
  640. self._log(
  641. name,
  642. old,
  643. None,
  644. committer=committer,
  645. timestamp=timestamp,
  646. timezone=timezone,
  647. message=message,
  648. )
  649. return True
  650. def get_peeled(self, name: Ref) -> ObjectID | None:
  651. """Get peeled version of a reference."""
  652. return self._peeled.get(name)
  653. def _update(self, refs: Mapping[Ref, ObjectID]) -> None:
  654. """Update multiple refs; intended only for testing."""
  655. # TODO(dborowitz): replace this with a public function that uses
  656. # set_if_equal.
  657. for ref, sha in refs.items():
  658. self.set_if_equals(ref, None, sha)
  659. def _update_peeled(self, peeled: Mapping[Ref, ObjectID]) -> None:
  660. """Update cached peeled refs; intended only for testing."""
  661. self._peeled.update(peeled)
  662. class DiskRefsContainer(RefsContainer):
  663. """Refs container that reads refs from disk."""
  664. def __init__(
  665. self,
  666. path: str | bytes | os.PathLike[str],
  667. worktree_path: str | bytes | os.PathLike[str] | None = None,
  668. logger: Callable[
  669. [bytes, bytes, bytes, bytes | None, int | None, int | None, bytes], None
  670. ]
  671. | None = None,
  672. ) -> None:
  673. """Initialize DiskRefsContainer."""
  674. super().__init__(logger=logger)
  675. # Convert path-like objects to strings, then to bytes for Git compatibility
  676. self.path = os.fsencode(os.fspath(path))
  677. if worktree_path is None:
  678. self.worktree_path = self.path
  679. else:
  680. self.worktree_path = os.fsencode(os.fspath(worktree_path))
  681. self._packed_refs: dict[Ref, ObjectID] | None = None
  682. self._peeled_refs: dict[Ref, ObjectID] | None = None
  683. def __repr__(self) -> str:
  684. """Return string representation of DiskRefsContainer."""
  685. return f"{self.__class__.__name__}({self.path!r})"
  686. def _iter_dir(
  687. self,
  688. path: bytes,
  689. base: bytes,
  690. dir_filter: Callable[[bytes], bool] | None = None,
  691. ) -> Iterator[Ref]:
  692. refspath = os.path.join(path, base.rstrip(b"/"))
  693. prefix_len = len(os.path.join(path, b""))
  694. for root, dirs, files in os.walk(refspath):
  695. directory = root[prefix_len:]
  696. if os.path.sep != "/":
  697. directory = directory.replace(os.fsencode(os.path.sep), b"/")
  698. if dir_filter is not None:
  699. dirs[:] = [
  700. d for d in dirs if dir_filter(b"/".join([directory, d, b""]))
  701. ]
  702. for filename in files:
  703. refname = b"/".join([directory, filename])
  704. if check_ref_format(Ref(refname)):
  705. yield Ref(refname)
  706. def _iter_loose_refs(self, base: bytes = b"refs/") -> Iterator[Ref]:
  707. base = base.rstrip(b"/") + b"/"
  708. search_paths: list[tuple[bytes, Callable[[bytes], bool] | None]] = []
  709. if base != b"refs/":
  710. path = self.worktree_path if is_per_worktree_ref(base) else self.path
  711. search_paths.append((path, None))
  712. elif self.worktree_path == self.path:
  713. # Iterate through all the refs from the main worktree
  714. search_paths.append((self.path, None))
  715. else:
  716. # Iterate through all the shared refs from the commondir, excluding per-worktree refs
  717. search_paths.append((self.path, lambda r: not is_per_worktree_ref(r)))
  718. # Iterate through all the per-worktree refs from the worktree's gitdir
  719. search_paths.append((self.worktree_path, is_per_worktree_ref))
  720. for path, dir_filter in search_paths:
  721. yield from self._iter_dir(path, base, dir_filter=dir_filter)
  722. def subkeys(self, base: Ref) -> set[Ref]:
  723. """Return subkeys under a given base reference path."""
  724. subkeys: set[Ref] = set()
  725. for key in self._iter_loose_refs(base):
  726. if key.startswith(base):
  727. subkeys.add(Ref(key[len(base) :].strip(b"/")))
  728. for key in self.get_packed_refs():
  729. if key.startswith(base):
  730. subkeys.add(Ref(key[len(base) :].strip(b"/")))
  731. return subkeys
  732. def allkeys(self) -> set[Ref]:
  733. """Return all reference keys."""
  734. allkeys: set[Ref] = set()
  735. if os.path.exists(self.refpath(HEADREF)):
  736. allkeys.add(Ref(HEADREF))
  737. allkeys.update(self._iter_loose_refs())
  738. allkeys.update(self.get_packed_refs())
  739. return allkeys
  740. def refpath(self, name: bytes) -> bytes:
  741. """Return the disk path of a ref."""
  742. path = name
  743. if os.path.sep != "/":
  744. path = path.replace(b"/", os.fsencode(os.path.sep))
  745. root_dir = self.worktree_path if is_per_worktree_ref(name) else self.path
  746. return os.path.join(root_dir, path)
  747. def get_packed_refs(self) -> dict[Ref, ObjectID]:
  748. """Get contents of the packed-refs file.
  749. Returns: Dictionary mapping ref names to SHA1s
  750. Note: Will return an empty dictionary when no packed-refs file is
  751. present.
  752. """
  753. # TODO: invalidate the cache on repacking
  754. if self._packed_refs is None:
  755. # set both to empty because we want _peeled_refs to be
  756. # None if and only if _packed_refs is also None.
  757. self._packed_refs = {}
  758. self._peeled_refs = {}
  759. path = os.path.join(self.path, b"packed-refs")
  760. try:
  761. f = GitFile(path, "rb")
  762. except FileNotFoundError:
  763. return {}
  764. with f:
  765. first_line = next(iter(f)).rstrip()
  766. if first_line.startswith(b"# pack-refs") and b" peeled" in first_line:
  767. for sha, name, peeled in read_packed_refs_with_peeled(f):
  768. self._packed_refs[name] = sha
  769. if peeled:
  770. self._peeled_refs[name] = peeled
  771. else:
  772. f.seek(0)
  773. for sha, name in read_packed_refs(f):
  774. self._packed_refs[name] = sha
  775. return self._packed_refs
  776. def add_packed_refs(self, new_refs: Mapping[Ref, ObjectID | None]) -> None:
  777. """Add the given refs as packed refs.
  778. Args:
  779. new_refs: A mapping of ref names to targets; if a target is None that
  780. means remove the ref
  781. """
  782. if not new_refs:
  783. return
  784. path = os.path.join(self.path, b"packed-refs")
  785. with GitFile(path, "wb") as f:
  786. # reread cached refs from disk, while holding the lock
  787. packed_refs = self.get_packed_refs().copy()
  788. for ref, target in new_refs.items():
  789. # sanity check
  790. if ref == HEADREF:
  791. raise ValueError("cannot pack HEAD")
  792. # remove any loose refs pointing to this one -- please
  793. # note that this bypasses remove_if_equals as we don't
  794. # want to affect packed refs in here
  795. with suppress(OSError):
  796. os.remove(self.refpath(ref))
  797. if target is not None:
  798. packed_refs[ref] = target
  799. else:
  800. packed_refs.pop(ref, None)
  801. write_packed_refs(f, packed_refs, self._peeled_refs)
  802. self._packed_refs = packed_refs
  803. def get_peeled(self, name: Ref) -> ObjectID | None:
  804. """Return the cached peeled value of a ref, if available.
  805. Args:
  806. name: Name of the ref to peel
  807. Returns: The peeled value of the ref. If the ref is known not point to
  808. a tag, this will be the SHA the ref refers to. If the ref may point
  809. to a tag, but no cached information is available, None is returned.
  810. """
  811. self.get_packed_refs()
  812. if (
  813. self._peeled_refs is None
  814. or self._packed_refs is None
  815. or name not in self._packed_refs
  816. ):
  817. # No cache: no peeled refs were read, or this ref is loose
  818. return None
  819. if name in self._peeled_refs:
  820. return self._peeled_refs[name]
  821. else:
  822. # Known not peelable
  823. return self[name]
  824. def read_loose_ref(self, name: Ref) -> bytes | None:
  825. """Read a reference file and return its contents.
  826. If the reference file a symbolic reference, only read the first line of
  827. the file. Otherwise, only read the first 40 bytes.
  828. Args:
  829. name: the refname to read, relative to refpath
  830. Returns: The contents of the ref file, or None if the file does not
  831. exist.
  832. Raises:
  833. IOError: if any other error occurs
  834. """
  835. filename = self.refpath(name)
  836. try:
  837. with GitFile(filename, "rb") as f:
  838. header = f.read(len(SYMREF))
  839. if header == SYMREF:
  840. # Read only the first line
  841. return header + next(iter(f)).rstrip(b"\r\n")
  842. else:
  843. # Read only the first 40 bytes
  844. return header + f.read(40 - len(SYMREF))
  845. except (OSError, UnicodeError):
  846. # don't assume anything specific about the error; in
  847. # particular, invalid or forbidden paths can raise weird
  848. # errors depending on the specific operating system
  849. return None
  850. def _remove_packed_ref(self, name: Ref) -> None:
  851. if self._packed_refs is None:
  852. return
  853. filename = os.path.join(self.path, b"packed-refs")
  854. # reread cached refs from disk, while holding the lock
  855. f = GitFile(filename, "wb")
  856. try:
  857. self._packed_refs = None
  858. self.get_packed_refs()
  859. if self._packed_refs is None or name not in self._packed_refs:
  860. f.abort()
  861. return
  862. del self._packed_refs[name]
  863. if self._peeled_refs is not None:
  864. with suppress(KeyError):
  865. del self._peeled_refs[name]
  866. write_packed_refs(f, self._packed_refs, self._peeled_refs)
  867. f.close()
  868. except BaseException:
  869. f.abort()
  870. raise
  871. def set_symbolic_ref(
  872. self,
  873. name: Ref,
  874. other: Ref,
  875. committer: bytes | None = None,
  876. timestamp: int | None = None,
  877. timezone: int | None = None,
  878. message: bytes | None = None,
  879. ) -> None:
  880. """Make a ref point at another ref.
  881. Args:
  882. name: Name of the ref to set
  883. other: Name of the ref to point at
  884. committer: Optional committer name
  885. timestamp: Optional timestamp
  886. timezone: Optional timezone
  887. message: Optional message to describe the change
  888. """
  889. self._check_refname(name)
  890. self._check_refname(other)
  891. filename = self.refpath(name)
  892. f = GitFile(filename, "wb")
  893. try:
  894. f.write(SYMREF + other + b"\n")
  895. sha = self.follow(name)[-1]
  896. self._log(
  897. name,
  898. sha,
  899. sha,
  900. committer=committer,
  901. timestamp=timestamp,
  902. timezone=timezone,
  903. message=message,
  904. )
  905. except BaseException:
  906. f.abort()
  907. raise
  908. else:
  909. f.close()
  910. def set_if_equals(
  911. self,
  912. name: Ref,
  913. old_ref: ObjectID | None,
  914. new_ref: ObjectID,
  915. committer: bytes | None = None,
  916. timestamp: int | None = None,
  917. timezone: int | None = None,
  918. message: bytes | None = None,
  919. ) -> bool:
  920. """Set a refname to new_ref only if it currently equals old_ref.
  921. This method follows all symbolic references, and can be used to perform
  922. an atomic compare-and-swap operation.
  923. Args:
  924. name: The refname to set.
  925. old_ref: The old sha the refname must refer to, or None to set
  926. unconditionally.
  927. new_ref: The new sha the refname will refer to.
  928. committer: Optional committer name
  929. timestamp: Optional timestamp
  930. timezone: Optional timezone
  931. message: Set message for reflog
  932. Returns: True if the set was successful, False otherwise.
  933. """
  934. self._check_refname(name)
  935. try:
  936. realnames, _ = self.follow(name)
  937. realname = realnames[-1]
  938. except (KeyError, IndexError, SymrefLoop):
  939. realname = name
  940. filename = self.refpath(realname)
  941. # make sure none of the ancestor folders is in packed refs
  942. probe_ref = Ref(os.path.dirname(realname))
  943. packed_refs = self.get_packed_refs()
  944. while probe_ref:
  945. if packed_refs.get(probe_ref, None) is not None:
  946. raise NotADirectoryError(filename)
  947. probe_ref = Ref(os.path.dirname(probe_ref))
  948. ensure_dir_exists(os.path.dirname(filename))
  949. with GitFile(filename, "wb") as f:
  950. if old_ref is not None:
  951. try:
  952. # read again while holding the lock to handle race conditions
  953. orig_ref = self.read_loose_ref(realname)
  954. if orig_ref is None:
  955. orig_ref = self.get_packed_refs().get(realname, ZERO_SHA)
  956. if orig_ref != old_ref:
  957. f.abort()
  958. return False
  959. except OSError:
  960. f.abort()
  961. raise
  962. # Check if ref already has the desired value while holding the lock
  963. # This avoids fsync when ref is unchanged but still detects lock conflicts
  964. current_ref = self.read_loose_ref(realname)
  965. if current_ref is None:
  966. current_ref = packed_refs.get(realname, None)
  967. if current_ref is not None and current_ref == new_ref:
  968. # Ref already has desired value, abort write to avoid fsync
  969. f.abort()
  970. return True
  971. try:
  972. f.write(new_ref + b"\n")
  973. except OSError:
  974. f.abort()
  975. raise
  976. self._log(
  977. realname,
  978. old_ref,
  979. new_ref,
  980. committer=committer,
  981. timestamp=timestamp,
  982. timezone=timezone,
  983. message=message,
  984. )
  985. return True
  986. def add_if_new(
  987. self,
  988. name: Ref,
  989. ref: ObjectID,
  990. committer: bytes | None = None,
  991. timestamp: int | None = None,
  992. timezone: int | None = None,
  993. message: bytes | None = None,
  994. ) -> bool:
  995. """Add a new reference only if it does not already exist.
  996. This method follows symrefs, and only ensures that the last ref in the
  997. chain does not exist.
  998. Args:
  999. name: The refname to set.
  1000. ref: The new sha the refname will refer to.
  1001. committer: Optional committer name
  1002. timestamp: Optional timestamp
  1003. timezone: Optional timezone
  1004. message: Optional message for reflog
  1005. Returns: True if the add was successful, False otherwise.
  1006. """
  1007. try:
  1008. realnames, contents = self.follow(name)
  1009. if contents is not None:
  1010. return False
  1011. realname = realnames[-1]
  1012. except (KeyError, IndexError):
  1013. realname = name
  1014. self._check_refname(realname)
  1015. filename = self.refpath(realname)
  1016. ensure_dir_exists(os.path.dirname(filename))
  1017. with GitFile(filename, "wb") as f:
  1018. if os.path.exists(filename) or name in self.get_packed_refs():
  1019. f.abort()
  1020. return False
  1021. try:
  1022. f.write(ref + b"\n")
  1023. except OSError:
  1024. f.abort()
  1025. raise
  1026. else:
  1027. self._log(
  1028. name,
  1029. None,
  1030. ref,
  1031. committer=committer,
  1032. timestamp=timestamp,
  1033. timezone=timezone,
  1034. message=message,
  1035. )
  1036. return True
  1037. def remove_if_equals(
  1038. self,
  1039. name: Ref,
  1040. old_ref: ObjectID | None,
  1041. committer: bytes | None = None,
  1042. timestamp: int | None = None,
  1043. timezone: int | None = None,
  1044. message: bytes | None = None,
  1045. ) -> bool:
  1046. """Remove a refname only if it currently equals old_ref.
  1047. This method does not follow symbolic references. It can be used to
  1048. perform an atomic compare-and-delete operation.
  1049. Args:
  1050. name: The refname to delete.
  1051. old_ref: The old sha the refname must refer to, or None to
  1052. delete unconditionally.
  1053. committer: Optional committer name
  1054. timestamp: Optional timestamp
  1055. timezone: Optional timezone
  1056. message: Optional message
  1057. Returns: True if the delete was successful, False otherwise.
  1058. """
  1059. self._check_refname(name)
  1060. filename = self.refpath(name)
  1061. ensure_dir_exists(os.path.dirname(filename))
  1062. f = GitFile(filename, "wb")
  1063. try:
  1064. if old_ref is not None:
  1065. orig_ref = self.read_loose_ref(name)
  1066. if orig_ref is None:
  1067. orig_ref = self.get_packed_refs().get(name)
  1068. if orig_ref is None:
  1069. orig_ref = ZERO_SHA
  1070. if orig_ref != old_ref:
  1071. return False
  1072. # remove the reference file itself
  1073. try:
  1074. found = os.path.lexists(filename)
  1075. except OSError:
  1076. # may only be packed, or otherwise unstorable
  1077. found = False
  1078. if found:
  1079. os.remove(filename)
  1080. self._remove_packed_ref(name)
  1081. self._log(
  1082. name,
  1083. old_ref,
  1084. None,
  1085. committer=committer,
  1086. timestamp=timestamp,
  1087. timezone=timezone,
  1088. message=message,
  1089. )
  1090. finally:
  1091. # never write, we just wanted the lock
  1092. f.abort()
  1093. # outside of the lock, clean-up any parent directory that might now
  1094. # be empty. this ensures that re-creating a reference of the same
  1095. # name of what was previously a directory works as expected
  1096. parent = name
  1097. while True:
  1098. try:
  1099. parent_bytes, _ = parent.rsplit(b"/", 1)
  1100. parent = Ref(parent_bytes)
  1101. except ValueError:
  1102. break
  1103. if parent == b"refs":
  1104. break
  1105. parent_filename = self.refpath(parent)
  1106. try:
  1107. os.rmdir(parent_filename)
  1108. except OSError:
  1109. # this can be caused by the parent directory being
  1110. # removed by another process, being not empty, etc.
  1111. # in any case, this is non fatal because we already
  1112. # removed the reference, just ignore it
  1113. break
  1114. return True
  1115. def pack_refs(self, all: bool = False) -> None:
  1116. """Pack loose refs into packed-refs file.
  1117. Args:
  1118. all: If True, pack all refs. If False, only pack tags.
  1119. """
  1120. refs_to_pack: dict[Ref, ObjectID | None] = {}
  1121. for ref in self.allkeys():
  1122. if ref == HEADREF:
  1123. # Never pack HEAD
  1124. continue
  1125. if all or ref.startswith(LOCAL_TAG_PREFIX):
  1126. try:
  1127. sha = self[ref]
  1128. if sha:
  1129. refs_to_pack[ref] = sha
  1130. except KeyError:
  1131. # Broken ref, skip it
  1132. pass
  1133. if refs_to_pack:
  1134. self.add_packed_refs(refs_to_pack)
  1135. def _split_ref_line(line: bytes) -> tuple[ObjectID, Ref]:
  1136. """Split a single ref line into a tuple of SHA1 and name."""
  1137. fields = line.rstrip(b"\n\r").split(b" ")
  1138. if len(fields) != 2:
  1139. raise PackedRefsException(f"invalid ref line {line!r}")
  1140. sha, name = fields
  1141. if not valid_hexsha(sha):
  1142. raise PackedRefsException(f"Invalid hex sha {sha!r}")
  1143. if not check_ref_format(Ref(name)):
  1144. raise PackedRefsException(f"invalid ref name {name!r}")
  1145. return (ObjectID(sha), Ref(name))
  1146. def read_packed_refs(f: IO[bytes]) -> Iterator[tuple[ObjectID, Ref]]:
  1147. """Read a packed refs file.
  1148. Args:
  1149. f: file-like object to read from
  1150. Returns: Iterator over tuples with SHA1s and ref names.
  1151. """
  1152. for line in f:
  1153. if line.startswith(b"#"):
  1154. # Comment
  1155. continue
  1156. if line.startswith(b"^"):
  1157. raise PackedRefsException("found peeled ref in packed-refs without peeled")
  1158. yield _split_ref_line(line)
  1159. def read_packed_refs_with_peeled(
  1160. f: IO[bytes],
  1161. ) -> Iterator[tuple[ObjectID, Ref, ObjectID | None]]:
  1162. """Read a packed refs file including peeled refs.
  1163. Assumes the "# pack-refs with: peeled" line was already read. Yields tuples
  1164. with ref names, SHA1s, and peeled SHA1s (or None).
  1165. Args:
  1166. f: file-like object to read from, seek'ed to the second line
  1167. """
  1168. last = None
  1169. for line in f:
  1170. if line.startswith(b"#"):
  1171. continue
  1172. line = line.rstrip(b"\r\n")
  1173. if line.startswith(b"^"):
  1174. if not last:
  1175. raise PackedRefsException("unexpected peeled ref line")
  1176. if not valid_hexsha(line[1:]):
  1177. raise PackedRefsException(f"Invalid hex sha {line[1:]!r}")
  1178. sha, name = _split_ref_line(last)
  1179. last = None
  1180. yield (sha, name, ObjectID(line[1:]))
  1181. else:
  1182. if last:
  1183. sha, name = _split_ref_line(last)
  1184. yield (sha, name, None)
  1185. last = line
  1186. if last:
  1187. sha, name = _split_ref_line(last)
  1188. yield (sha, name, None)
  1189. def write_packed_refs(
  1190. f: IO[bytes],
  1191. packed_refs: Mapping[Ref, ObjectID],
  1192. peeled_refs: Mapping[Ref, ObjectID] | None = None,
  1193. ) -> None:
  1194. """Write a packed refs file.
  1195. Args:
  1196. f: empty file-like object to write to
  1197. packed_refs: dict of refname to sha of packed refs to write
  1198. peeled_refs: dict of refname to peeled value of sha
  1199. """
  1200. if peeled_refs is None:
  1201. peeled_refs = {}
  1202. else:
  1203. f.write(b"# pack-refs with: peeled\n")
  1204. for refname in sorted(packed_refs.keys()):
  1205. f.write(git_line(packed_refs[refname], refname))
  1206. if refname in peeled_refs:
  1207. f.write(b"^" + peeled_refs[refname] + b"\n")
  1208. def read_info_refs(f: BinaryIO) -> dict[Ref, ObjectID]:
  1209. """Read info/refs file.
  1210. Args:
  1211. f: File-like object to read from
  1212. Returns:
  1213. Dictionary mapping ref names to SHA1s
  1214. """
  1215. ret: dict[Ref, ObjectID] = {}
  1216. for line in f.readlines():
  1217. (sha, name) = line.rstrip(b"\r\n").split(b"\t", 1)
  1218. ret[Ref(name)] = ObjectID(sha)
  1219. return ret
  1220. def is_local_branch(x: bytes) -> bool:
  1221. """Check if a ref name is a local branch."""
  1222. return x.startswith(LOCAL_BRANCH_PREFIX)
  1223. def local_branch_name(name: bytes) -> Ref:
  1224. """Build a full branch ref from a short name.
  1225. Args:
  1226. name: Short branch name (e.g., b"master") or full ref
  1227. Returns:
  1228. Full branch ref name (e.g., b"refs/heads/master")
  1229. Examples:
  1230. >>> local_branch_name(b"master")
  1231. b'refs/heads/master'
  1232. >>> local_branch_name(b"refs/heads/master")
  1233. b'refs/heads/master'
  1234. """
  1235. if name.startswith(LOCAL_BRANCH_PREFIX):
  1236. return Ref(name)
  1237. return Ref(LOCAL_BRANCH_PREFIX + name)
  1238. def local_tag_name(name: bytes) -> Ref:
  1239. """Build a full tag ref from a short name.
  1240. Args:
  1241. name: Short tag name (e.g., b"v1.0") or full ref
  1242. Returns:
  1243. Full tag ref name (e.g., b"refs/tags/v1.0")
  1244. Examples:
  1245. >>> local_tag_name(b"v1.0")
  1246. b'refs/tags/v1.0'
  1247. >>> local_tag_name(b"refs/tags/v1.0")
  1248. b'refs/tags/v1.0'
  1249. """
  1250. if name.startswith(LOCAL_TAG_PREFIX):
  1251. return Ref(name)
  1252. return Ref(LOCAL_TAG_PREFIX + name)
  1253. def local_replace_name(name: bytes) -> Ref:
  1254. """Build a full replace ref from a short name.
  1255. Args:
  1256. name: Short replace name (object SHA) or full ref
  1257. Returns:
  1258. Full replace ref name (e.g., b"refs/replace/<sha>")
  1259. Examples:
  1260. >>> local_replace_name(b"abc123")
  1261. b'refs/replace/abc123'
  1262. >>> local_replace_name(b"refs/replace/abc123")
  1263. b'refs/replace/abc123'
  1264. """
  1265. if name.startswith(LOCAL_REPLACE_PREFIX):
  1266. return Ref(name)
  1267. return Ref(LOCAL_REPLACE_PREFIX + name)
  1268. def extract_branch_name(ref: bytes) -> bytes:
  1269. """Extract branch name from a full branch ref.
  1270. Args:
  1271. ref: Full branch ref (e.g., b"refs/heads/master")
  1272. Returns:
  1273. Short branch name (e.g., b"master")
  1274. Raises:
  1275. ValueError: If ref is not a local branch
  1276. Examples:
  1277. >>> extract_branch_name(b"refs/heads/master")
  1278. b'master'
  1279. >>> extract_branch_name(b"refs/heads/feature/foo")
  1280. b'feature/foo'
  1281. """
  1282. if not ref.startswith(LOCAL_BRANCH_PREFIX):
  1283. raise ValueError(f"Not a local branch ref: {ref!r}")
  1284. return ref[len(LOCAL_BRANCH_PREFIX) :]
  1285. def extract_tag_name(ref: bytes) -> bytes:
  1286. """Extract tag name from a full tag ref.
  1287. Args:
  1288. ref: Full tag ref (e.g., b"refs/tags/v1.0")
  1289. Returns:
  1290. Short tag name (e.g., b"v1.0")
  1291. Raises:
  1292. ValueError: If ref is not a local tag
  1293. Examples:
  1294. >>> extract_tag_name(b"refs/tags/v1.0")
  1295. b'v1.0'
  1296. """
  1297. if not ref.startswith(LOCAL_TAG_PREFIX):
  1298. raise ValueError(f"Not a local tag ref: {ref!r}")
  1299. return ref[len(LOCAL_TAG_PREFIX) :]
  1300. def shorten_ref_name(ref: bytes) -> bytes:
  1301. """Convert a full ref name to its short form.
  1302. Args:
  1303. ref: Full ref name (e.g., b"refs/heads/master")
  1304. Returns:
  1305. Short ref name (e.g., b"master")
  1306. Examples:
  1307. >>> shorten_ref_name(b"refs/heads/master")
  1308. b'master'
  1309. >>> shorten_ref_name(b"refs/remotes/origin/main")
  1310. b'origin/main'
  1311. >>> shorten_ref_name(b"refs/tags/v1.0")
  1312. b'v1.0'
  1313. >>> shorten_ref_name(b"HEAD")
  1314. b'HEAD'
  1315. """
  1316. if ref.startswith(LOCAL_BRANCH_PREFIX):
  1317. return ref[len(LOCAL_BRANCH_PREFIX) :]
  1318. elif ref.startswith(LOCAL_REMOTE_PREFIX):
  1319. return ref[len(LOCAL_REMOTE_PREFIX) :]
  1320. elif ref.startswith(LOCAL_TAG_PREFIX):
  1321. return ref[len(LOCAL_TAG_PREFIX) :]
  1322. return ref
  1323. def _set_origin_head(
  1324. refs: RefsContainer, origin: bytes, origin_head: bytes | None
  1325. ) -> None:
  1326. # set refs/remotes/origin/HEAD
  1327. origin_base = b"refs/remotes/" + origin + b"/"
  1328. if origin_head and origin_head.startswith(LOCAL_BRANCH_PREFIX):
  1329. origin_ref = Ref(origin_base + HEADREF)
  1330. target_ref = Ref(origin_base + extract_branch_name(origin_head))
  1331. if target_ref in refs:
  1332. refs.set_symbolic_ref(origin_ref, target_ref)
  1333. def _set_default_branch(
  1334. refs: RefsContainer,
  1335. origin: bytes,
  1336. origin_head: bytes | None,
  1337. branch: bytes | None,
  1338. ref_message: bytes | None,
  1339. ) -> bytes:
  1340. """Set the default branch."""
  1341. origin_base = b"refs/remotes/" + origin + b"/"
  1342. if branch:
  1343. origin_ref = Ref(origin_base + branch)
  1344. if origin_ref in refs:
  1345. local_ref = Ref(local_branch_name(branch))
  1346. refs.add_if_new(local_ref, refs[origin_ref], ref_message)
  1347. head_ref = local_ref
  1348. elif Ref(local_tag_name(branch)) in refs:
  1349. head_ref = Ref(local_tag_name(branch))
  1350. else:
  1351. raise ValueError(f"{os.fsencode(branch)!r} is not a valid branch or tag")
  1352. elif origin_head:
  1353. head_ref = Ref(origin_head)
  1354. if origin_head.startswith(LOCAL_BRANCH_PREFIX):
  1355. origin_ref = Ref(origin_base + extract_branch_name(origin_head))
  1356. else:
  1357. origin_ref = Ref(origin_head)
  1358. try:
  1359. refs.add_if_new(head_ref, refs[origin_ref], ref_message)
  1360. except KeyError:
  1361. pass
  1362. else:
  1363. raise ValueError("neither origin_head nor branch are provided")
  1364. return head_ref
  1365. def _set_head(
  1366. refs: RefsContainer, head_ref: bytes, ref_message: bytes | None
  1367. ) -> ObjectID | None:
  1368. if head_ref.startswith(LOCAL_TAG_PREFIX):
  1369. # detach HEAD at specified tag
  1370. head = refs[Ref(head_ref)]
  1371. if isinstance(head, Tag):
  1372. _cls, obj = head.object
  1373. head = obj.get_object(obj).id
  1374. del refs[HEADREF]
  1375. refs.set_if_equals(HEADREF, None, head, message=ref_message)
  1376. else:
  1377. # set HEAD to specific branch
  1378. try:
  1379. head = refs[Ref(head_ref)]
  1380. refs.set_symbolic_ref(HEADREF, Ref(head_ref))
  1381. refs.set_if_equals(HEADREF, None, head, message=ref_message)
  1382. except KeyError:
  1383. head = None
  1384. return head
  1385. def _import_remote_refs(
  1386. refs_container: RefsContainer,
  1387. remote_name: str,
  1388. refs: Mapping[Ref, ObjectID | None],
  1389. message: bytes | None = None,
  1390. prune: bool = False,
  1391. prune_tags: bool = False,
  1392. ) -> None:
  1393. from .protocol import PEELED_TAG_SUFFIX, strip_peeled_refs
  1394. stripped_refs = strip_peeled_refs(refs)
  1395. branches: dict[Ref, ObjectID | None] = {
  1396. Ref(extract_branch_name(n)): v
  1397. for (n, v) in stripped_refs.items()
  1398. if n.startswith(LOCAL_BRANCH_PREFIX)
  1399. }
  1400. refs_container.import_refs(
  1401. Ref(b"refs/remotes/" + remote_name.encode()),
  1402. branches,
  1403. message=message,
  1404. prune=prune,
  1405. )
  1406. tags: dict[Ref, ObjectID | None] = {
  1407. Ref(extract_tag_name(n)): v
  1408. for (n, v) in stripped_refs.items()
  1409. if n.startswith(LOCAL_TAG_PREFIX) and not n.endswith(PEELED_TAG_SUFFIX)
  1410. }
  1411. refs_container.import_refs(
  1412. Ref(LOCAL_TAG_PREFIX), tags, message=message, prune=prune_tags
  1413. )
  1414. class locked_ref:
  1415. """Lock a ref while making modifications.
  1416. Works as a context manager.
  1417. """
  1418. def __init__(self, refs_container: DiskRefsContainer, refname: Ref) -> None:
  1419. """Initialize a locked ref.
  1420. Args:
  1421. refs_container: The DiskRefsContainer to lock the ref in
  1422. refname: The ref name to lock
  1423. """
  1424. self._refs_container = refs_container
  1425. self._refname = refname
  1426. self._file: _GitFile | None = None
  1427. self._realname: Ref | None = None
  1428. self._deleted = False
  1429. def __enter__(self) -> "locked_ref":
  1430. """Enter the context manager and acquire the lock.
  1431. Returns:
  1432. This locked_ref instance
  1433. Raises:
  1434. OSError: If the lock cannot be acquired
  1435. """
  1436. self._refs_container._check_refname(self._refname)
  1437. try:
  1438. realnames, _ = self._refs_container.follow(self._refname)
  1439. self._realname = realnames[-1]
  1440. except (KeyError, IndexError, SymrefLoop):
  1441. self._realname = self._refname
  1442. filename = self._refs_container.refpath(self._realname)
  1443. ensure_dir_exists(os.path.dirname(filename))
  1444. f = GitFile(filename, "wb")
  1445. self._file = f
  1446. return self
  1447. def __exit__(
  1448. self,
  1449. exc_type: type | None,
  1450. exc_value: BaseException | None,
  1451. traceback: types.TracebackType | None,
  1452. ) -> None:
  1453. """Exit the context manager and release the lock.
  1454. Args:
  1455. exc_type: Type of exception if one occurred
  1456. exc_value: Exception instance if one occurred
  1457. traceback: Traceback if an exception occurred
  1458. """
  1459. if self._file:
  1460. if exc_type is not None or self._deleted:
  1461. self._file.abort()
  1462. else:
  1463. self._file.close()
  1464. def get(self) -> bytes | None:
  1465. """Get the current value of the ref."""
  1466. if not self._file:
  1467. raise RuntimeError("locked_ref not in context")
  1468. assert self._realname is not None
  1469. current_ref = self._refs_container.read_loose_ref(self._realname)
  1470. if current_ref is None:
  1471. current_ref = self._refs_container.get_packed_refs().get(
  1472. self._realname, None
  1473. )
  1474. return current_ref
  1475. def ensure_equals(self, expected_value: bytes | None) -> bool:
  1476. """Ensure the ref currently equals the expected value.
  1477. Args:
  1478. expected_value: The expected current value of the ref
  1479. Returns:
  1480. True if the ref equals the expected value, False otherwise
  1481. """
  1482. current_value = self.get()
  1483. return current_value == expected_value
  1484. def set(self, new_ref: bytes) -> None:
  1485. """Set the ref to a new value.
  1486. Args:
  1487. new_ref: The new SHA1 or symbolic ref value
  1488. """
  1489. if not self._file:
  1490. raise RuntimeError("locked_ref not in context")
  1491. if not (valid_hexsha(new_ref) or new_ref.startswith(SYMREF)):
  1492. raise ValueError(f"{new_ref!r} must be a valid sha (40 chars) or a symref")
  1493. self._file.seek(0)
  1494. self._file.truncate()
  1495. self._file.write(new_ref + b"\n")
  1496. self._deleted = False
  1497. def set_symbolic_ref(self, target: Ref) -> None:
  1498. """Make this ref point at another ref.
  1499. Args:
  1500. target: Name of the ref to point at
  1501. """
  1502. if not self._file:
  1503. raise RuntimeError("locked_ref not in context")
  1504. self._refs_container._check_refname(target)
  1505. self._file.seek(0)
  1506. self._file.truncate()
  1507. self._file.write(SYMREF + target + b"\n")
  1508. self._deleted = False
  1509. def delete(self) -> None:
  1510. """Delete the ref file while holding the lock."""
  1511. if not self._file:
  1512. raise RuntimeError("locked_ref not in context")
  1513. # Delete the actual ref file while holding the lock
  1514. if self._realname:
  1515. filename = self._refs_container.refpath(self._realname)
  1516. try:
  1517. if os.path.lexists(filename):
  1518. os.remove(filename)
  1519. except FileNotFoundError:
  1520. pass
  1521. self._refs_container._remove_packed_ref(self._realname)
  1522. self._deleted = True
  1523. class NamespacedRefsContainer(RefsContainer):
  1524. """Wrapper that adds namespace prefix to all ref operations.
  1525. This implements Git's GIT_NAMESPACE feature, which stores refs under
  1526. refs/namespaces/<namespace>/ and filters operations to only show refs
  1527. within that namespace.
  1528. Example:
  1529. With namespace "foo", a ref "refs/heads/master" is stored as
  1530. "refs/namespaces/foo/refs/heads/master" in the underlying container.
  1531. """
  1532. def __init__(self, refs: RefsContainer, namespace: bytes) -> None:
  1533. """Initialize NamespacedRefsContainer.
  1534. Args:
  1535. refs: The underlying refs container to wrap
  1536. namespace: The namespace prefix (e.g., b"foo" or b"foo/bar")
  1537. """
  1538. super().__init__(logger=refs._logger)
  1539. self._refs = refs
  1540. # Build namespace prefix: refs/namespaces/<namespace>/
  1541. # Support nested namespaces: foo/bar -> refs/namespaces/foo/refs/namespaces/bar/
  1542. namespace_parts = namespace.split(b"/")
  1543. self._namespace_prefix = b""
  1544. for part in namespace_parts:
  1545. self._namespace_prefix += b"refs/namespaces/" + part + b"/"
  1546. def _apply_namespace(self, name: bytes) -> bytes:
  1547. """Apply namespace prefix to a ref name."""
  1548. # HEAD and other special refs are not namespaced
  1549. if name == HEADREF or not name.startswith(b"refs/"):
  1550. return name
  1551. return self._namespace_prefix + name
  1552. def _strip_namespace(self, name: bytes) -> bytes | None:
  1553. """Remove namespace prefix from a ref name.
  1554. Returns None if the ref is not in our namespace.
  1555. """
  1556. # HEAD and other special refs are not namespaced
  1557. if name == HEADREF or not name.startswith(b"refs/"):
  1558. return name
  1559. if name.startswith(self._namespace_prefix):
  1560. return name[len(self._namespace_prefix) :]
  1561. return None
  1562. def allkeys(self) -> set[Ref]:
  1563. """Return all reference keys in this namespace."""
  1564. keys: set[Ref] = set()
  1565. for key in self._refs.allkeys():
  1566. stripped = self._strip_namespace(key)
  1567. if stripped is not None:
  1568. keys.add(Ref(stripped))
  1569. return keys
  1570. def read_loose_ref(self, name: Ref) -> bytes | None:
  1571. """Read a loose reference."""
  1572. return self._refs.read_loose_ref(Ref(self._apply_namespace(name)))
  1573. def get_packed_refs(self) -> dict[Ref, ObjectID]:
  1574. """Get packed refs within this namespace."""
  1575. packed: dict[Ref, ObjectID] = {}
  1576. for name, value in self._refs.get_packed_refs().items():
  1577. stripped = self._strip_namespace(name)
  1578. if stripped is not None:
  1579. packed[Ref(stripped)] = value
  1580. return packed
  1581. def add_packed_refs(self, new_refs: Mapping[Ref, ObjectID | None]) -> None:
  1582. """Add packed refs with namespace prefix."""
  1583. namespaced_refs: dict[Ref, ObjectID | None] = {
  1584. Ref(self._apply_namespace(name)): value for name, value in new_refs.items()
  1585. }
  1586. self._refs.add_packed_refs(namespaced_refs)
  1587. def get_peeled(self, name: Ref) -> ObjectID | None:
  1588. """Return the cached peeled value of a ref."""
  1589. return self._refs.get_peeled(Ref(self._apply_namespace(name)))
  1590. def set_symbolic_ref(
  1591. self,
  1592. name: Ref,
  1593. other: Ref,
  1594. committer: bytes | None = None,
  1595. timestamp: int | None = None,
  1596. timezone: int | None = None,
  1597. message: bytes | None = None,
  1598. ) -> None:
  1599. """Make a ref point at another ref."""
  1600. self._refs.set_symbolic_ref(
  1601. Ref(self._apply_namespace(name)),
  1602. Ref(self._apply_namespace(other)),
  1603. committer=committer,
  1604. timestamp=timestamp,
  1605. timezone=timezone,
  1606. message=message,
  1607. )
  1608. def set_if_equals(
  1609. self,
  1610. name: Ref,
  1611. old_ref: ObjectID | None,
  1612. new_ref: ObjectID,
  1613. committer: bytes | None = None,
  1614. timestamp: int | None = None,
  1615. timezone: int | None = None,
  1616. message: bytes | None = None,
  1617. ) -> bool:
  1618. """Set a refname to new_ref only if it currently equals old_ref."""
  1619. return self._refs.set_if_equals(
  1620. Ref(self._apply_namespace(name)),
  1621. old_ref,
  1622. new_ref,
  1623. committer=committer,
  1624. timestamp=timestamp,
  1625. timezone=timezone,
  1626. message=message,
  1627. )
  1628. def add_if_new(
  1629. self,
  1630. name: Ref,
  1631. ref: ObjectID,
  1632. committer: bytes | None = None,
  1633. timestamp: int | None = None,
  1634. timezone: int | None = None,
  1635. message: bytes | None = None,
  1636. ) -> bool:
  1637. """Add a new reference only if it does not already exist."""
  1638. return self._refs.add_if_new(
  1639. Ref(self._apply_namespace(name)),
  1640. ref,
  1641. committer=committer,
  1642. timestamp=timestamp,
  1643. timezone=timezone,
  1644. message=message,
  1645. )
  1646. def remove_if_equals(
  1647. self,
  1648. name: Ref,
  1649. old_ref: ObjectID | None,
  1650. committer: bytes | None = None,
  1651. timestamp: int | None = None,
  1652. timezone: int | None = None,
  1653. message: bytes | None = None,
  1654. ) -> bool:
  1655. """Remove a refname only if it currently equals old_ref."""
  1656. return self._refs.remove_if_equals(
  1657. Ref(self._apply_namespace(name)),
  1658. old_ref,
  1659. committer=committer,
  1660. timestamp=timestamp,
  1661. timezone=timezone,
  1662. message=message,
  1663. )
  1664. def pack_refs(self, all: bool = False) -> None:
  1665. """Pack loose refs into packed-refs file.
  1666. Note: This packs all refs in the underlying container, not just
  1667. those in the namespace.
  1668. """
  1669. self._refs.pack_refs(all=all)
  1670. def filter_ref_prefix(refs: T, prefixes: Iterable[bytes]) -> T:
  1671. """Filter refs to only include those with a given prefix.
  1672. Args:
  1673. refs: A dictionary of refs.
  1674. prefixes: The prefixes to filter by.
  1675. """
  1676. filtered = {k: v for k, v in refs.items() if any(k.startswith(p) for p in prefixes)}
  1677. return filtered
  1678. def is_per_worktree_ref(ref: bytes) -> bool:
  1679. """Returns whether a reference is stored per worktree or not.
  1680. Per-worktree references are:
  1681. - all pseudorefs, e.g. HEAD
  1682. - all references stored inside "refs/bisect/", "refs/worktree/" and "refs/rewritten/"
  1683. All refs starting with "refs/" are shared, except for the ones listed above.
  1684. See https://git-scm.com/docs/git-worktree#_refs.
  1685. """
  1686. return not ref.startswith(b"refs/") or ref.startswith(
  1687. (b"refs/bisect/", b"refs/worktree/", b"refs/rewritten/")
  1688. )