|
|
@@ -105,7 +105,6 @@ from typing import (
|
|
|
TYPE_CHECKING,
|
|
|
Any,
|
|
|
BinaryIO,
|
|
|
- Optional,
|
|
|
TextIO,
|
|
|
TypedDict,
|
|
|
TypeVar,
|
|
|
@@ -114,10 +113,11 @@ from typing import (
|
|
|
)
|
|
|
|
|
|
if sys.version_info >= (3, 12):
|
|
|
- from collections.abc import Buffer
|
|
|
from typing import override
|
|
|
else:
|
|
|
- from typing_extensions import Buffer, override
|
|
|
+ from typing_extensions import override
|
|
|
+
|
|
|
+from ._typing import Buffer
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
import urllib3
|
|
|
@@ -198,6 +198,7 @@ from .refs import (
|
|
|
LOCAL_REMOTE_PREFIX,
|
|
|
LOCAL_REPLACE_PREFIX,
|
|
|
LOCAL_TAG_PREFIX,
|
|
|
+ DictRefsContainer,
|
|
|
Ref,
|
|
|
SymrefLoop,
|
|
|
_import_remote_refs,
|
|
|
@@ -205,6 +206,7 @@ from .refs import (
|
|
|
local_branch_name,
|
|
|
local_replace_name,
|
|
|
local_tag_name,
|
|
|
+ parse_remote_ref,
|
|
|
shorten_ref_name,
|
|
|
)
|
|
|
from .repo import BaseRepo, Repo, get_user_identity
|
|
|
@@ -244,7 +246,7 @@ class TransportKwargs(TypedDict, total=False):
|
|
|
password: str | None
|
|
|
key_filename: str | None
|
|
|
ssh_command: str | None
|
|
|
- pool_manager: Optional["urllib3.PoolManager"]
|
|
|
+ pool_manager: "urllib3.PoolManager | None"
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
@@ -281,14 +283,25 @@ class NoneStream(RawIOBase):
|
|
|
"""
|
|
|
return b""
|
|
|
|
|
|
- @override
|
|
|
- def readinto(self, b: Buffer) -> int | None:
|
|
|
- return 0
|
|
|
+ if sys.version_info >= (3, 12):
|
|
|
+
|
|
|
+ @override
|
|
|
+ def readinto(self, b: Buffer) -> int | None:
|
|
|
+ return 0
|
|
|
+
|
|
|
+ @override
|
|
|
+ def write(self, b: Buffer) -> int | None:
|
|
|
+ return len(cast(bytes, b)) if b else 0
|
|
|
+
|
|
|
+ else:
|
|
|
+
|
|
|
+ @override
|
|
|
+ def readinto(self, b: bytearray | memoryview) -> int | None: # type: ignore[override]
|
|
|
+ return 0
|
|
|
|
|
|
- @override
|
|
|
- def write(self, b: Buffer) -> int | None:
|
|
|
- # All Buffer implementations (bytes, bytearray, memoryview) support len()
|
|
|
- return len(b) if b else 0 # type: ignore[arg-type]
|
|
|
+ @override
|
|
|
+ def write(self, b: bytes | bytearray | memoryview) -> int | None: # type: ignore[override]
|
|
|
+ return len(b) if b else 0
|
|
|
|
|
|
|
|
|
default_bytes_out_stream: BinaryIO = cast(
|
|
|
@@ -630,8 +643,6 @@ def _get_variables(repo: RepoPath = ".") -> dict[str, str]:
|
|
|
Returns:
|
|
|
A dictionary of all logical variables with values
|
|
|
"""
|
|
|
- from .repo import get_user_identity
|
|
|
-
|
|
|
with open_repo_closing(repo) as repo_obj:
|
|
|
config = repo_obj.get_config_stack()
|
|
|
|
|
|
@@ -827,8 +838,6 @@ def commit(
|
|
|
if normalizer is not None:
|
|
|
|
|
|
def filter_callback(data: bytes, path: bytes) -> bytes:
|
|
|
- from dulwich.objects import Blob
|
|
|
-
|
|
|
blob = Blob()
|
|
|
blob.data = data
|
|
|
normalized_blob = normalizer.checkin_normalize(blob, path)
|
|
|
@@ -1066,7 +1075,7 @@ def stripspace(
|
|
|
>>> stripspace(b"line\\n", comment_lines=True)
|
|
|
b'# line\\n'
|
|
|
"""
|
|
|
- from dulwich.stripspace import stripspace as _stripspace
|
|
|
+ from .stripspace import stripspace as _stripspace
|
|
|
|
|
|
# Convert text to bytes
|
|
|
if isinstance(text, str):
|
|
|
@@ -1290,8 +1299,6 @@ def add(
|
|
|
if normalizer is not None:
|
|
|
|
|
|
def filter_callback(data: bytes, path: bytes) -> bytes:
|
|
|
- from dulwich.objects import Blob
|
|
|
-
|
|
|
blob = Blob()
|
|
|
blob.data = data
|
|
|
normalized_blob = normalizer.checkin_normalize(blob, path)
|
|
|
@@ -2838,42 +2845,6 @@ def reset(
|
|
|
|
|
|
elif mode == "hard":
|
|
|
# Hard reset: update HEAD, index, and working tree
|
|
|
- # Get configuration for working directory update
|
|
|
- config = r.get_config()
|
|
|
- honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
|
|
|
-
|
|
|
- if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):
|
|
|
- validate_path_element = validate_path_element_ntfs
|
|
|
- elif config.get_boolean(
|
|
|
- b"core", b"core.protectHFS", sys.platform == "darwin"
|
|
|
- ):
|
|
|
- validate_path_element = validate_path_element_hfs
|
|
|
- else:
|
|
|
- validate_path_element = validate_path_element_default
|
|
|
-
|
|
|
- if config.get_boolean(b"core", b"symlinks", True):
|
|
|
-
|
|
|
- def symlink_wrapper(
|
|
|
- source: str | bytes | os.PathLike[str],
|
|
|
- target: str | bytes | os.PathLike[str],
|
|
|
- ) -> None:
|
|
|
- symlink(source, target) # type: ignore[arg-type,unused-ignore]
|
|
|
-
|
|
|
- symlink_fn = symlink_wrapper
|
|
|
- else:
|
|
|
-
|
|
|
- def symlink_fallback(
|
|
|
- source: str | bytes | os.PathLike[str],
|
|
|
- target: str | bytes | os.PathLike[str],
|
|
|
- ) -> None:
|
|
|
- mode = "w" + ("b" if isinstance(source, bytes) else "")
|
|
|
- with open(target, mode) as f:
|
|
|
- f.write(source)
|
|
|
-
|
|
|
- symlink_fn = symlink_fallback
|
|
|
-
|
|
|
- # Update working tree and index
|
|
|
- blob_normalizer = r.get_blob_normalizer()
|
|
|
# For reset --hard, use current index tree as old tree to get proper deletions
|
|
|
index = r.open_index()
|
|
|
if len(index) > 0:
|
|
|
@@ -2882,6 +2853,12 @@ def reset(
|
|
|
# Empty index
|
|
|
index_tree_id = None
|
|
|
|
|
|
+ # Get configuration for working tree updates
|
|
|
+ honor_filemode, validate_path_element, symlink_fn = (
|
|
|
+ _get_worktree_update_config(r)
|
|
|
+ )
|
|
|
+
|
|
|
+ blob_normalizer = r.get_blob_normalizer()
|
|
|
changes = tree_changes(
|
|
|
r.object_store, index_tree_id, tree.id, want_unchanged=True
|
|
|
)
|
|
|
@@ -3004,8 +2981,6 @@ def push(
|
|
|
remote_changed_refs: dict[bytes, bytes | None] = {}
|
|
|
|
|
|
def update_refs(refs: dict[bytes, bytes]) -> dict[bytes, bytes]:
|
|
|
- from .refs import DictRefsContainer
|
|
|
-
|
|
|
remote_refs = DictRefsContainer(refs)
|
|
|
selected_refs.extend(
|
|
|
parse_reftuples(r.refs, remote_refs, refspecs_bytes, force=force)
|
|
|
@@ -3074,10 +3049,14 @@ def push(
|
|
|
for ref, error in (result.ref_status or {}).items():
|
|
|
if error is not None:
|
|
|
errstream.write(
|
|
|
- b"Push of ref %s failed: %s\n" % (ref, error.encode(err_encoding))
|
|
|
+ f"Push of ref {ref.decode('utf-8', 'replace')} failed: {error}\n".encode(
|
|
|
+ err_encoding
|
|
|
+ )
|
|
|
)
|
|
|
else:
|
|
|
- errstream.write(b"Ref %s updated\n" % ref)
|
|
|
+ errstream.write(
|
|
|
+ f"Ref {ref.decode('utf-8', 'replace')} updated\n".encode()
|
|
|
+ )
|
|
|
|
|
|
if remote_name is not None:
|
|
|
_import_remote_refs(r.refs, remote_name, remote_changed_refs)
|
|
|
@@ -3148,8 +3127,6 @@ def pull(
|
|
|
def determine_wants(
|
|
|
remote_refs: dict[bytes, bytes], depth: int | None = None
|
|
|
) -> list[bytes]:
|
|
|
- from .refs import DictRefsContainer
|
|
|
-
|
|
|
remote_refs_container = DictRefsContainer(remote_refs)
|
|
|
selected_refs.extend(
|
|
|
parse_reftuples(
|
|
|
@@ -3286,18 +3263,17 @@ def status(
|
|
|
untracked - list of untracked, un-ignored & non-.git paths
|
|
|
"""
|
|
|
with open_repo_closing(repo) as r:
|
|
|
+ # Open the index once and reuse it for both staged and unstaged checks
|
|
|
+ index = r.open_index()
|
|
|
# 1. Get status of staged
|
|
|
- tracked_changes = get_tree_changes(r)
|
|
|
+ tracked_changes = get_tree_changes(r, index)
|
|
|
# 2. Get status of unstaged
|
|
|
- index = r.open_index()
|
|
|
normalizer = r.get_blob_normalizer()
|
|
|
|
|
|
# Create a wrapper that handles the bytes -> Blob conversion
|
|
|
if normalizer is not None:
|
|
|
|
|
|
def filter_callback(data: bytes, path: bytes) -> bytes:
|
|
|
- from dulwich.objects import Blob
|
|
|
-
|
|
|
blob = Blob()
|
|
|
blob.data = data
|
|
|
normalized_blob = normalizer.checkin_normalize(blob, path)
|
|
|
@@ -3684,15 +3660,19 @@ def grep(
|
|
|
outstream.write(f"{path_str}:{line_str}\n")
|
|
|
|
|
|
|
|
|
-def get_tree_changes(repo: RepoPath) -> dict[str, list[str | bytes]]:
|
|
|
+def get_tree_changes(
|
|
|
+ repo: RepoPath, index: Index | None = None
|
|
|
+) -> dict[str, list[str | bytes]]:
|
|
|
"""Return add/delete/modify changes to tree by comparing index to HEAD.
|
|
|
|
|
|
Args:
|
|
|
repo: repo path or object
|
|
|
+ index: optional Index object to reuse (avoids re-opening the index)
|
|
|
Returns: dict with lists for each type of change
|
|
|
"""
|
|
|
with open_repo_closing(repo) as r:
|
|
|
- index = r.open_index()
|
|
|
+ if index is None:
|
|
|
+ index = r.open_index()
|
|
|
|
|
|
# Compares the Index to the HEAD & determines changes
|
|
|
# Iterate through the changes and report add/delete/modify
|
|
|
@@ -4498,8 +4478,6 @@ def show_ref(
|
|
|
try:
|
|
|
obj = r.get_object(sha)
|
|
|
# Peel tag objects to get the underlying commit/object
|
|
|
- from .objects import Tag
|
|
|
-
|
|
|
while obj.type_name == b"tag":
|
|
|
assert isinstance(obj, Tag)
|
|
|
_obj_class, sha = obj.object
|
|
|
@@ -5079,6 +5057,168 @@ def check_ignore(
|
|
|
yield _quote_path(output_path) if quote_path else output_path
|
|
|
|
|
|
|
|
|
+def _get_current_head_tree(repo: Repo) -> bytes | None:
|
|
|
+ """Get the current HEAD tree ID.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ repo: Repository object
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Tree ID of current HEAD, or None if no HEAD exists (empty repo)
|
|
|
+ """
|
|
|
+ try:
|
|
|
+ current_head = repo.refs[b"HEAD"]
|
|
|
+ current_commit = repo[current_head]
|
|
|
+ assert isinstance(current_commit, Commit), "Expected a Commit object"
|
|
|
+ tree_id: bytes = current_commit.tree
|
|
|
+ return tree_id
|
|
|
+ except KeyError:
|
|
|
+ # No HEAD yet (empty repo)
|
|
|
+ return None
|
|
|
+
|
|
|
+
|
|
|
+def _check_uncommitted_changes(
|
|
|
+ repo: Repo, target_tree_id: bytes, force: bool = False
|
|
|
+) -> None:
|
|
|
+ """Check for uncommitted changes that would conflict with a checkout/switch.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ repo: Repository object
|
|
|
+ target_tree_id: Tree ID to check conflicts against
|
|
|
+ force: If True, skip the check
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ CheckoutError: If there are conflicting local changes
|
|
|
+ """
|
|
|
+ if force:
|
|
|
+ return
|
|
|
+
|
|
|
+ # Get current HEAD tree for comparison
|
|
|
+ current_tree_id = _get_current_head_tree(repo)
|
|
|
+ if current_tree_id is None:
|
|
|
+ # No HEAD yet (empty repo)
|
|
|
+ return
|
|
|
+
|
|
|
+ status_report = status(repo)
|
|
|
+ changes = []
|
|
|
+ # staged is a dict with 'add', 'delete', 'modify' keys
|
|
|
+ if isinstance(status_report.staged, dict):
|
|
|
+ changes.extend(status_report.staged.get("add", []))
|
|
|
+ changes.extend(status_report.staged.get("delete", []))
|
|
|
+ changes.extend(status_report.staged.get("modify", []))
|
|
|
+ # unstaged is a list
|
|
|
+ changes.extend(status_report.unstaged)
|
|
|
+
|
|
|
+ if changes:
|
|
|
+ # Check if any changes would conflict with checkout
|
|
|
+ target_tree_obj = repo[target_tree_id]
|
|
|
+ assert isinstance(target_tree_obj, Tree), "Expected a Tree object"
|
|
|
+ target_tree = target_tree_obj
|
|
|
+ for change in changes:
|
|
|
+ if isinstance(change, str):
|
|
|
+ change = change.encode(DEFAULT_ENCODING)
|
|
|
+
|
|
|
+ try:
|
|
|
+ target_tree.lookup_path(repo.object_store.__getitem__, change)
|
|
|
+ except KeyError:
|
|
|
+ # File doesn't exist in target tree - change can be preserved
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ # File exists in target tree - would overwrite local changes
|
|
|
+ raise CheckoutError(
|
|
|
+ f"Your local changes to '{change.decode()}' would be "
|
|
|
+ "overwritten. Please commit or stash before switching."
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
+def _get_worktree_update_config(
|
|
|
+ repo: Repo,
|
|
|
+) -> tuple[
|
|
|
+ bool,
|
|
|
+ Callable[[bytes], bool],
|
|
|
+ Callable[[str | bytes | os.PathLike[str], str | bytes | os.PathLike[str]], None],
|
|
|
+]:
|
|
|
+ """Get configuration for working tree updates.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ repo: Repository object
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ Tuple of (honor_filemode, validate_path_element, symlink_fn)
|
|
|
+ """
|
|
|
+ config = repo.get_config()
|
|
|
+ honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
|
|
|
+
|
|
|
+ if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):
|
|
|
+ validate_path_element = validate_path_element_ntfs
|
|
|
+ elif config.get_boolean(b"core", b"core.protectHFS", sys.platform == "darwin"):
|
|
|
+ validate_path_element = validate_path_element_hfs
|
|
|
+ else:
|
|
|
+ validate_path_element = validate_path_element_default
|
|
|
+
|
|
|
+ if config.get_boolean(b"core", b"symlinks", True):
|
|
|
+
|
|
|
+ def symlink_wrapper(
|
|
|
+ source: str | bytes | os.PathLike[str],
|
|
|
+ target: str | bytes | os.PathLike[str],
|
|
|
+ ) -> None:
|
|
|
+ symlink(source, target) # type: ignore[arg-type,unused-ignore]
|
|
|
+
|
|
|
+ symlink_fn = symlink_wrapper
|
|
|
+ else:
|
|
|
+
|
|
|
+ def symlink_fallback(
|
|
|
+ source: str | bytes | os.PathLike[str],
|
|
|
+ target: str | bytes | os.PathLike[str],
|
|
|
+ ) -> None:
|
|
|
+ mode = "w" + ("b" if isinstance(source, bytes) else "")
|
|
|
+ with open(target, mode) as f:
|
|
|
+ f.write(source)
|
|
|
+
|
|
|
+ symlink_fn = symlink_fallback
|
|
|
+
|
|
|
+ return honor_filemode, validate_path_element, symlink_fn
|
|
|
+
|
|
|
+
|
|
|
+def _perform_tree_switch(
|
|
|
+ repo: Repo,
|
|
|
+ current_tree_id: bytes | None,
|
|
|
+ target_tree_id: bytes,
|
|
|
+ force: bool = False,
|
|
|
+) -> None:
|
|
|
+ """Perform the actual working tree switch.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ repo: Repository object
|
|
|
+ current_tree_id: Current tree ID (or None for empty repo)
|
|
|
+ target_tree_id: Target tree ID to switch to
|
|
|
+ force: If True, force removal of untracked files and allow overwriting modified files
|
|
|
+ """
|
|
|
+ honor_filemode, validate_path_element, symlink_fn = _get_worktree_update_config(
|
|
|
+ repo
|
|
|
+ )
|
|
|
+
|
|
|
+ # Get blob normalizer for line ending conversion
|
|
|
+ blob_normalizer = repo.get_blob_normalizer()
|
|
|
+
|
|
|
+ # Update working tree
|
|
|
+ tree_change_iterator: Iterator[TreeChange] = tree_changes(
|
|
|
+ repo.object_store, current_tree_id, target_tree_id
|
|
|
+ )
|
|
|
+ update_working_tree(
|
|
|
+ repo,
|
|
|
+ current_tree_id,
|
|
|
+ target_tree_id,
|
|
|
+ change_iterator=tree_change_iterator,
|
|
|
+ honor_filemode=honor_filemode,
|
|
|
+ validate_path_element=validate_path_element,
|
|
|
+ symlink_fn=symlink_fn,
|
|
|
+ force_remove_untracked=force,
|
|
|
+ blob_normalizer=blob_normalizer,
|
|
|
+ allow_overwrite_modified=force,
|
|
|
+ )
|
|
|
+
|
|
|
+
|
|
|
def update_head(
|
|
|
repo: RepoPath,
|
|
|
target: str | bytes,
|
|
|
@@ -5233,102 +5373,257 @@ def checkout(
|
|
|
target_tree_id = target_commit.tree
|
|
|
|
|
|
# Get current HEAD tree for comparison
|
|
|
- try:
|
|
|
- current_head = r.refs[b"HEAD"]
|
|
|
- current_commit = r[current_head]
|
|
|
- assert isinstance(current_commit, Commit), "Expected a Commit object"
|
|
|
- current_tree_id = current_commit.tree
|
|
|
- except KeyError:
|
|
|
- # No HEAD yet (empty repo)
|
|
|
- current_tree_id = None
|
|
|
+ current_tree_id = _get_current_head_tree(r)
|
|
|
|
|
|
# Check for uncommitted changes if not forcing
|
|
|
- if not force and current_tree_id is not None:
|
|
|
- status_report = status(r)
|
|
|
- changes = []
|
|
|
- # staged is a dict with 'add', 'delete', 'modify' keys
|
|
|
- if isinstance(status_report.staged, dict):
|
|
|
- changes.extend(status_report.staged.get("add", []))
|
|
|
- changes.extend(status_report.staged.get("delete", []))
|
|
|
- changes.extend(status_report.staged.get("modify", []))
|
|
|
- # unstaged is a list
|
|
|
- changes.extend(status_report.unstaged)
|
|
|
- if changes:
|
|
|
- # Check if any changes would conflict with checkout
|
|
|
- target_tree_obj = r[target_tree_id]
|
|
|
- assert isinstance(target_tree_obj, Tree), "Expected a Tree object"
|
|
|
- target_tree = target_tree_obj
|
|
|
- for change in changes:
|
|
|
- if isinstance(change, str):
|
|
|
- change = change.encode(DEFAULT_ENCODING)
|
|
|
+ if current_tree_id is not None:
|
|
|
+ _check_uncommitted_changes(r, target_tree_id, force)
|
|
|
+
|
|
|
+ # Update working tree
|
|
|
+ _perform_tree_switch(r, current_tree_id, target_tree_id, force)
|
|
|
+
|
|
|
+ # Update HEAD
|
|
|
+ if new_branch:
|
|
|
+ # Create new branch and switch to it
|
|
|
+ branch_create(r, new_branch, objectish=target_commit.id.decode("ascii"))
|
|
|
+ update_head(r, new_branch)
|
|
|
+
|
|
|
+ # Set up tracking if creating from a remote branch
|
|
|
+ if isinstance(original_target, bytes) and target_bytes.startswith(
|
|
|
+ LOCAL_REMOTE_PREFIX
|
|
|
+ ):
|
|
|
+ try:
|
|
|
+ remote_name, branch_name = parse_remote_ref(target_bytes)
|
|
|
+ # Set tracking to refs/heads/<branch> on the remote
|
|
|
+ set_branch_tracking(
|
|
|
+ r, new_branch, remote_name, local_branch_name(branch_name)
|
|
|
+ )
|
|
|
+ except ValueError:
|
|
|
+ # Invalid remote ref format, skip tracking setup
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ # Check if target is a branch name (with or without refs/heads/ prefix)
|
|
|
+ branch_ref = None
|
|
|
+ if (
|
|
|
+ isinstance(original_target, (str, bytes))
|
|
|
+ and target_bytes in r.refs.keys()
|
|
|
+ ):
|
|
|
+ if target_bytes.startswith(LOCAL_BRANCH_PREFIX):
|
|
|
+ branch_ref = target_bytes
|
|
|
+ else:
|
|
|
+ # Try adding refs/heads/ prefix
|
|
|
+ potential_branch = (
|
|
|
+ _make_branch_ref(target_bytes)
|
|
|
+ if isinstance(original_target, (str, bytes))
|
|
|
+ else None
|
|
|
+ )
|
|
|
+ if potential_branch in r.refs.keys():
|
|
|
+ branch_ref = potential_branch
|
|
|
+
|
|
|
+ if branch_ref:
|
|
|
+ # It's a branch - update HEAD symbolically
|
|
|
+ update_head(r, branch_ref)
|
|
|
+ else:
|
|
|
+ # It's a tag, other ref, or commit SHA - detached HEAD
|
|
|
+ update_head(r, target_commit.id.decode("ascii"), detached=True)
|
|
|
+
|
|
|
+
|
|
|
+def restore(
|
|
|
+ repo: str | os.PathLike[str] | Repo,
|
|
|
+ paths: list[bytes | str],
|
|
|
+ source: str | bytes | Commit | Tag | None = None,
|
|
|
+ staged: bool = False,
|
|
|
+ worktree: bool = True,
|
|
|
+) -> None:
|
|
|
+ """Restore working tree files.
|
|
|
+
|
|
|
+ This is similar to 'git restore', allowing you to restore specific files
|
|
|
+ from a commit or the index without changing HEAD.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ repo: Path to repository or repository object
|
|
|
+ paths: List of specific paths to restore
|
|
|
+ source: Branch name, tag, or commit SHA to restore from. If None, restores
|
|
|
+ staged files from HEAD, or worktree files from index
|
|
|
+ staged: Restore files in the index (--staged)
|
|
|
+ worktree: Restore files in the working tree (default: True)
|
|
|
+
|
|
|
+ Raises:
|
|
|
+ CheckoutError: If restore cannot be performed
|
|
|
+ ValueError: If neither staged nor worktree is specified
|
|
|
+ KeyError: If the source reference cannot be found
|
|
|
+ """
|
|
|
+ if not staged and not worktree:
|
|
|
+ raise ValueError("At least one of staged or worktree must be True")
|
|
|
+
|
|
|
+ with open_repo_closing(repo) as r:
|
|
|
+ from .index import _fs_to_tree_path, build_file_from_blob
|
|
|
+
|
|
|
+ # Determine the source tree
|
|
|
+ if source is None:
|
|
|
+ if staged:
|
|
|
+ # Restoring staged files from HEAD
|
|
|
+ try:
|
|
|
+ source = r.refs[b"HEAD"]
|
|
|
+ except KeyError:
|
|
|
+ raise CheckoutError("No HEAD reference found")
|
|
|
+ elif worktree:
|
|
|
+ # Restoring worktree files from index
|
|
|
+ from .index import ConflictedIndexEntry, IndexEntry
|
|
|
+
|
|
|
+ index = r.open_index()
|
|
|
+ for path in paths:
|
|
|
+ if isinstance(path, str):
|
|
|
+ tree_path = _fs_to_tree_path(path)
|
|
|
+ else:
|
|
|
+ tree_path = path
|
|
|
|
|
|
try:
|
|
|
- target_tree.lookup_path(r.object_store.__getitem__, change)
|
|
|
+ index_entry = index[tree_path]
|
|
|
+ if isinstance(index_entry, ConflictedIndexEntry):
|
|
|
+ raise CheckoutError(
|
|
|
+ f"Path '{path if isinstance(path, str) else path.decode(DEFAULT_ENCODING)}' has conflicts"
|
|
|
+ )
|
|
|
+ blob = r[index_entry.sha]
|
|
|
+ assert isinstance(blob, Blob), "Expected a Blob object"
|
|
|
+
|
|
|
+ full_path = os.path.join(os.fsencode(r.path), tree_path)
|
|
|
+ mode = index_entry.mode
|
|
|
+
|
|
|
+ # Use build_file_from_blob to write the file
|
|
|
+ build_file_from_blob(blob, mode, full_path)
|
|
|
except KeyError:
|
|
|
- # File doesn't exist in target tree - change can be preserved
|
|
|
- pass
|
|
|
- else:
|
|
|
- # File exists in target tree - would overwrite local changes
|
|
|
+ # Path doesn't exist in index
|
|
|
raise CheckoutError(
|
|
|
- f"Your local changes to '{change.decode()}' would be "
|
|
|
- "overwritten by checkout. Please commit or stash before switching."
|
|
|
+ f"Path '{path if isinstance(path, str) else path.decode(DEFAULT_ENCODING)}' not in index"
|
|
|
)
|
|
|
+ return
|
|
|
|
|
|
- # Get configuration for working directory update
|
|
|
- config = r.get_config()
|
|
|
- honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
|
|
|
+ # source is not None at this point
|
|
|
+ assert source is not None
|
|
|
+ # Get the source tree
|
|
|
+ source_tree = parse_tree(r, treeish=source)
|
|
|
+
|
|
|
+ # Restore specified paths from source tree
|
|
|
+ for path in paths:
|
|
|
+ if isinstance(path, str):
|
|
|
+ tree_path = _fs_to_tree_path(path)
|
|
|
+ else:
|
|
|
+ tree_path = path
|
|
|
+
|
|
|
+ try:
|
|
|
+ # Look up the path in the source tree
|
|
|
+ mode, sha = source_tree.lookup_path(
|
|
|
+ r.object_store.__getitem__, tree_path
|
|
|
+ )
|
|
|
+ blob = r[sha]
|
|
|
+ assert isinstance(blob, Blob), "Expected a Blob object"
|
|
|
+ except KeyError:
|
|
|
+ # Path doesn't exist in source tree
|
|
|
+ raise CheckoutError(
|
|
|
+ f"Path '{path if isinstance(path, str) else path.decode(DEFAULT_ENCODING)}' not found in source"
|
|
|
+ )
|
|
|
+
|
|
|
+ full_path = os.path.join(os.fsencode(r.path), tree_path)
|
|
|
+
|
|
|
+ if worktree:
|
|
|
+ # Use build_file_from_blob to restore to working tree
|
|
|
+ build_file_from_blob(blob, mode, full_path)
|
|
|
+
|
|
|
+ if staged:
|
|
|
+ # Update the index with the blob from source
|
|
|
+ from .index import IndexEntry
|
|
|
+
|
|
|
+ index = r.open_index()
|
|
|
+
|
|
|
+ # When only updating staged (not worktree), we want to reset the index
|
|
|
+ # to the source, but invalidate the stat cache so Git knows to check
|
|
|
+ # the worktree file. Use zeros for stat fields.
|
|
|
+ if not worktree:
|
|
|
+ # Invalidate stat cache by using zeros
|
|
|
+ new_entry = IndexEntry(
|
|
|
+ ctime=(0, 0),
|
|
|
+ mtime=(0, 0),
|
|
|
+ dev=0,
|
|
|
+ ino=0,
|
|
|
+ mode=mode,
|
|
|
+ uid=0,
|
|
|
+ gid=0,
|
|
|
+ size=0,
|
|
|
+ sha=sha,
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ # If we also updated worktree, use actual stat
|
|
|
+ from .index import index_entry_from_stat
|
|
|
+
|
|
|
+ st = os.lstat(full_path)
|
|
|
+ new_entry = index_entry_from_stat(st, sha, mode)
|
|
|
+
|
|
|
+ index[tree_path] = new_entry
|
|
|
+ index.write()
|
|
|
|
|
|
- if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):
|
|
|
- validate_path_element = validate_path_element_ntfs
|
|
|
- else:
|
|
|
- validate_path_element = validate_path_element_default
|
|
|
|
|
|
- if config.get_boolean(b"core", b"symlinks", True):
|
|
|
+def switch(
|
|
|
+ repo: str | os.PathLike[str] | Repo,
|
|
|
+ target: str | bytes | Commit | Tag,
|
|
|
+ create: str | bytes | None = None,
|
|
|
+ force: bool = False,
|
|
|
+ detach: bool = False,
|
|
|
+) -> None:
|
|
|
+ """Switch branches.
|
|
|
+
|
|
|
+ This is similar to 'git switch', allowing you to switch to a different
|
|
|
+ branch or commit, updating both HEAD and the working tree.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ repo: Path to repository or repository object
|
|
|
+ target: Branch name, tag, or commit SHA to switch to
|
|
|
+ create: Create a new branch at target before switching (like git switch -c)
|
|
|
+ force: Force switch even if there are local changes
|
|
|
+ detach: Switch to a commit in detached HEAD state (like git switch --detach)
|
|
|
|
|
|
- def symlink_wrapper(
|
|
|
- source: str | bytes | os.PathLike[str],
|
|
|
- target: str | bytes | os.PathLike[str],
|
|
|
- ) -> None:
|
|
|
- symlink(source, target) # type: ignore[arg-type,unused-ignore]
|
|
|
+ Raises:
|
|
|
+ CheckoutError: If switch cannot be performed due to conflicts
|
|
|
+ KeyError: If the target reference cannot be found
|
|
|
+ ValueError: If both create and detach are specified
|
|
|
+ """
|
|
|
+ if create and detach:
|
|
|
+ raise ValueError("Cannot use both create and detach options")
|
|
|
+
|
|
|
+ with open_repo_closing(repo) as r:
|
|
|
+ # Store the original target for later reference checks
|
|
|
+ original_target = target
|
|
|
|
|
|
- symlink_fn = symlink_wrapper
|
|
|
+ if isinstance(target, str):
|
|
|
+ target_bytes = target.encode(DEFAULT_ENCODING)
|
|
|
+ elif isinstance(target, bytes):
|
|
|
+ target_bytes = target
|
|
|
else:
|
|
|
+ # For Commit/Tag objects, we'll use their SHA
|
|
|
+ target_bytes = target.id
|
|
|
|
|
|
- def symlink_fallback(
|
|
|
- source: str | bytes | os.PathLike[str],
|
|
|
- target: str | bytes | os.PathLike[str],
|
|
|
- ) -> None:
|
|
|
- mode = "w" + ("b" if isinstance(source, bytes) else "")
|
|
|
- with open(target, mode) as f:
|
|
|
- f.write(source)
|
|
|
+ if isinstance(create, str):
|
|
|
+ create = create.encode(DEFAULT_ENCODING)
|
|
|
|
|
|
- symlink_fn = symlink_fallback
|
|
|
+ # Parse the target to get the commit
|
|
|
+ target_commit = parse_commit(r, original_target)
|
|
|
+ target_tree_id = target_commit.tree
|
|
|
|
|
|
- # Get blob normalizer for line ending conversion
|
|
|
- blob_normalizer = r.get_blob_normalizer()
|
|
|
+ # Get current HEAD tree for comparison
|
|
|
+ current_tree_id = _get_current_head_tree(r)
|
|
|
+
|
|
|
+ # Check for uncommitted changes if not forcing
|
|
|
+ if current_tree_id is not None:
|
|
|
+ _check_uncommitted_changes(r, target_tree_id, force)
|
|
|
|
|
|
# Update working tree
|
|
|
- tree_change_iterator: Iterator[TreeChange] = tree_changes(
|
|
|
- r.object_store, current_tree_id, target_tree_id
|
|
|
- )
|
|
|
- update_working_tree(
|
|
|
- r,
|
|
|
- current_tree_id,
|
|
|
- target_tree_id,
|
|
|
- change_iterator=tree_change_iterator,
|
|
|
- honor_filemode=honor_filemode,
|
|
|
- validate_path_element=validate_path_element,
|
|
|
- symlink_fn=symlink_fn,
|
|
|
- force_remove_untracked=force,
|
|
|
- blob_normalizer=blob_normalizer,
|
|
|
- allow_overwrite_modified=force,
|
|
|
- )
|
|
|
+ _perform_tree_switch(r, current_tree_id, target_tree_id, force)
|
|
|
|
|
|
# Update HEAD
|
|
|
- if new_branch:
|
|
|
+ if create:
|
|
|
# Create new branch and switch to it
|
|
|
- branch_create(r, new_branch, objectish=target_commit.id.decode("ascii"))
|
|
|
- update_head(r, new_branch)
|
|
|
+ branch_create(r, create, objectish=target_commit.id.decode("ascii"))
|
|
|
+ update_head(r, create)
|
|
|
|
|
|
# Set up tracking if creating from a remote branch
|
|
|
from .refs import LOCAL_REMOTE_PREFIX, local_branch_name, parse_remote_ref
|
|
|
@@ -5340,11 +5635,14 @@ def checkout(
|
|
|
remote_name, branch_name = parse_remote_ref(target_bytes)
|
|
|
# Set tracking to refs/heads/<branch> on the remote
|
|
|
set_branch_tracking(
|
|
|
- r, new_branch, remote_name, local_branch_name(branch_name)
|
|
|
+ r, create, remote_name, local_branch_name(branch_name)
|
|
|
)
|
|
|
except ValueError:
|
|
|
# Invalid remote ref format, skip tracking setup
|
|
|
pass
|
|
|
+ elif detach:
|
|
|
+ # Detached HEAD mode
|
|
|
+ update_head(r, target_commit.id.decode("ascii"), detached=True)
|
|
|
else:
|
|
|
# Check if target is a branch name (with or without refs/heads/ prefix)
|
|
|
branch_ref = None
|
|
|
@@ -5368,8 +5666,12 @@ def checkout(
|
|
|
# It's a branch - update HEAD symbolically
|
|
|
update_head(r, branch_ref)
|
|
|
else:
|
|
|
- # It's a tag, other ref, or commit SHA - detached HEAD
|
|
|
- update_head(r, target_commit.id.decode("ascii"), detached=True)
|
|
|
+ # It's a tag, other ref, or commit SHA
|
|
|
+ # In git switch, this would be an error unless --detach is used
|
|
|
+ raise CheckoutError(
|
|
|
+ f"'{target_bytes.decode(DEFAULT_ENCODING)}' is not a branch. "
|
|
|
+ "Use detach=True to switch to a commit in detached HEAD state."
|
|
|
+ )
|
|
|
|
|
|
|
|
|
def reset_file(
|
|
|
@@ -6911,6 +7213,7 @@ def rebase(
|
|
|
Raises:
|
|
|
Error: If rebase fails or conflicts occur
|
|
|
"""
|
|
|
+ # TODO: Avoid importing from .cli
|
|
|
from .cli import launch_editor
|
|
|
from .rebase import (
|
|
|
RebaseConflict,
|
|
|
@@ -7038,7 +7341,7 @@ def annotate(
|
|
|
"""
|
|
|
if committish is None:
|
|
|
committish = "HEAD"
|
|
|
- from dulwich.annotate import annotate_lines
|
|
|
+ from .annotate import annotate_lines
|
|
|
|
|
|
with open_repo_closing(repo) as r:
|
|
|
commit_id = parse_commit(r, committish).id
|
|
|
@@ -7055,7 +7358,7 @@ def filter_branch(
|
|
|
repo: RepoPath = ".",
|
|
|
branch: str | bytes = "HEAD",
|
|
|
*,
|
|
|
- filter_fn: Callable[[Commit], Optional["CommitData"]] | None = None,
|
|
|
+ filter_fn: Callable[[Commit], "CommitData | None"] | None = None,
|
|
|
filter_author: Callable[[bytes], bytes | None] | None = None,
|
|
|
filter_committer: Callable[[bytes], bytes | None] | None = None,
|
|
|
filter_message: Callable[[bytes], bytes | None] | None = None,
|
|
|
@@ -8577,7 +8880,6 @@ def merge_base(
|
|
|
List of commit IDs that are merge bases
|
|
|
"""
|
|
|
from .graph import find_merge_base, find_octopus_base
|
|
|
- from .objects import Commit
|
|
|
from .objectspec import parse_object
|
|
|
|
|
|
if committishes is None or len(committishes) < 2:
|
|
|
@@ -8620,7 +8922,6 @@ def is_ancestor(
|
|
|
True if ancestor is an ancestor of descendant, False otherwise
|
|
|
"""
|
|
|
from .graph import find_merge_base
|
|
|
- from .objects import Commit
|
|
|
from .objectspec import parse_object
|
|
|
|
|
|
if ancestor is None or descendant is None:
|
|
|
@@ -8656,7 +8957,6 @@ def independent_commits(
|
|
|
List of commit IDs that are not ancestors of any other commits in the list
|
|
|
"""
|
|
|
from .graph import independent
|
|
|
- from .objects import Commit
|
|
|
from .objectspec import parse_object
|
|
|
|
|
|
if committishes is None or len(committishes) == 0:
|
|
|
@@ -8726,8 +9026,6 @@ def mailsplit(
|
|
|
keep_cr=keep_cr,
|
|
|
)
|
|
|
else:
|
|
|
- from typing import BinaryIO, cast
|
|
|
-
|
|
|
if input_path is None:
|
|
|
# Read from stdin
|
|
|
input_file: str | bytes | BinaryIO = sys.stdin.buffer
|
|
|
@@ -8788,8 +9086,6 @@ def mailinfo(
|
|
|
>>> print(f"Author: {result.author_name} <{result.author_email}>")
|
|
|
>>> print(f"Subject: {result.subject}")
|
|
|
"""
|
|
|
- from typing import BinaryIO, TextIO, cast
|
|
|
-
|
|
|
from .mbox import mailinfo as mbox_mailinfo
|
|
|
|
|
|
if input_path is None:
|
|
|
@@ -8848,15 +9144,13 @@ def rerere(repo: RepoPath = ".") -> tuple[list[tuple[bytes, str]], list[bytes]]:
|
|
|
- List of tuples (path, conflict_id) for recorded conflicts
|
|
|
- List of paths where resolutions were automatically applied
|
|
|
"""
|
|
|
- from dulwich.rerere import _has_conflict_markers, rerere_auto
|
|
|
+ from .rerere import _has_conflict_markers, rerere_auto
|
|
|
|
|
|
with open_repo_closing(repo) as r:
|
|
|
# Get conflicts from the index (if available)
|
|
|
index = r.open_index()
|
|
|
conflicts = []
|
|
|
|
|
|
- from dulwich.index import ConflictedIndexEntry
|
|
|
-
|
|
|
for path, entry in index.items():
|
|
|
if isinstance(entry, ConflictedIndexEntry):
|
|
|
conflicts.append(path)
|
|
|
@@ -8889,7 +9183,7 @@ def rerere_status(repo: RepoPath = ".") -> list[tuple[str, bool]]:
|
|
|
Returns:
|
|
|
List of tuples (conflict_id, has_resolution)
|
|
|
"""
|
|
|
- from dulwich.rerere import RerereCache
|
|
|
+ from .rerere import RerereCache
|
|
|
|
|
|
with open_repo_closing(repo) as r:
|
|
|
cache = RerereCache.from_repo(r)
|
|
|
@@ -8908,7 +9202,7 @@ def rerere_diff(
|
|
|
Returns:
|
|
|
List of tuples (conflict_id, preimage, postimage)
|
|
|
"""
|
|
|
- from dulwich.rerere import RerereCache
|
|
|
+ from .rerere import RerereCache
|
|
|
|
|
|
with open_repo_closing(repo) as r:
|
|
|
cache = RerereCache.from_repo(r)
|
|
|
@@ -8935,7 +9229,7 @@ def rerere_forget(repo: RepoPath = ".", pathspec: str | bytes | None = None) ->
|
|
|
repo: Path to the repository
|
|
|
pathspec: Path to forget (currently not implemented, forgets all)
|
|
|
"""
|
|
|
- from dulwich.rerere import RerereCache
|
|
|
+ from .rerere import RerereCache
|
|
|
|
|
|
with open_repo_closing(repo) as r:
|
|
|
cache = RerereCache.from_repo(r)
|
|
|
@@ -8955,7 +9249,7 @@ def rerere_clear(repo: RepoPath = ".") -> None:
|
|
|
Args:
|
|
|
repo: Path to the repository
|
|
|
"""
|
|
|
- from dulwich.rerere import RerereCache
|
|
|
+ from .rerere import RerereCache
|
|
|
|
|
|
with open_repo_closing(repo) as r:
|
|
|
cache = RerereCache.from_repo(r)
|
|
|
@@ -8969,7 +9263,7 @@ def rerere_gc(repo: RepoPath = ".", max_age_days: int = 60) -> None:
|
|
|
repo: Path to the repository
|
|
|
max_age_days: Maximum age in days for keeping resolutions
|
|
|
"""
|
|
|
- from dulwich.rerere import RerereCache
|
|
|
+ from .rerere import RerereCache
|
|
|
|
|
|
with open_repo_closing(repo) as r:
|
|
|
cache = RerereCache.from_repo(r)
|