stash.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # stash.py
  2. # Copyright (C) 2018 Jelmer Vernooij <jelmer@samba.org>
  3. #
  4. # SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
  5. # Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
  6. # General Public License as public by the Free Software Foundation; version 2.0
  7. # or (at your option) any later version. You can redistribute it and/or
  8. # modify it under the terms of either of these two licenses.
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. #
  16. # You should have received a copy of the licenses; if not, see
  17. # <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
  18. # and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
  19. # License, Version 2.0.
  20. #
  21. """Stash handling."""
  22. import os
  23. from .file import GitFile
  24. from .index import commit_tree, iter_fresh_objects
  25. from .reflog import drop_reflog_entry, read_reflog
  26. DEFAULT_STASH_REF = b"refs/stash"
  27. class Stash:
  28. """A Git stash.
  29. Note that this doesn't currently update the working tree.
  30. """
  31. def __init__(self, repo, ref=DEFAULT_STASH_REF) -> None:
  32. self._ref = ref
  33. self._repo = repo
  34. @property
  35. def _reflog_path(self):
  36. return os.path.join(self._repo.commondir(), "logs", os.fsdecode(self._ref))
  37. def stashes(self):
  38. try:
  39. with GitFile(self._reflog_path, "rb") as f:
  40. return reversed(list(read_reflog(f)))
  41. except FileNotFoundError:
  42. return []
  43. @classmethod
  44. def from_repo(cls, repo):
  45. """Create a new stash from a Repo object."""
  46. return cls(repo)
  47. def drop(self, index) -> None:
  48. """Drop entry with specified index."""
  49. with open(self._reflog_path, "rb+") as f:
  50. drop_reflog_entry(f, index, rewrite=True)
  51. if len(self) == 0:
  52. os.remove(self._reflog_path)
  53. del self._repo.refs[self._ref]
  54. return
  55. if index == 0:
  56. self._repo.refs[self._ref] = self[0].new_sha
  57. def pop(self, index):
  58. raise NotImplementedError(self.pop)
  59. def push(self, committer=None, author=None, message=None):
  60. """Create a new stash.
  61. Args:
  62. committer: Optional committer name to use
  63. author: Optional author name to use
  64. message: Optional commit message
  65. """
  66. # First, create the index commit.
  67. commit_kwargs = {}
  68. if committer is not None:
  69. commit_kwargs["committer"] = committer
  70. if author is not None:
  71. commit_kwargs["author"] = author
  72. index = self._repo.open_index()
  73. index_tree_id = index.commit(self._repo.object_store)
  74. index_commit_id = self._repo.do_commit(
  75. ref=None,
  76. tree=index_tree_id,
  77. message=b"Index stash",
  78. merge_heads=[self._repo.head()],
  79. no_verify=True,
  80. **commit_kwargs,
  81. )
  82. # Then, the working tree one.
  83. stash_tree_id = commit_tree(
  84. self._repo.object_store,
  85. iter_fresh_objects(
  86. index,
  87. os.fsencode(self._repo.path),
  88. object_store=self._repo.object_store,
  89. ),
  90. )
  91. if message is None:
  92. message = b"A stash on " + self._repo.head()
  93. # TODO(jelmer): Just pass parents into do_commit()?
  94. self._repo.refs[self._ref] = self._repo.head()
  95. cid = self._repo.do_commit(
  96. ref=self._ref,
  97. tree=stash_tree_id,
  98. message=message,
  99. merge_heads=[index_commit_id],
  100. no_verify=True,
  101. **commit_kwargs,
  102. )
  103. return cid
  104. def __getitem__(self, index):
  105. return list(self.stashes())[index]
  106. def __len__(self) -> int:
  107. return len(list(self.stashes()))