Răsfoiți Sursa

Implement basic restore and switch commands

Fixes #1777
Jelmer Vernooij 1 lună în urmă
părinte
comite
fae481dd65
4 a modificat fișierele cu 628 adăugiri și 30 ștergeri
  1. 3 0
      NEWS
  2. 113 0
      dulwich/cli.py
  3. 361 30
      dulwich/porcelain.py
  4. 151 0
      tests/test_porcelain.py

+ 3 - 0
NEWS

@@ -32,6 +32,9 @@
    Dulwich version, and installed dependencies with their versions.
    (Jelmer Vernooij, #1835)
 
+ * Add basic ``dulwich restore`` and ``dulwich switch``
+   commands. (Jelmer Vernooij, #1777)
+
 0.24.10	2025-11-10
 
  * Fix compatibility with python 3.9. (Jelmer Vernooij, #1991)

+ 113 - 0
dulwich/cli.py

@@ -3841,6 +3841,117 @@ class cmd_checkout(Command):
         return 0
 
 
+class cmd_restore(Command):
+    """Restore working tree files."""
+
+    def run(self, args: Sequence[str]) -> int | None:
+        """Execute the restore command.
+
+        Args:
+            args: Command line arguments
+        """
+        parser = argparse.ArgumentParser()
+        parser.add_argument(
+            "paths",
+            nargs="+",
+            type=str,
+            help="Paths to restore",
+        )
+        parser.add_argument(
+            "-s",
+            "--source",
+            type=str,
+            help="Restore from a specific commit (default: HEAD for --staged, index for worktree)",
+        )
+        parser.add_argument(
+            "--staged",
+            action="store_true",
+            help="Restore files in the index",
+        )
+        parser.add_argument(
+            "--worktree",
+            action="store_true",
+            help="Restore files in the working tree",
+        )
+        parsed_args = parser.parse_args(args)
+
+        # If neither --staged nor --worktree is specified, default to --worktree
+        if not parsed_args.staged and not parsed_args.worktree:
+            worktree = True
+            staged = False
+        else:
+            worktree = parsed_args.worktree
+            staged = parsed_args.staged
+
+        try:
+            porcelain.restore(
+                ".",
+                paths=parsed_args.paths,
+                source=parsed_args.source,
+                staged=staged,
+                worktree=worktree,
+            )
+        except porcelain.CheckoutError as e:
+            sys.stderr.write(f"{e}\n")
+            return 1
+        return 0
+
+
+class cmd_switch(Command):
+    """Switch branches."""
+
+    def run(self, args: Sequence[str]) -> int | None:
+        """Execute the switch command.
+
+        Args:
+            args: Command line arguments
+        """
+        parser = argparse.ArgumentParser()
+        parser.add_argument(
+            "target",
+            type=str,
+            help="Branch or commit to switch to",
+        )
+        parser.add_argument(
+            "-c",
+            "--create",
+            type=str,
+            help="Create a new branch at the target and switch to it",
+        )
+        parser.add_argument(
+            "-f",
+            "--force",
+            action="store_true",
+            help="Force switch even if there are local changes",
+        )
+        parser.add_argument(
+            "-d",
+            "--detach",
+            action="store_true",
+            help="Switch to a commit in detached HEAD state",
+        )
+        parsed_args = parser.parse_args(args)
+
+        if not parsed_args.target:
+            logger.error(
+                "Usage: dulwich switch TARGET [-c NEW_BRANCH] [--force] [--detach]"
+            )
+            return 1
+
+        try:
+            porcelain.switch(
+                ".",
+                target=parsed_args.target,
+                create=parsed_args.create,
+                force=parsed_args.force,
+                detach=parsed_args.detach,
+            )
+        except porcelain.CheckoutError as e:
+            sys.stderr.write(f"{e}\n")
+            return 1
+        return 0
+
+
 class cmd_stash_list(Command):
     """List stash entries."""
 
@@ -6550,6 +6661,7 @@ commands = {
     "repack": cmd_repack,
     "replace": cmd_replace,
     "reset": cmd_reset,
+    "restore": cmd_restore,
     "revert": cmd_revert,
     "rev-list": cmd_rev_list,
     "rm": cmd_rm,
@@ -6561,6 +6673,7 @@ commands = {
     "status": cmd_status,
     "stripspace": cmd_stripspace,
     "shortlog": cmd_shortlog,
+    "switch": cmd_switch,
     "symbolic-ref": cmd_symbolic_ref,
     "submodule": cmd_submodule,
     "tag": cmd_tag,

+ 361 - 30
dulwich/porcelain.py

@@ -5082,6 +5082,63 @@ def check_ignore(
                 yield _quote_path(output_path) if quote_path else output_path
 
 
+def _check_uncommitted_changes(
+    repo: Repo, target_tree_id: bytes, force: bool = False
+) -> None:
+    """Check for uncommitted changes that would conflict with a checkout/switch.
+
+    Args:
+      repo: Repository object
+      target_tree_id: Tree ID to check conflicts against
+      force: If True, skip the check
+
+    Raises:
+      CheckoutError: If there are conflicting local changes
+    """
+    if force:
+        return
+
+    # Get current HEAD tree for comparison
+    try:
+        current_head = repo.refs[b"HEAD"]
+        current_commit = repo[current_head]
+        assert isinstance(current_commit, Commit), "Expected a Commit object"
+    except KeyError:
+        # No HEAD yet (empty repo)
+        return
+
+    status_report = status(repo)
+    changes = []
+    # staged is a dict with 'add', 'delete', 'modify' keys
+    if isinstance(status_report.staged, dict):
+        changes.extend(status_report.staged.get("add", []))
+        changes.extend(status_report.staged.get("delete", []))
+        changes.extend(status_report.staged.get("modify", []))
+    # unstaged is a list
+    changes.extend(status_report.unstaged)
+
+    if changes:
+        # Check if any changes would conflict with checkout
+        target_tree_obj = repo[target_tree_id]
+        assert isinstance(target_tree_obj, Tree), "Expected a Tree object"
+        target_tree = target_tree_obj
+        for change in changes:
+            if isinstance(change, str):
+                change = change.encode(DEFAULT_ENCODING)
+
+            try:
+                target_tree.lookup_path(repo.object_store.__getitem__, change)
+            except KeyError:
+                # File doesn't exist in target tree - change can be preserved
+                pass
+            else:
+                # File exists in target tree - would overwrite local changes
+                raise CheckoutError(
+                    f"Your local changes to '{change.decode()}' would be "
+                    "overwritten. Please commit or stash before switching."
+                )
+
+
 def update_head(
     repo: RepoPath,
     target: str | bytes,
@@ -5246,36 +5303,8 @@ def checkout(
             current_tree_id = None
 
         # Check for uncommitted changes if not forcing
-        if not force and current_tree_id is not None:
-            status_report = status(r)
-            changes = []
-            # staged is a dict with 'add', 'delete', 'modify' keys
-            if isinstance(status_report.staged, dict):
-                changes.extend(status_report.staged.get("add", []))
-                changes.extend(status_report.staged.get("delete", []))
-                changes.extend(status_report.staged.get("modify", []))
-            # unstaged is a list
-            changes.extend(status_report.unstaged)
-            if changes:
-                # Check if any changes would conflict with checkout
-                target_tree_obj = r[target_tree_id]
-                assert isinstance(target_tree_obj, Tree), "Expected a Tree object"
-                target_tree = target_tree_obj
-                for change in changes:
-                    if isinstance(change, str):
-                        change = change.encode(DEFAULT_ENCODING)
-
-                    try:
-                        target_tree.lookup_path(r.object_store.__getitem__, change)
-                    except KeyError:
-                        # File doesn't exist in target tree - change can be preserved
-                        pass
-                    else:
-                        # File exists in target tree - would overwrite local changes
-                        raise CheckoutError(
-                            f"Your local changes to '{change.decode()}' would be "
-                            "overwritten by checkout. Please commit or stash before switching."
-                        )
+        if current_tree_id is not None:
+            _check_uncommitted_changes(r, target_tree_id, force)
 
         # Get configuration for working directory update
         config = r.get_config()
@@ -5373,6 +5402,308 @@ def checkout(
                 update_head(r, target_commit.id.decode("ascii"), detached=True)
 
 
+def restore(
+    repo: str | os.PathLike[str] | Repo,
+    paths: list[bytes | str],
+    source: str | bytes | Commit | Tag | None = None,
+    staged: bool = False,
+    worktree: bool = True,
+) -> None:
+    """Restore working tree files.
+
+    This is similar to 'git restore', allowing you to restore specific files
+    from a commit or the index without changing HEAD.
+
+    Args:
+      repo: Path to repository or repository object
+      paths: List of specific paths to restore
+      source: Branch name, tag, or commit SHA to restore from. If None, restores
+              staged files from HEAD, or worktree files from index
+      staged: Restore files in the index (--staged)
+      worktree: Restore files in the working tree (default: True)
+
+    Raises:
+      CheckoutError: If restore cannot be performed
+      ValueError: If neither staged nor worktree is specified
+      KeyError: If the source reference cannot be found
+    """
+    if not staged and not worktree:
+        raise ValueError("At least one of staged or worktree must be True")
+
+    with open_repo_closing(repo) as r:
+        from .index import _fs_to_tree_path, build_file_from_blob
+
+        wt = r.get_worktree()
+
+        # Determine the source tree
+        if source is None:
+            if staged:
+                # Restoring staged files from HEAD
+                try:
+                    source = r.refs[b"HEAD"]
+                except KeyError:
+                    raise CheckoutError("No HEAD reference found")
+            elif worktree:
+                # Restoring worktree files from index
+                from .index import ConflictedIndexEntry, IndexEntry
+
+                index = r.open_index()
+                for path in paths:
+                    if isinstance(path, str):
+                        tree_path = _fs_to_tree_path(path)
+                    else:
+                        tree_path = path
+
+                    try:
+                        index_entry = index[tree_path]
+                        if isinstance(index_entry, ConflictedIndexEntry):
+                            raise CheckoutError(
+                                f"Path '{path if isinstance(path, str) else path.decode(DEFAULT_ENCODING)}' has conflicts"
+                            )
+                        blob = r[index_entry.sha]
+                        assert isinstance(blob, Blob), "Expected a Blob object"
+
+                        full_path = os.path.join(os.fsencode(r.path), tree_path)
+                        mode = index_entry.mode
+
+                        # Use build_file_from_blob to write the file
+                        build_file_from_blob(blob, mode, full_path)
+                    except KeyError:
+                        # Path doesn't exist in index
+                        raise CheckoutError(
+                            f"Path '{path if isinstance(path, str) else path.decode(DEFAULT_ENCODING)}' not in index"
+                        )
+                return
+
+        # source is not None at this point
+        assert source is not None
+        # Get the source tree
+        source_tree = parse_tree(r, treeish=source)
+
+        # Restore specified paths from source tree
+        for path in paths:
+            if isinstance(path, str):
+                tree_path = _fs_to_tree_path(path)
+            else:
+                tree_path = path
+
+            try:
+                # Look up the path in the source tree
+                mode, sha = source_tree.lookup_path(
+                    r.object_store.__getitem__, tree_path
+                )
+                blob = r[sha]
+                assert isinstance(blob, Blob), "Expected a Blob object"
+            except KeyError:
+                # Path doesn't exist in source tree
+                raise CheckoutError(
+                    f"Path '{path if isinstance(path, str) else path.decode(DEFAULT_ENCODING)}' not found in source"
+                )
+
+            full_path = os.path.join(os.fsencode(r.path), tree_path)
+
+            if worktree:
+                # Use build_file_from_blob to restore to working tree
+                build_file_from_blob(blob, mode, full_path)
+
+            if staged:
+                # Update the index with the blob from source
+                from .index import IndexEntry
+
+                index = r.open_index()
+
+                # When only updating staged (not worktree), we want to reset the index
+                # to the source, but invalidate the stat cache so Git knows to check
+                # the worktree file. Use zeros for stat fields.
+                if not worktree:
+                    # Invalidate stat cache by using zeros
+                    new_entry = IndexEntry(
+                        ctime=(0, 0),
+                        mtime=(0, 0),
+                        dev=0,
+                        ino=0,
+                        mode=mode,
+                        uid=0,
+                        gid=0,
+                        size=0,
+                        sha=sha,
+                    )
+                else:
+                    # If we also updated worktree, use actual stat
+                    from .index import index_entry_from_stat
+
+                    st = os.lstat(full_path)
+                    new_entry = index_entry_from_stat(st, sha, mode)
+
+                index[tree_path] = new_entry
+                index.write()
+
+
+def switch(
+    repo: str | os.PathLike[str] | Repo,
+    target: str | bytes | Commit | Tag,
+    create: str | bytes | None = None,
+    force: bool = False,
+    detach: bool = False,
+) -> None:
+    """Switch branches.
+
+    This is similar to 'git switch', allowing you to switch to a different
+    branch or commit, updating both HEAD and the working tree.
+
+    Args:
+      repo: Path to repository or repository object
+      target: Branch name, tag, or commit SHA to switch to
+      create: Create a new branch at target before switching (like git switch -c)
+      force: Force switch even if there are local changes
+      detach: Switch to a commit in detached HEAD state (like git switch --detach)
+
+    Raises:
+      CheckoutError: If switch cannot be performed due to conflicts
+      KeyError: If the target reference cannot be found
+      ValueError: If both create and detach are specified
+    """
+    if create and detach:
+        raise ValueError("Cannot use both create and detach options")
+
+    with open_repo_closing(repo) as r:
+        # Store the original target for later reference checks
+        original_target = target
+
+        if isinstance(target, str):
+            target_bytes = target.encode(DEFAULT_ENCODING)
+        elif isinstance(target, bytes):
+            target_bytes = target
+        else:
+            # For Commit/Tag objects, we'll use their SHA
+            target_bytes = target.id
+
+        if isinstance(create, str):
+            create = create.encode(DEFAULT_ENCODING)
+
+        # Parse the target to get the commit
+        target_commit = parse_commit(r, original_target)
+        target_tree_id = target_commit.tree
+
+        # Get current HEAD tree for comparison
+        try:
+            current_head = r.refs[b"HEAD"]
+            current_commit = r[current_head]
+            assert isinstance(current_commit, Commit), "Expected a Commit object"
+            current_tree_id = current_commit.tree
+        except KeyError:
+            # No HEAD yet (empty repo)
+            current_tree_id = None
+
+        # Check for uncommitted changes if not forcing
+        if current_tree_id is not None:
+            _check_uncommitted_changes(r, target_tree_id, force)
+
+        # Get configuration for working directory update
+        config = r.get_config()
+        honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
+
+        if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):
+            validate_path_element = validate_path_element_ntfs
+        else:
+            validate_path_element = validate_path_element_default
+
+        if config.get_boolean(b"core", b"symlinks", True):
+
+            def symlink_wrapper(
+                source: str | bytes | os.PathLike[str],
+                target: str | bytes | os.PathLike[str],
+            ) -> None:
+                symlink(source, target)  # type: ignore[arg-type,unused-ignore]
+
+            symlink_fn = symlink_wrapper
+        else:
+
+            def symlink_fallback(
+                source: str | bytes | os.PathLike[str],
+                target: str | bytes | os.PathLike[str],
+            ) -> None:
+                mode = "w" + ("b" if isinstance(source, bytes) else "")
+                with open(target, mode) as f:
+                    f.write(source)
+
+            symlink_fn = symlink_fallback
+
+        # Get blob normalizer for line ending conversion
+        blob_normalizer = r.get_blob_normalizer()
+
+        # Update working tree
+        tree_change_iterator: Iterator[TreeChange] = tree_changes(
+            r.object_store, current_tree_id, target_tree_id
+        )
+        update_working_tree(
+            r,
+            current_tree_id,
+            target_tree_id,
+            change_iterator=tree_change_iterator,
+            honor_filemode=honor_filemode,
+            validate_path_element=validate_path_element,
+            symlink_fn=symlink_fn,
+            force_remove_untracked=force,
+            blob_normalizer=blob_normalizer,
+            allow_overwrite_modified=force,
+        )
+
+        # Update HEAD
+        if create:
+            # Create new branch and switch to it
+            branch_create(r, create, objectish=target_commit.id.decode("ascii"))
+            update_head(r, create)
+
+            # Set up tracking if creating from a remote branch
+            from .refs import LOCAL_REMOTE_PREFIX, local_branch_name, parse_remote_ref
+
+            if isinstance(original_target, bytes) and target_bytes.startswith(
+                LOCAL_REMOTE_PREFIX
+            ):
+                try:
+                    remote_name, branch_name = parse_remote_ref(target_bytes)
+                    # Set tracking to refs/heads/<branch> on the remote
+                    set_branch_tracking(
+                        r, create, remote_name, local_branch_name(branch_name)
+                    )
+                except ValueError:
+                    # Invalid remote ref format, skip tracking setup
+                    pass
+        elif detach:
+            # Detached HEAD mode
+            update_head(r, target_commit.id.decode("ascii"), detached=True)
+        else:
+            # Check if target is a branch name (with or without refs/heads/ prefix)
+            branch_ref = None
+            if (
+                isinstance(original_target, (str, bytes))
+                and target_bytes in r.refs.keys()
+            ):
+                if target_bytes.startswith(LOCAL_BRANCH_PREFIX):
+                    branch_ref = target_bytes
+            else:
+                # Try adding refs/heads/ prefix
+                potential_branch = (
+                    _make_branch_ref(target_bytes)
+                    if isinstance(original_target, (str, bytes))
+                    else None
+                )
+                if potential_branch in r.refs.keys():
+                    branch_ref = potential_branch
+
+            if branch_ref:
+                # It's a branch - update HEAD symbolically
+                update_head(r, branch_ref)
+            else:
+                # It's a tag, other ref, or commit SHA
+                # In git switch, this would be an error unless --detach is used
+                raise CheckoutError(
+                    f"'{target_bytes.decode(DEFAULT_ENCODING)}' is not a branch. "
+                    "Use detach=True to switch to a commit in detached HEAD state."
+                )
+
+
 def reset_file(
     repo: Repo,
     file_path: str,

+ 151 - 0
tests/test_porcelain.py

@@ -4559,6 +4559,157 @@ class CheckoutTests(PorcelainTestCase):
         remote_repo.close()
 
 
+class RestoreTests(PorcelainTestCase):
+    """Tests for the restore command."""
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._sha, self._foo_path = _commit_file_with_content(
+            self.repo, "foo", "original\n"
+        )
+
+    def test_restore_worktree_from_index(self) -> None:
+        # Modify the working tree file
+        with open(self._foo_path, "w") as f:
+            f.write("modified\n")
+
+        # Restore from index (should restore to original)
+        porcelain.restore(self.repo, paths=["foo"])
+
+        with open(self._foo_path) as f:
+            content = f.read()
+        self.assertEqual("original\n", content)
+
+    def test_restore_worktree_from_head(self) -> None:
+        # Modify and stage the file
+        with open(self._foo_path, "w") as f:
+            f.write("staged\n")
+        porcelain.add(self.repo, paths=[self._foo_path])
+
+        # Now modify it again in worktree
+        with open(self._foo_path, "w") as f:
+            f.write("worktree\n")
+
+        # Restore from HEAD (should restore to original, not staged)
+        porcelain.restore(self.repo, paths=["foo"], source="HEAD")
+
+        with open(self._foo_path) as f:
+            content = f.read()
+        self.assertEqual("original\n", content)
+
+    def test_restore_staged_from_head(self) -> None:
+        # Modify and stage the file
+        with open(self._foo_path, "w") as f:
+            f.write("staged\n")
+        porcelain.add(self.repo, paths=[self._foo_path])
+
+        # Verify it's staged
+        status = list(porcelain.status(self.repo))
+        self.assertEqual(
+            [{"add": [], "delete": [], "modify": [b"foo"]}, [], []], status
+        )
+
+        # Restore staged from HEAD
+        porcelain.restore(self.repo, paths=["foo"], staged=True, worktree=False)
+
+        # Verify it's no longer staged
+        status = list(porcelain.status(self.repo))
+        # Now it should show as unstaged modification
+        self.assertEqual(
+            [{"add": [], "delete": [], "modify": []}, [b"foo"], []], status
+        )
+
+    def test_restore_both_staged_and_worktree(self) -> None:
+        # Modify and stage the file
+        with open(self._foo_path, "w") as f:
+            f.write("staged\n")
+        porcelain.add(self.repo, paths=[self._foo_path])
+
+        # Now modify it again in worktree
+        with open(self._foo_path, "w") as f:
+            f.write("worktree\n")
+
+        # Restore both from HEAD
+        porcelain.restore(self.repo, paths=["foo"], staged=True, worktree=True)
+
+        # Verify content is restored
+        with open(self._foo_path) as f:
+            content = f.read()
+        self.assertEqual("original\n", content)
+
+        # Verify nothing is staged
+        status = list(porcelain.status(self.repo))
+        self.assertEqual([{"add": [], "delete": [], "modify": []}, [], []], status)
+
+    def test_restore_nonexistent_path(self) -> None:
+        with self.assertRaises(porcelain.CheckoutError):
+            porcelain.restore(self.repo, paths=["nonexistent"])
+
+
+class SwitchTests(PorcelainTestCase):
+    """Tests for the switch command."""
+
+    def setUp(self) -> None:
+        super().setUp()
+        self._sha, self._foo_path = _commit_file_with_content(
+            self.repo, "foo", "hello\n"
+        )
+        porcelain.branch_create(self.repo, "dev")
+
+    def test_switch_to_existing_branch(self) -> None:
+        self.assertEqual(b"master", porcelain.active_branch(self.repo))
+        porcelain.switch(self.repo, "dev")
+        self.assertEqual(b"dev", porcelain.active_branch(self.repo))
+
+    def test_switch_to_non_existing_branch(self) -> None:
+        self.assertEqual(b"master", porcelain.active_branch(self.repo))
+
+        with self.assertRaises(KeyError):
+            porcelain.switch(self.repo, "nonexistent")
+
+        self.assertEqual(b"master", porcelain.active_branch(self.repo))
+
+    def test_switch_with_create(self) -> None:
+        self.assertEqual(b"master", porcelain.active_branch(self.repo))
+        porcelain.switch(self.repo, "master", create="feature")
+        self.assertEqual(b"feature", porcelain.active_branch(self.repo))
+
+    def test_switch_with_detach(self) -> None:
+        self.assertEqual(b"master", porcelain.active_branch(self.repo))
+        porcelain.switch(self.repo, self._sha.decode(), detach=True)
+        # In detached HEAD state, active_branch raises IndexError
+        with self.assertRaises(IndexError):
+            porcelain.active_branch(self.repo)
+
+    def test_switch_with_uncommitted_changes(self) -> None:
+        # Modify the file
+        with open(self._foo_path, "a") as f:
+            f.write("new content\n")
+        porcelain.add(self.repo, paths=[self._foo_path])
+
+        # Switch should fail due to uncommitted changes
+        with self.assertRaises(porcelain.CheckoutError):
+            porcelain.switch(self.repo, "dev")
+
+        # Should still be on master
+        self.assertEqual(b"master", porcelain.active_branch(self.repo))
+
+    def test_switch_with_force(self) -> None:
+        # Modify the file
+        with open(self._foo_path, "a") as f:
+            f.write("new content\n")
+        porcelain.add(self.repo, paths=[self._foo_path])
+
+        # Force switch should work
+        porcelain.switch(self.repo, "dev", force=True)
+        self.assertEqual(b"dev", porcelain.active_branch(self.repo))
+
+    def test_switch_to_commit_without_detach(self) -> None:
+        # Switching to a commit SHA without --detach should fail
+        with self.assertRaises(porcelain.CheckoutError):
+            porcelain.switch(self.repo, self._sha.decode())
+
+
 class GeneralCheckoutTests(PorcelainTestCase):
     """Tests for the general checkout function that handles branches, tags, and commits."""