ソースを参照

Implement restore and switch commands (#2003)

Fixes #1777
Jelmer Vernooij 2 ヶ月 前
コミット
6527bee5e6
4 ファイル変更677 行追加114 行削除
  1. 3 0
      NEWS
  2. 113 0
      dulwich/cli.py
  3. 410 114
      dulwich/porcelain.py
  4. 151 0
      tests/test_porcelain.py

+ 3 - 0
NEWS

@@ -32,6 +32,9 @@
    Dulwich version, and installed dependencies with their versions.
    (Jelmer Vernooij, #1835)
 
+ * Add basic ``dulwich restore`` and ``dulwich switch``
+   commands. (Jelmer Vernooij, #1777)
+
 0.24.10	2025-11-10
 
  * Fix compatibility with python 3.9. (Jelmer Vernooij, #1991)

+ 113 - 0
dulwich/cli.py

@@ -3841,6 +3841,117 @@ class cmd_checkout(Command):
         return 0
 
 
+class cmd_restore(Command):
+    """Restore working tree files."""
+
+    def run(self, args: Sequence[str]) -> int | None:
+        """Execute the restore command.
+
+        Args:
+            args: Command line arguments
+        """
+        parser = argparse.ArgumentParser()
+        parser.add_argument(
+            "paths",
+            nargs="+",
+            type=str,
+            help="Paths to restore",
+        )
+        parser.add_argument(
+            "-s",
+            "--source",
+            type=str,
+            help="Restore from a specific commit (default: HEAD for --staged, index for worktree)",
+        )
+        parser.add_argument(
+            "--staged",
+            action="store_true",
+            help="Restore files in the index",
+        )
+        parser.add_argument(
+            "--worktree",
+            action="store_true",
+            help="Restore files in the working tree",
+        )
+        parsed_args = parser.parse_args(args)
+
+        # If neither --staged nor --worktree is specified, default to --worktree
+        if not parsed_args.staged and not parsed_args.worktree:
+            worktree = True
+            staged = False
+        else:
+            worktree = parsed_args.worktree
+            staged = parsed_args.staged
+
+        try:
+            porcelain.restore(
+                ".",
+                paths=parsed_args.paths,
+                source=parsed_args.source,
+                staged=staged,
+                worktree=worktree,
+            )
+        except porcelain.CheckoutError as e:
+            sys.stderr.write(f"{e}\n")
+            return 1
+        return 0
+
+
+class cmd_switch(Command):
+    """Switch branches."""
+
+    def run(self, args: Sequence[str]) -> int | None:
+        """Execute the switch command.
+
+        Args:
+            args: Command line arguments
+        """
+        parser = argparse.ArgumentParser()
+        parser.add_argument(
+            "target",
+            type=str,
+            help="Branch or commit to switch to",
+        )
+        parser.add_argument(
+            "-c",
+            "--create",
+            type=str,
+            help="Create a new branch at the target and switch to it",
+        )
+        parser.add_argument(
+            "-f",
+            "--force",
+            action="store_true",
+            help="Force switch even if there are local changes",
+        )
+        parser.add_argument(
+            "-d",
+            "--detach",
+            action="store_true",
+            help="Switch to a commit in detached HEAD state",
+        )
+        parsed_args = parser.parse_args(args)
+
+        if not parsed_args.target:
+            logger.error(
+                "Usage: dulwich switch TARGET [-c NEW_BRANCH] [--force] [--detach]"
+            )
+            return 1
+
+        try:
+            porcelain.switch(
+                ".",
+                target=parsed_args.target,
+                create=parsed_args.create,
+                force=parsed_args.force,
+                detach=parsed_args.detach,
+            )
+        except porcelain.CheckoutError as e:
+            sys.stderr.write(f"{e}\n")
+            return 1
+        return 0
+
+
 class cmd_stash_list(Command):
     """List stash entries."""
 
@@ -6550,6 +6661,7 @@ commands = {
     "repack": cmd_repack,
     "replace": cmd_replace,
     "reset": cmd_reset,
+    "restore": cmd_restore,
     "revert": cmd_revert,
     "rev-list": cmd_rev_list,
     "rm": cmd_rm,
@@ -6561,6 +6673,7 @@ commands = {
     "status": cmd_status,
     "stripspace": cmd_stripspace,
     "shortlog": cmd_shortlog,
+    "switch": cmd_switch,
     "symbolic-ref": cmd_symbolic_ref,
     "submodule": cmd_submodule,
     "tag": cmd_tag,

+ 410 - 114
dulwich/porcelain.py

@@ -2845,42 +2845,6 @@ def reset(
 
         elif mode == "hard":
             # Hard reset: update HEAD, index, and working tree
-            # Get configuration for working directory update
-            config = r.get_config()
-            honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
-
-            if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):
-                validate_path_element = validate_path_element_ntfs
-            elif config.get_boolean(
-                b"core", b"core.protectHFS", sys.platform == "darwin"
-            ):
-                validate_path_element = validate_path_element_hfs
-            else:
-                validate_path_element = validate_path_element_default
-
-            if config.get_boolean(b"core", b"symlinks", True):
-
-                def symlink_wrapper(
-                    source: str | bytes | os.PathLike[str],
-                    target: str | bytes | os.PathLike[str],
-                ) -> None:
-                    symlink(source, target)  # type: ignore[arg-type,unused-ignore]
-
-                symlink_fn = symlink_wrapper
-            else:
-
-                def symlink_fallback(
-                    source: str | bytes | os.PathLike[str],
-                    target: str | bytes | os.PathLike[str],
-                ) -> None:
-                    mode = "w" + ("b" if isinstance(source, bytes) else "")
-                    with open(target, mode) as f:
-                        f.write(source)
-
-                symlink_fn = symlink_fallback
-
-            # Update working tree and index
-            blob_normalizer = r.get_blob_normalizer()
             # For reset --hard, use current index tree as old tree to get proper deletions
             index = r.open_index()
             if len(index) > 0:
@@ -2889,6 +2853,12 @@ def reset(
                 # Empty index
                 index_tree_id = None
 
+            # Get configuration for working tree updates
+            honor_filemode, validate_path_element, symlink_fn = (
+                _get_worktree_update_config(r)
+            )
+
+            blob_normalizer = r.get_blob_normalizer()
             changes = tree_changes(
                 r.object_store, index_tree_id, tree.id, want_unchanged=True
             )
@@ -5082,6 +5052,168 @@ def check_ignore(
                 yield _quote_path(output_path) if quote_path else output_path
 
 
+def _get_current_head_tree(repo: Repo) -> bytes | None:
+    """Get the current HEAD tree ID.
+
+    Args:
+      repo: Repository object
+
+    Returns:
+      Tree ID of current HEAD, or None if no HEAD exists (empty repo)
+    """
+    try:
+        current_head = repo.refs[b"HEAD"]
+        current_commit = repo[current_head]
+        assert isinstance(current_commit, Commit), "Expected a Commit object"
+        tree_id: bytes = current_commit.tree
+        return tree_id
+    except KeyError:
+        # No HEAD yet (empty repo)
+        return None
+
+
+def _check_uncommitted_changes(
+    repo: Repo, target_tree_id: bytes, force: bool = False
+) -> None:
+    """Check for uncommitted changes that would conflict with a checkout/switch.
+
+    Args:
+      repo: Repository object
+      target_tree_id: Tree ID to check conflicts against
+      force: If True, skip the check
+
+    Raises:
+      CheckoutError: If there are conflicting local changes
+    """
+    if force:
+        return
+
+    # Get current HEAD tree for comparison
+    current_tree_id = _get_current_head_tree(repo)
+    if current_tree_id is None:
+        # No HEAD yet (empty repo)
+        return
+
+    status_report = status(repo)
+    changes = []
+    # staged is a dict with 'add', 'delete', 'modify' keys
+    if isinstance(status_report.staged, dict):
+        changes.extend(status_report.staged.get("add", []))
+        changes.extend(status_report.staged.get("delete", []))
+        changes.extend(status_report.staged.get("modify", []))
+    # unstaged is a list
+    changes.extend(status_report.unstaged)
+
+    if changes:
+        # Check if any changes would conflict with checkout
+        target_tree_obj = repo[target_tree_id]
+        assert isinstance(target_tree_obj, Tree), "Expected a Tree object"
+        target_tree = target_tree_obj
+        for change in changes:
+            if isinstance(change, str):
+                change = change.encode(DEFAULT_ENCODING)
+
+            try:
+                target_tree.lookup_path(repo.object_store.__getitem__, change)
+            except KeyError:
+                # File doesn't exist in target tree - change can be preserved
+                pass
+            else:
+                # File exists in target tree - would overwrite local changes
+                raise CheckoutError(
+                    f"Your local changes to '{change.decode()}' would be "
+                    "overwritten. Please commit or stash before switching."
+                )
+
+
+def _get_worktree_update_config(
+    repo: Repo,
+) -> tuple[
+    bool,
+    Callable[[bytes], bool],
+    Callable[[str | bytes | os.PathLike[str], str | bytes | os.PathLike[str]], None],
+]:
+    """Get configuration for working tree updates.
+
+    Args:
+      repo: Repository object
+
+    Returns:
+      Tuple of (honor_filemode, validate_path_element, symlink_fn)
+    """
+    config = repo.get_config()
+    honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
+
+    if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):
+        validate_path_element = validate_path_element_ntfs
+    elif config.get_boolean(b"core", b"core.protectHFS", sys.platform == "darwin"):
+        validate_path_element = validate_path_element_hfs
+    else:
+        validate_path_element = validate_path_element_default
+
+    if config.get_boolean(b"core", b"symlinks", True):
+
+        def symlink_wrapper(
+            source: str | bytes | os.PathLike[str],
+            target: str | bytes | os.PathLike[str],
+        ) -> None:
+            symlink(source, target)  # type: ignore[arg-type,unused-ignore]
+
+        symlink_fn = symlink_wrapper
+    else:
+
+        def symlink_fallback(
+            source: str | bytes | os.PathLike[str],
+            target: str | bytes | os.PathLike[str],
+        ) -> None:
+            mode = "w" + ("b" if isinstance(source, bytes) else "")
+            with open(target, mode) as f:
+                f.write(source)
+
+        symlink_fn = symlink_fallback
+
+    return honor_filemode, validate_path_element, symlink_fn
+
+
+def _perform_tree_switch(
+    repo: Repo,
+    current_tree_id: bytes | None,
+    target_tree_id: bytes,
+    force: bool = False,
+) -> None:
+    """Perform the actual working tree switch.
+
+    Args:
+      repo: Repository object
+      current_tree_id: Current tree ID (or None for empty repo)
+      target_tree_id: Target tree ID to switch to
+      force: If True, force removal of untracked files and allow overwriting modified files
+    """
+    honor_filemode, validate_path_element, symlink_fn = _get_worktree_update_config(
+        repo
+    )
+
+    # Get blob normalizer for line ending conversion
+    blob_normalizer = repo.get_blob_normalizer()
+
+    # Update working tree
+    tree_change_iterator: Iterator[TreeChange] = tree_changes(
+        repo.object_store, current_tree_id, target_tree_id
+    )
+    update_working_tree(
+        repo,
+        current_tree_id,
+        target_tree_id,
+        change_iterator=tree_change_iterator,
+        honor_filemode=honor_filemode,
+        validate_path_element=validate_path_element,
+        symlink_fn=symlink_fn,
+        force_remove_untracked=force,
+        blob_normalizer=blob_normalizer,
+        allow_overwrite_modified=force,
+    )
+
+
 def update_head(
     repo: RepoPath,
     target: str | bytes,
@@ -5236,104 +5368,261 @@ def checkout(
         target_tree_id = target_commit.tree
 
         # Get current HEAD tree for comparison
-        try:
-            current_head = r.refs[b"HEAD"]
-            current_commit = r[current_head]
-            assert isinstance(current_commit, Commit), "Expected a Commit object"
-            current_tree_id = current_commit.tree
-        except KeyError:
-            # No HEAD yet (empty repo)
-            current_tree_id = None
+        current_tree_id = _get_current_head_tree(r)
 
         # Check for uncommitted changes if not forcing
-        if not force and current_tree_id is not None:
-            status_report = status(r)
-            changes = []
-            # staged is a dict with 'add', 'delete', 'modify' keys
-            if isinstance(status_report.staged, dict):
-                changes.extend(status_report.staged.get("add", []))
-                changes.extend(status_report.staged.get("delete", []))
-                changes.extend(status_report.staged.get("modify", []))
-            # unstaged is a list
-            changes.extend(status_report.unstaged)
-            if changes:
-                # Check if any changes would conflict with checkout
-                target_tree_obj = r[target_tree_id]
-                assert isinstance(target_tree_obj, Tree), "Expected a Tree object"
-                target_tree = target_tree_obj
-                for change in changes:
-                    if isinstance(change, str):
-                        change = change.encode(DEFAULT_ENCODING)
+        if current_tree_id is not None:
+            _check_uncommitted_changes(r, target_tree_id, force)
+
+        # Update working tree
+        _perform_tree_switch(r, current_tree_id, target_tree_id, force)
+
+        # Update HEAD
+        if new_branch:
+            # Create new branch and switch to it
+            branch_create(r, new_branch, objectish=target_commit.id.decode("ascii"))
+            update_head(r, new_branch)
+
+            # Set up tracking if creating from a remote branch
+            if isinstance(original_target, bytes) and target_bytes.startswith(
+                LOCAL_REMOTE_PREFIX
+            ):
+                try:
+                    remote_name, branch_name = parse_remote_ref(target_bytes)
+                    # Set tracking to refs/heads/<branch> on the remote
+                    set_branch_tracking(
+                        r, new_branch, remote_name, local_branch_name(branch_name)
+                    )
+                except ValueError:
+                    # Invalid remote ref format, skip tracking setup
+                    pass
+        else:
+            # Check if target is a branch name (with or without refs/heads/ prefix)
+            branch_ref = None
+            if (
+                isinstance(original_target, (str, bytes))
+                and target_bytes in r.refs.keys()
+            ):
+                if target_bytes.startswith(LOCAL_BRANCH_PREFIX):
+                    branch_ref = target_bytes
+            else:
+                # Try adding refs/heads/ prefix
+                potential_branch = (
+                    _make_branch_ref(target_bytes)
+                    if isinstance(original_target, (str, bytes))
+                    else None
+                )
+                if potential_branch in r.refs.keys():
+                    branch_ref = potential_branch
+
+            if branch_ref:
+                # It's a branch - update HEAD symbolically
+                update_head(r, branch_ref)
+            else:
+                # It's a tag, other ref, or commit SHA - detached HEAD
+                update_head(r, target_commit.id.decode("ascii"), detached=True)
+
+
+def restore(
+    repo: str | os.PathLike[str] | Repo,
+    paths: list[bytes | str],
+    source: str | bytes | Commit | Tag | None = None,
+    staged: bool = False,
+    worktree: bool = True,
+) -> None:
+    """Restore working tree files.
+
+    This is similar to 'git restore', allowing you to restore specific files
+    from a commit or the index without changing HEAD.
+
+    Args:
+      repo: Path to repository or repository object
+      paths: List of specific paths to restore
+      source: Branch name, tag, or commit SHA to restore from. If None, restores
+              staged files from HEAD, or worktree files from index
+      staged: Restore files in the index (--staged)
+      worktree: Restore files in the working tree (default: True)
+
+    Raises:
+      CheckoutError: If restore cannot be performed
+      ValueError: If neither staged nor worktree is specified
+      KeyError: If the source reference cannot be found
+    """
+    if not staged and not worktree:
+        raise ValueError("At least one of staged or worktree must be True")
+
+    with open_repo_closing(repo) as r:
+        from .index import _fs_to_tree_path, build_file_from_blob
+
+        # Determine the source tree
+        if source is None:
+            if staged:
+                # Restoring staged files from HEAD
+                try:
+                    source = r.refs[b"HEAD"]
+                except KeyError:
+                    raise CheckoutError("No HEAD reference found")
+            elif worktree:
+                # Restoring worktree files from index
+                from .index import ConflictedIndexEntry, IndexEntry
+
+                index = r.open_index()
+                for path in paths:
+                    if isinstance(path, str):
+                        tree_path = _fs_to_tree_path(path)
+                    else:
+                        tree_path = path
 
                     try:
-                        target_tree.lookup_path(r.object_store.__getitem__, change)
+                        index_entry = index[tree_path]
+                        if isinstance(index_entry, ConflictedIndexEntry):
+                            raise CheckoutError(
+                                f"Path '{path if isinstance(path, str) else path.decode(DEFAULT_ENCODING)}' has conflicts"
+                            )
+                        blob = r[index_entry.sha]
+                        assert isinstance(blob, Blob), "Expected a Blob object"
+
+                        full_path = os.path.join(os.fsencode(r.path), tree_path)
+                        mode = index_entry.mode
+
+                        # Use build_file_from_blob to write the file
+                        build_file_from_blob(blob, mode, full_path)
                     except KeyError:
-                        # File doesn't exist in target tree - change can be preserved
-                        pass
-                    else:
-                        # File exists in target tree - would overwrite local changes
+                        # Path doesn't exist in index
                         raise CheckoutError(
-                            f"Your local changes to '{change.decode()}' would be "
-                            "overwritten by checkout. Please commit or stash before switching."
+                            f"Path '{path if isinstance(path, str) else path.decode(DEFAULT_ENCODING)}' not in index"
                         )
+                return
 
-        # Get configuration for working directory update
-        config = r.get_config()
-        honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
+        # source is not None at this point
+        assert source is not None
+        # Get the source tree
+        source_tree = parse_tree(r, treeish=source)
 
-        if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):
-            validate_path_element = validate_path_element_ntfs
-        else:
-            validate_path_element = validate_path_element_default
+        # Restore specified paths from source tree
+        for path in paths:
+            if isinstance(path, str):
+                tree_path = _fs_to_tree_path(path)
+            else:
+                tree_path = path
 
-        if config.get_boolean(b"core", b"symlinks", True):
+            try:
+                # Look up the path in the source tree
+                mode, sha = source_tree.lookup_path(
+                    r.object_store.__getitem__, tree_path
+                )
+                blob = r[sha]
+                assert isinstance(blob, Blob), "Expected a Blob object"
+            except KeyError:
+                # Path doesn't exist in source tree
+                raise CheckoutError(
+                    f"Path '{path if isinstance(path, str) else path.decode(DEFAULT_ENCODING)}' not found in source"
+                )
+
+            full_path = os.path.join(os.fsencode(r.path), tree_path)
+
+            if worktree:
+                # Use build_file_from_blob to restore to working tree
+                build_file_from_blob(blob, mode, full_path)
+
+            if staged:
+                # Update the index with the blob from source
+                from .index import IndexEntry
+
+                index = r.open_index()
+
+                # When only updating staged (not worktree), we want to reset the index
+                # to the source, but invalidate the stat cache so Git knows to check
+                # the worktree file. Use zeros for stat fields.
+                if not worktree:
+                    # Invalidate stat cache by using zeros
+                    new_entry = IndexEntry(
+                        ctime=(0, 0),
+                        mtime=(0, 0),
+                        dev=0,
+                        ino=0,
+                        mode=mode,
+                        uid=0,
+                        gid=0,
+                        size=0,
+                        sha=sha,
+                    )
+                else:
+                    # If we also updated worktree, use actual stat
+                    from .index import index_entry_from_stat
+
+                    st = os.lstat(full_path)
+                    new_entry = index_entry_from_stat(st, sha, mode)
+
+                index[tree_path] = new_entry
+                index.write()
+
+
+def switch(
+    repo: str | os.PathLike[str] | Repo,
+    target: str | bytes | Commit | Tag,
+    create: str | bytes | None = None,
+    force: bool = False,
+    detach: bool = False,
+) -> None:
+    """Switch branches.
 
-            def symlink_wrapper(
-                source: str | bytes | os.PathLike[str],
-                target: str | bytes | os.PathLike[str],
-            ) -> None:
-                symlink(source, target)  # type: ignore[arg-type,unused-ignore]
+    This is similar to 'git switch', allowing you to switch to a different
+    branch or commit, updating both HEAD and the working tree.
+
+    Args:
+      repo: Path to repository or repository object
+      target: Branch name, tag, or commit SHA to switch to
+      create: Create a new branch at target before switching (like git switch -c)
+      force: Force switch even if there are local changes
+      detach: Switch to a commit in detached HEAD state (like git switch --detach)
+
+    Raises:
+      CheckoutError: If switch cannot be performed due to conflicts
+      KeyError: If the target reference cannot be found
+      ValueError: If both create and detach are specified
+    """
+    if create and detach:
+        raise ValueError("Cannot use both create and detach options")
+
+    with open_repo_closing(repo) as r:
+        # Store the original target for later reference checks
+        original_target = target
 
-            symlink_fn = symlink_wrapper
+        if isinstance(target, str):
+            target_bytes = target.encode(DEFAULT_ENCODING)
+        elif isinstance(target, bytes):
+            target_bytes = target
         else:
+            # For Commit/Tag objects, we'll use their SHA
+            target_bytes = target.id
 
-            def symlink_fallback(
-                source: str | bytes | os.PathLike[str],
-                target: str | bytes | os.PathLike[str],
-            ) -> None:
-                mode = "w" + ("b" if isinstance(source, bytes) else "")
-                with open(target, mode) as f:
-                    f.write(source)
+        if isinstance(create, str):
+            create = create.encode(DEFAULT_ENCODING)
 
-            symlink_fn = symlink_fallback
+        # Parse the target to get the commit
+        target_commit = parse_commit(r, original_target)
+        target_tree_id = target_commit.tree
 
-        # Get blob normalizer for line ending conversion
-        blob_normalizer = r.get_blob_normalizer()
+        # Get current HEAD tree for comparison
+        current_tree_id = _get_current_head_tree(r)
+
+        # Check for uncommitted changes if not forcing
+        if current_tree_id is not None:
+            _check_uncommitted_changes(r, target_tree_id, force)
 
         # Update working tree
-        tree_change_iterator: Iterator[TreeChange] = tree_changes(
-            r.object_store, current_tree_id, target_tree_id
-        )
-        update_working_tree(
-            r,
-            current_tree_id,
-            target_tree_id,
-            change_iterator=tree_change_iterator,
-            honor_filemode=honor_filemode,
-            validate_path_element=validate_path_element,
-            symlink_fn=symlink_fn,
-            force_remove_untracked=force,
-            blob_normalizer=blob_normalizer,
-            allow_overwrite_modified=force,
-        )
+        _perform_tree_switch(r, current_tree_id, target_tree_id, force)
 
         # Update HEAD
-        if new_branch:
+        if create:
             # Create new branch and switch to it
-            branch_create(r, new_branch, objectish=target_commit.id.decode("ascii"))
-            update_head(r, new_branch)
+            branch_create(r, create, objectish=target_commit.id.decode("ascii"))
+            update_head(r, create)
 
             # Set up tracking if creating from a remote branch
+            from .refs import LOCAL_REMOTE_PREFIX, local_branch_name, parse_remote_ref
+
             if isinstance(original_target, bytes) and target_bytes.startswith(
                 LOCAL_REMOTE_PREFIX
             ):
@@ -5341,11 +5630,14 @@ def checkout(
                     remote_name, branch_name = parse_remote_ref(target_bytes)
                     # Set tracking to refs/heads/<branch> on the remote
                     set_branch_tracking(
-                        r, new_branch, remote_name, local_branch_name(branch_name)
+                        r, create, remote_name, local_branch_name(branch_name)
                     )
                 except ValueError:
                     # Invalid remote ref format, skip tracking setup
                     pass
+        elif detach:
+            # Detached HEAD mode
+            update_head(r, target_commit.id.decode("ascii"), detached=True)
         else:
             # Check if target is a branch name (with or without refs/heads/ prefix)
             branch_ref = None
@@ -5369,8 +5661,12 @@ def checkout(
                 # It's a branch - update HEAD symbolically
                 update_head(r, branch_ref)
             else:
-                # It's a tag, other ref, or commit SHA - detached HEAD
-                update_head(r, target_commit.id.decode("ascii"), detached=True)
+                # It's a tag, other ref, or commit SHA
+                # In git switch, this would be an error unless --detach is used
+                raise CheckoutError(
+                    f"'{target_bytes.decode(DEFAULT_ENCODING)}' is not a branch. "
+                    "Use detach=True to switch to a commit in detached HEAD state."
+                )
 
 
 def reset_file(

+ 151 - 0
tests/test_porcelain.py

@@ -4559,6 +4559,157 @@ class CheckoutTests(PorcelainTestCase):
         remote_repo.close()
 
 
+class RestoreTests(PorcelainTestCase):
+    """Tests for the restore command."""
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._sha, self._foo_path = _commit_file_with_content(
+            self.repo, "foo", "original\n"
+        )
+
+    def test_restore_worktree_from_index(self) -> None:
+        # Modify the working tree file
+        with open(self._foo_path, "w") as f:
+            f.write("modified\n")
+
+        # Restore from index (should restore to original)
+        porcelain.restore(self.repo, paths=["foo"])
+
+        with open(self._foo_path) as f:
+            content = f.read()
+        self.assertEqual("original\n", content)
+
+    def test_restore_worktree_from_head(self) -> None:
+        # Modify and stage the file
+        with open(self._foo_path, "w") as f:
+            f.write("staged\n")
+        porcelain.add(self.repo, paths=[self._foo_path])
+
+        # Now modify it again in worktree
+        with open(self._foo_path, "w") as f:
+            f.write("worktree\n")
+
+        # Restore from HEAD (should restore to original, not staged)
+        porcelain.restore(self.repo, paths=["foo"], source="HEAD")
+
+        with open(self._foo_path) as f:
+            content = f.read()
+        self.assertEqual("original\n", content)
+
+    def test_restore_staged_from_head(self) -> None:
+        # Modify and stage the file
+        with open(self._foo_path, "w") as f:
+            f.write("staged\n")
+        porcelain.add(self.repo, paths=[self._foo_path])
+
+        # Verify it's staged
+        status = list(porcelain.status(self.repo))
+        self.assertEqual(
+            [{"add": [], "delete": [], "modify": [b"foo"]}, [], []], status
+        )
+
+        # Restore staged from HEAD
+        porcelain.restore(self.repo, paths=["foo"], staged=True, worktree=False)
+
+        # Verify it's no longer staged
+        status = list(porcelain.status(self.repo))
+        # Now it should show as unstaged modification
+        self.assertEqual(
+            [{"add": [], "delete": [], "modify": []}, [b"foo"], []], status
+        )
+
+    def test_restore_both_staged_and_worktree(self) -> None:
+        # Modify and stage the file
+        with open(self._foo_path, "w") as f:
+            f.write("staged\n")
+        porcelain.add(self.repo, paths=[self._foo_path])
+
+        # Now modify it again in worktree
+        with open(self._foo_path, "w") as f:
+            f.write("worktree\n")
+
+        # Restore both from HEAD
+        porcelain.restore(self.repo, paths=["foo"], staged=True, worktree=True)
+
+        # Verify content is restored
+        with open(self._foo_path) as f:
+            content = f.read()
+        self.assertEqual("original\n", content)
+
+        # Verify nothing is staged
+        status = list(porcelain.status(self.repo))
+        self.assertEqual([{"add": [], "delete": [], "modify": []}, [], []], status)
+
+    def test_restore_nonexistent_path(self) -> None:
+        with self.assertRaises(porcelain.CheckoutError):
+            porcelain.restore(self.repo, paths=["nonexistent"])
+
+
+class SwitchTests(PorcelainTestCase):
+    """Tests for the switch command."""
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._sha, self._foo_path = _commit_file_with_content(
+            self.repo, "foo", "hello\n"
+        )
+        porcelain.branch_create(self.repo, "dev")
+
+    def test_switch_to_existing_branch(self) -> None:
+        self.assertEqual(b"master", porcelain.active_branch(self.repo))
+        porcelain.switch(self.repo, "dev")
+        self.assertEqual(b"dev", porcelain.active_branch(self.repo))
+
+    def test_switch_to_non_existing_branch(self) -> None:
+        self.assertEqual(b"master", porcelain.active_branch(self.repo))
+
+        with self.assertRaises(KeyError):
+            porcelain.switch(self.repo, "nonexistent")
+
+        self.assertEqual(b"master", porcelain.active_branch(self.repo))
+
+    def test_switch_with_create(self) -> None:
+        self.assertEqual(b"master", porcelain.active_branch(self.repo))
+        porcelain.switch(self.repo, "master", create="feature")
+        self.assertEqual(b"feature", porcelain.active_branch(self.repo))
+
+    def test_switch_with_detach(self) -> None:
+        self.assertEqual(b"master", porcelain.active_branch(self.repo))
+        porcelain.switch(self.repo, self._sha.decode(), detach=True)
+        # In detached HEAD state, active_branch raises IndexError
+        with self.assertRaises(IndexError):
+            porcelain.active_branch(self.repo)
+
+    def test_switch_with_uncommitted_changes(self) -> None:
+        # Modify the file
+        with open(self._foo_path, "a") as f:
+            f.write("new content\n")
+        porcelain.add(self.repo, paths=[self._foo_path])
+
+        # Switch should fail due to uncommitted changes
+        with self.assertRaises(porcelain.CheckoutError):
+            porcelain.switch(self.repo, "dev")
+
+        # Should still be on master
+        self.assertEqual(b"master", porcelain.active_branch(self.repo))
+
+    def test_switch_with_force(self) -> None:
+        # Modify the file
+        with open(self._foo_path, "a") as f:
+            f.write("new content\n")
+        porcelain.add(self.repo, paths=[self._foo_path])
+
+        # Force switch should work
+        porcelain.switch(self.repo, "dev", force=True)
+        self.assertEqual(b"dev", porcelain.active_branch(self.repo))
+
+    def test_switch_to_commit_without_detach(self) -> None:
+        # Switching to a commit SHA without --detach should fail
+        with self.assertRaises(porcelain.CheckoutError):
+            porcelain.switch(self.repo, self._sha.decode())
+
+
 class GeneralCheckoutTests(PorcelainTestCase):
     """Tests for the general checkout function that handles branches, tags, and commits."""