浏览代码

Split out WorkTree from Repo

Jelmer Vernooij 1 月之前
父节点
当前提交
e9a0cbde16
共有 5 个文件被更改,包括 1104 次插入334 次删除
  1. 2 0
      NEWS
  2. 189 334
      dulwich/repo.py
  3. 502 0
      dulwich/worktree.py
  4. 1 0
      tests/__init__.py
  5. 410 0
      tests/test_worktree.py

+ 2 - 0
NEWS

@@ -1,5 +1,7 @@
 0.23.3	UNRELEASED
 0.23.3	UNRELEASED
 
 
+ * Split out ``WorkTree`` from ``Repo``. (Jelmer Vernooij)
+
  * Add support for ``-a`` argument to
  * Add support for ``-a`` argument to
    ``dulwich.cli.commit``. (Jelmer Vernooij)
    ``dulwich.cli.commit``. (Jelmer Vernooij)
 
 

+ 189 - 334
dulwich/repo.py

@@ -53,10 +53,10 @@ if TYPE_CHECKING:
     from .config import ConditionMatcher, ConfigFile, StackedConfig
     from .config import ConditionMatcher, ConfigFile, StackedConfig
     from .index import Index
     from .index import Index
     from .notes import Notes
     from .notes import Notes
+    from .worktree import WorkTree
 
 
+from . import replace_me
 from .errors import (
 from .errors import (
-    CommitError,
-    HookError,
     NoIndexPresent,
     NoIndexPresent,
     NotBlobError,
     NotBlobError,
     NotCommitError,
     NotCommitError,
@@ -987,6 +987,20 @@ class BaseRepo:
         with f:
         with f:
             return [line.strip() for line in f.readlines() if line.strip()]
             return [line.strip() for line in f.readlines() if line.strip()]
 
 
+    def get_worktree(self) -> "WorkTree":
+        """Get the working tree for this repository.
+
+        Returns:
+            WorkTree instance for performing working tree operations
+
+        Raises:
+            NotImplementedError: If the repository doesn't support working trees
+        """
+        raise NotImplementedError(
+            "Working tree operations not supported by this repository type"
+        )
+
+    @replace_me()
     def do_commit(
     def do_commit(
         self,
         self,
         message: Optional[bytes] = None,
         message: Optional[bytes] = None,
@@ -1034,158 +1048,21 @@ class BaseRepo:
         Returns:
         Returns:
           New commit SHA1
           New commit SHA1
         """
         """
-        try:
-            if not no_verify:
-                self.hooks["pre-commit"].execute()
-        except HookError as exc:
-            raise CommitError(exc) from exc
-        except KeyError:  # no hook defined, silent fallthrough
-            pass
-
-        c = Commit()
-        if tree is None:
-            index = self.open_index()
-            c.tree = index.commit(self.object_store)
-        else:
-            if len(tree) != 40:
-                raise ValueError("tree must be a 40-byte hex sha string")
-            c.tree = tree
-
-        config = self.get_config_stack()
-        if merge_heads is None:
-            merge_heads = self._read_heads("MERGE_HEAD")
-        if committer is None:
-            committer = get_user_identity(config, kind="COMMITTER")
-        check_user_identity(committer)
-        c.committer = committer
-        if commit_timestamp is None:
-            # FIXME: Support GIT_COMMITTER_DATE environment variable
-            commit_timestamp = time.time()
-        c.commit_time = int(commit_timestamp)
-        if commit_timezone is None:
-            # FIXME: Use current user timezone rather than UTC
-            commit_timezone = 0
-        c.commit_timezone = commit_timezone
-        if author is None:
-            author = get_user_identity(config, kind="AUTHOR")
-        c.author = author
-        check_user_identity(author)
-        if author_timestamp is None:
-            # FIXME: Support GIT_AUTHOR_DATE environment variable
-            author_timestamp = commit_timestamp
-        c.author_time = int(author_timestamp)
-        if author_timezone is None:
-            author_timezone = commit_timezone
-        c.author_timezone = author_timezone
-        if encoding is None:
-            try:
-                encoding = config.get(("i18n",), "commitEncoding")
-            except KeyError:
-                pass  # No dice
-        if encoding is not None:
-            c.encoding = encoding
-        # Store original message (might be callable)
-        original_message = message
-        message = None  # Will be set later after parents are set
-
-        # Check if we should sign the commit
-        should_sign = sign
-        if sign is None:
-            # Check commit.gpgSign configuration when sign is not explicitly set
-            config = self.get_config_stack()
-            try:
-                should_sign = config.get_boolean((b"commit",), b"gpgSign")
-            except KeyError:
-                should_sign = False  # Default to not signing if no config
-        keyid = sign if isinstance(sign, str) else None
-
-        if ref is None:
-            # Create a dangling commit
-            c.parents = merge_heads
-        else:
-            try:
-                old_head = self.refs[ref]
-                c.parents = [old_head, *merge_heads]
-            except KeyError:
-                c.parents = merge_heads
-
-        # Handle message after parents are set
-        if callable(original_message):
-            message = original_message(self, c)
-            if message is None:
-                raise ValueError("Message callback returned None")
-        else:
-            message = original_message
-
-        if message is None:
-            # FIXME: Try to read commit message from .git/MERGE_MSG
-            raise ValueError("No commit message specified")
-
-        try:
-            if no_verify:
-                c.message = message
-            else:
-                c.message = self.hooks["commit-msg"].execute(message)
-                if c.message is None:
-                    c.message = message
-        except HookError as exc:
-            raise CommitError(exc) from exc
-        except KeyError:  # no hook defined, message not modified
-            c.message = message
-
-        if ref is None:
-            # Create a dangling commit
-            if should_sign:
-                c.sign(keyid)
-            self.object_store.add_object(c)
-        else:
-            try:
-                old_head = self.refs[ref]
-                if should_sign:
-                    c.sign(keyid)
-                self.object_store.add_object(c)
-                ok = self.refs.set_if_equals(
-                    ref,
-                    old_head,
-                    c.id,
-                    message=b"commit: " + message,
-                    committer=committer,
-                    timestamp=commit_timestamp,
-                    timezone=commit_timezone,
-                )
-            except KeyError:
-                c.parents = merge_heads
-                if should_sign:
-                    c.sign(keyid)
-                self.object_store.add_object(c)
-                ok = self.refs.add_if_new(
-                    ref,
-                    c.id,
-                    message=b"commit: " + message,
-                    committer=committer,
-                    timestamp=commit_timestamp,
-                    timezone=commit_timezone,
-                )
-            if not ok:
-                # Fail if the atomic compare-and-swap failed, leaving the
-                # commit and all its objects as garbage.
-                raise CommitError(f"{ref!r} changed during commit")
-
-        self._del_named_file("MERGE_HEAD")
-
-        try:
-            self.hooks["post-commit"].execute()
-        except HookError as e:  # silent failure
-            warnings.warn(f"post-commit hook failed: {e}", UserWarning)
-        except KeyError:  # no hook defined, silent fallthrough
-            pass
-
-        # Trigger auto GC if needed
-        from .gc import maybe_auto_gc
-
-        maybe_auto_gc(self)
-
-        return c.id
+        return self.get_worktree().commit(
+            message=message,
+            committer=committer,
+            author=author,
+            commit_timestamp=commit_timestamp,
+            commit_timezone=commit_timezone,
+            author_timestamp=author_timestamp,
+            author_timezone=author_timezone,
+            tree=tree,
+            encoding=encoding,
+            ref=ref,
+            merge_heads=merge_heads,
+            no_verify=no_verify,
+            sign=sign,
+        )
 
 
 
 
 def read_gitfile(f):
 def read_gitfile(f):
@@ -1350,6 +1227,16 @@ class Repo(BaseRepo):
         self.hooks["post-commit"] = PostCommitShellHook(self.controldir())
         self.hooks["post-commit"] = PostCommitShellHook(self.controldir())
         self.hooks["post-receive"] = PostReceiveShellHook(self.controldir())
         self.hooks["post-receive"] = PostReceiveShellHook(self.controldir())
 
 
+    def get_worktree(self) -> "WorkTree":
+        """Get the working tree for this repository.
+
+        Returns:
+            WorkTree instance for performing working tree operations
+        """
+        from .worktree import WorkTree
+
+        return WorkTree(self, self.path)
+
     def _write_reflog(
     def _write_reflog(
         self, ref, old_sha, new_sha, committer, timestamp, timezone, message
         self, ref, old_sha, new_sha, committer, timestamp, timezone, message
     ) -> None:
     ) -> None:
@@ -1547,6 +1434,7 @@ class Repo(BaseRepo):
         # missing index file, which is treated as empty.
         # missing index file, which is treated as empty.
         return not self.bare
         return not self.bare
 
 
+    @replace_me()
     def stage(
     def stage(
         self,
         self,
         fs_paths: Union[
         fs_paths: Union[
@@ -1558,117 +1446,16 @@ class Repo(BaseRepo):
         Args:
         Args:
           fs_paths: List of paths, relative to the repository path
           fs_paths: List of paths, relative to the repository path
         """
         """
-        root_path_bytes = os.fsencode(self.path)
-
-        if isinstance(fs_paths, (str, bytes, os.PathLike)):
-            fs_paths = [fs_paths]
-        fs_paths = list(fs_paths)
-
-        from .index import (
-            _fs_to_tree_path,
-            blob_from_path_and_stat,
-            index_entry_from_directory,
-            index_entry_from_stat,
-        )
-
-        index = self.open_index()
-        blob_normalizer = self.get_blob_normalizer()
-        for fs_path in fs_paths:
-            if not isinstance(fs_path, bytes):
-                fs_path = os.fsencode(fs_path)
-            if os.path.isabs(fs_path):
-                raise ValueError(
-                    f"path {fs_path!r} should be relative to "
-                    "repository root, not absolute"
-                )
-            tree_path = _fs_to_tree_path(fs_path)
-            full_path = os.path.join(root_path_bytes, fs_path)
-            try:
-                st = os.lstat(full_path)
-            except OSError:
-                # File no longer exists
-                try:
-                    del index[tree_path]
-                except KeyError:
-                    pass  # already removed
-            else:
-                if stat.S_ISDIR(st.st_mode):
-                    entry = index_entry_from_directory(st, full_path)
-                    if entry:
-                        index[tree_path] = entry
-                    else:
-                        try:
-                            del index[tree_path]
-                        except KeyError:
-                            pass
-                elif not stat.S_ISREG(st.st_mode) and not stat.S_ISLNK(st.st_mode):
-                    try:
-                        del index[tree_path]
-                    except KeyError:
-                        pass
-                else:
-                    blob = blob_from_path_and_stat(full_path, st)
-                    blob = blob_normalizer.checkin_normalize(blob, fs_path)
-                    self.object_store.add_object(blob)
-                    index[tree_path] = index_entry_from_stat(st, blob.id)
-        index.write()
+        return self.get_worktree().stage(fs_paths)
 
 
+    @replace_me()
     def unstage(self, fs_paths: list[str]) -> None:
     def unstage(self, fs_paths: list[str]) -> None:
         """Unstage specific file in the index
         """Unstage specific file in the index
         Args:
         Args:
           fs_paths: a list of files to unstage,
           fs_paths: a list of files to unstage,
             relative to the repository path.
             relative to the repository path.
         """
         """
-        from .index import IndexEntry, _fs_to_tree_path
-
-        index = self.open_index()
-        try:
-            tree_id = self[b"HEAD"].tree
-        except KeyError:
-            # no head mean no commit in the repo
-            for fs_path in fs_paths:
-                tree_path = _fs_to_tree_path(fs_path)
-                del index[tree_path]
-            index.write()
-            return
-
-        for fs_path in fs_paths:
-            tree_path = _fs_to_tree_path(fs_path)
-            try:
-                tree = self.object_store[tree_id]
-                assert isinstance(tree, Tree)
-                tree_entry = tree.lookup_path(self.object_store.__getitem__, tree_path)
-            except KeyError:
-                # if tree_entry didn't exist, this file was being added, so
-                # remove index entry
-                try:
-                    del index[tree_path]
-                    continue
-                except KeyError as exc:
-                    raise KeyError(f"file '{tree_path.decode()}' not in index") from exc
-
-            st = None
-            try:
-                st = os.lstat(os.path.join(self.path, fs_path))
-            except FileNotFoundError:
-                pass
-
-            index_entry = IndexEntry(
-                ctime=(self[b"HEAD"].commit_time, 0),
-                mtime=(self[b"HEAD"].commit_time, 0),
-                dev=st.st_dev if st else 0,
-                ino=st.st_ino if st else 0,
-                mode=tree_entry[0],
-                uid=st.st_uid if st else 0,
-                gid=st.st_gid if st else 0,
-                size=len(self[tree_entry[1]].data),
-                sha=tree_entry[1],
-                flags=0,
-                extended_flags=0,
-            )
-
-            index[tree_path] = index_entry
-        index.write()
+        return self.get_worktree().unstage(fs_paths)
 
 
     def clone(
     def clone(
         self,
         self,
@@ -1765,55 +1552,14 @@ class Repo(BaseRepo):
             raise
             raise
         return target
         return target
 
 
+    @replace_me()
     def reset_index(self, tree: Optional[bytes] = None):
     def reset_index(self, tree: Optional[bytes] = None):
         """Reset the index back to a specific tree.
         """Reset the index back to a specific tree.
 
 
         Args:
         Args:
           tree: Tree SHA to reset to, None for current HEAD tree.
           tree: Tree SHA to reset to, None for current HEAD tree.
         """
         """
-        from .index import (
-            build_index_from_tree,
-            symlink,
-            validate_path_element_default,
-            validate_path_element_hfs,
-            validate_path_element_ntfs,
-        )
-
-        if tree is None:
-            head = self[b"HEAD"]
-            if isinstance(head, Tag):
-                _cls, obj = head.object
-                head = self.get_object(obj)
-            tree = head.tree
-        config = self.get_config()
-        honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
-        if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):
-            validate_path_element = validate_path_element_ntfs
-        elif config.get_boolean(b"core", b"core.protectHFS", sys.platform == "darwin"):
-            validate_path_element = validate_path_element_hfs
-        else:
-            validate_path_element = validate_path_element_default
-        if config.get_boolean(b"core", b"symlinks", True):
-            symlink_fn = symlink
-        else:
-
-            def symlink_fn(source, target) -> None:  # type: ignore
-                with open(
-                    target, "w" + ("b" if isinstance(source, bytes) else "")
-                ) as f:
-                    f.write(source)
-
-        blob_normalizer = self.get_blob_normalizer()
-        return build_index_from_tree(
-            self.path,
-            self.index_path(),
-            self.object_store,
-            tree,
-            honor_filemode=honor_filemode,
-            validate_path_element=validate_path_element,
-            symlink_fn=symlink_fn,
-            blob_normalizer=blob_normalizer,
-        )
+        return self.get_worktree().reset_index(tree)
 
 
     def _get_config_condition_matchers(self) -> dict[str, "ConditionMatcher"]:
     def _get_config_condition_matchers(self) -> dict[str, "ConditionMatcher"]:
         """Get condition matchers for includeIf conditions.
         """Get condition matchers for includeIf conditions.
@@ -2234,40 +1980,31 @@ class Repo(BaseRepo):
 
 
         return GitAttributes(patterns)
         return GitAttributes(patterns)
 
 
+    @replace_me()
     def _sparse_checkout_file_path(self) -> str:
     def _sparse_checkout_file_path(self) -> str:
         """Return the path of the sparse-checkout file in this repo's control dir."""
         """Return the path of the sparse-checkout file in this repo's control dir."""
-        return os.path.join(self.controldir(), "info", "sparse-checkout")
+        return self.get_worktree()._sparse_checkout_file_path()
 
 
+    @replace_me()
     def configure_for_cone_mode(self) -> None:
     def configure_for_cone_mode(self) -> None:
         """Ensure the repository is configured for cone-mode sparse-checkout."""
         """Ensure the repository is configured for cone-mode sparse-checkout."""
-        config = self.get_config()
-        config.set((b"core",), b"sparseCheckout", b"true")
-        config.set((b"core",), b"sparseCheckoutCone", b"true")
-        config.write_to_path()
+        return self.get_worktree().configure_for_cone_mode()
 
 
+    @replace_me()
     def infer_cone_mode(self) -> bool:
     def infer_cone_mode(self) -> bool:
         """Return True if 'core.sparseCheckoutCone' is set to 'true' in config, else False."""
         """Return True if 'core.sparseCheckoutCone' is set to 'true' in config, else False."""
-        config = self.get_config()
-        try:
-            sc_cone = config.get((b"core",), b"sparseCheckoutCone")
-            return sc_cone == b"true"
-        except KeyError:
-            # If core.sparseCheckoutCone is not set, default to False
-            return False
+        return self.get_worktree().infer_cone_mode()
 
 
+    @replace_me()
     def get_sparse_checkout_patterns(self) -> list[str]:
     def get_sparse_checkout_patterns(self) -> list[str]:
         """Return a list of sparse-checkout patterns from info/sparse-checkout.
         """Return a list of sparse-checkout patterns from info/sparse-checkout.
 
 
         Returns:
         Returns:
             A list of patterns. Returns an empty list if the file is missing.
             A list of patterns. Returns an empty list if the file is missing.
         """
         """
-        path = self._sparse_checkout_file_path()
-        try:
-            with open(path, encoding="utf-8") as f:
-                return [line.strip() for line in f if line.strip()]
-        except FileNotFoundError:
-            return []
+        return self.get_worktree().get_sparse_checkout_patterns()
 
 
+    @replace_me()
     def set_sparse_checkout_patterns(self, patterns: list[str]) -> None:
     def set_sparse_checkout_patterns(self, patterns: list[str]) -> None:
         """Write the given sparse-checkout patterns into info/sparse-checkout.
         """Write the given sparse-checkout patterns into info/sparse-checkout.
 
 
@@ -2276,14 +2013,9 @@ class Repo(BaseRepo):
         Args:
         Args:
             patterns: A list of gitignore-style patterns to store.
             patterns: A list of gitignore-style patterns to store.
         """
         """
-        info_dir = os.path.join(self.controldir(), "info")
-        os.makedirs(info_dir, exist_ok=True)
-
-        path = self._sparse_checkout_file_path()
-        with open(path, "w", encoding="utf-8") as f:
-            for pat in patterns:
-                f.write(pat + "\n")
+        return self.get_worktree().set_sparse_checkout_patterns(patterns)
 
 
+    @replace_me()
     def set_cone_mode_patterns(self, dirs: Union[list[str], None] = None) -> None:
     def set_cone_mode_patterns(self, dirs: Union[list[str], None] = None) -> None:
         """Write the given cone-mode directory patterns into info/sparse-checkout.
         """Write the given cone-mode directory patterns into info/sparse-checkout.
 
 
@@ -2291,14 +2023,7 @@ class Repo(BaseRepo):
         ``!/*/`` 'exclude' that re-includes that directory and everything under it.
         ``!/*/`` 'exclude' that re-includes that directory and everything under it.
         Never add the same line twice.
         Never add the same line twice.
         """
         """
-        patterns = ["/*", "!/*/"]
-        if dirs:
-            for d in dirs:
-                d = d.strip("/")
-                line = f"/{d}/"
-                if d and line not in patterns:
-                    patterns.append(line)
-        self.set_sparse_checkout_patterns(patterns)
+        return self.get_worktree().set_cone_mode_patterns(dirs)
 
 
 
 
 class MemoryRepo(BaseRepo):
 class MemoryRepo(BaseRepo):
@@ -2420,6 +2145,136 @@ class MemoryRepo(BaseRepo):
         # Return empty GitAttributes
         # Return empty GitAttributes
         return GitAttributes([])
         return GitAttributes([])
 
 
+    def do_commit(
+        self,
+        message: Optional[bytes] = None,
+        committer: Optional[bytes] = None,
+        author: Optional[bytes] = None,
+        commit_timestamp=None,
+        commit_timezone=None,
+        author_timestamp=None,
+        author_timezone=None,
+        tree: Optional[ObjectID] = None,
+        encoding: Optional[bytes] = None,
+        ref: Optional[Ref] = b"HEAD",
+        merge_heads: Optional[list[ObjectID]] = None,
+        no_verify: bool = False,
+        sign: bool = False,
+    ):
+        """Create a new commit.
+
+        This is a simplified implementation for in-memory repositories that
+        doesn't support worktree operations or hooks.
+
+        Args:
+          message: Commit message
+          committer: Committer fullname
+          author: Author fullname
+          commit_timestamp: Commit timestamp (defaults to now)
+          commit_timezone: Commit timestamp timezone (defaults to GMT)
+          author_timestamp: Author timestamp (defaults to commit timestamp)
+          author_timezone: Author timestamp timezone (defaults to commit timezone)
+          tree: SHA1 of the tree root to use
+          encoding: Encoding
+          ref: Optional ref to commit to (defaults to current branch).
+            If None, creates a dangling commit without updating any ref.
+          merge_heads: Merge heads
+          no_verify: Skip pre-commit and commit-msg hooks (ignored for MemoryRepo)
+          sign: GPG Sign the commit (ignored for MemoryRepo)
+
+        Returns:
+          New commit SHA1
+        """
+        import time
+
+        from .objects import Commit
+
+        if tree is None:
+            raise ValueError("tree must be specified for MemoryRepo")
+
+        c = Commit()
+        if len(tree) != 40:
+            raise ValueError("tree must be a 40-byte hex sha string")
+        c.tree = tree
+
+        config = self.get_config_stack()
+        if merge_heads is None:
+            merge_heads = []
+        if committer is None:
+            committer = get_user_identity(config, kind="COMMITTER")
+        check_user_identity(committer)
+        c.committer = committer
+        if commit_timestamp is None:
+            commit_timestamp = time.time()
+        c.commit_time = int(commit_timestamp)
+        if commit_timezone is None:
+            commit_timezone = 0
+        c.commit_timezone = commit_timezone
+        if author is None:
+            author = get_user_identity(config, kind="AUTHOR")
+        c.author = author
+        check_user_identity(author)
+        if author_timestamp is None:
+            author_timestamp = commit_timestamp
+        c.author_time = int(author_timestamp)
+        if author_timezone is None:
+            author_timezone = commit_timezone
+        c.author_timezone = author_timezone
+        if encoding is None:
+            try:
+                encoding = config.get(("i18n",), "commitEncoding")
+            except KeyError:
+                pass
+        if encoding is not None:
+            c.encoding = encoding
+
+        # Handle message (for MemoryRepo, we don't support callable messages)
+        if callable(message):
+            message = message(self, c)
+            if message is None:
+                raise ValueError("Message callback returned None")
+
+        if message is None:
+            raise ValueError("No commit message specified")
+
+        c.message = message
+
+        if ref is None:
+            # Create a dangling commit
+            c.parents = merge_heads
+            self.object_store.add_object(c)
+        else:
+            try:
+                old_head = self.refs[ref]
+                c.parents = [old_head, *merge_heads]
+                self.object_store.add_object(c)
+                ok = self.refs.set_if_equals(
+                    ref,
+                    old_head,
+                    c.id,
+                    message=b"commit: " + message,
+                    committer=committer,
+                    timestamp=commit_timestamp,
+                    timezone=commit_timezone,
+                )
+            except KeyError:
+                c.parents = merge_heads
+                self.object_store.add_object(c)
+                ok = self.refs.add_if_new(
+                    ref,
+                    c.id,
+                    message=b"commit: " + message,
+                    committer=committer,
+                    timestamp=commit_timestamp,
+                    timezone=commit_timezone,
+                )
+            if not ok:
+                from .errors import CommitError
+
+                raise CommitError(f"{ref!r} changed during commit")
+
+        return c.id
+
     @classmethod
     @classmethod
     def init_bare(cls, objects, refs, format: Optional[int] = None):
     def init_bare(cls, objects, refs, format: Optional[int] = None):
         """Create a new bare repository in memory.
         """Create a new bare repository in memory.

+ 502 - 0
dulwich/worktree.py

@@ -0,0 +1,502 @@
+# worktree.py -- Working tree operations for Git repositories
+# Copyright (C) 2024 Jelmer Vernooij <jelmer@jelmer.uk>
+#
+# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
+# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
+# General Public License as published by the Free Software Foundation; version 2.0
+# or (at your option) any later version. You can redistribute it and/or
+# modify it under the terms of either of these two licenses.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# You should have received a copy of the licenses; if not, see
+# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
+# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
+# License, Version 2.0.
+#
+
+"""Working tree operations for Git repositories."""
+
+import os
+import stat
+import sys
+import time
+import warnings
+from collections.abc import Iterable
+from typing import TYPE_CHECKING, Optional, Union
+
+if TYPE_CHECKING:
+    from .repo import Repo
+
+from .errors import CommitError, HookError
+from .objects import Commit, ObjectID, Tag, Tree
+from .refs import Ref
+from .repo import check_user_identity, get_user_identity
+
+
+class WorkTree:
+    """Working tree operations for a Git repository.
+
+    This class provides methods for working with the working tree,
+    such as staging files, committing changes, and resetting the index.
+    """
+
+    def __init__(self, repo: "Repo", path: Union[str, bytes, os.PathLike]) -> None:
+        """Initialize a WorkTree for the given repository.
+
+        Args:
+            repo: The repository this working tree belongs to
+            path: Path to the working tree directory
+        """
+        self._repo = repo
+        raw_path = os.fspath(path)
+        if isinstance(raw_path, bytes):
+            self.path: str = os.fsdecode(raw_path)
+        else:
+            self.path = raw_path
+        self.path = os.path.abspath(self.path)
+
+    def stage(
+        self,
+        fs_paths: Union[
+            str, bytes, os.PathLike, Iterable[Union[str, bytes, os.PathLike]]
+        ],
+    ) -> None:
+        """Stage a set of paths.
+
+        Args:
+          fs_paths: List of paths, relative to the repository path
+        """
+        root_path_bytes = os.fsencode(self.path)
+
+        if isinstance(fs_paths, (str, bytes, os.PathLike)):
+            fs_paths = [fs_paths]
+        fs_paths = list(fs_paths)
+
+        from .index import (
+            _fs_to_tree_path,
+            blob_from_path_and_stat,
+            index_entry_from_directory,
+            index_entry_from_stat,
+        )
+
+        index = self._repo.open_index()
+        blob_normalizer = self._repo.get_blob_normalizer()
+        for fs_path in fs_paths:
+            if not isinstance(fs_path, bytes):
+                fs_path = os.fsencode(fs_path)
+            if os.path.isabs(fs_path):
+                raise ValueError(
+                    f"path {fs_path!r} should be relative to "
+                    "repository root, not absolute"
+                )
+            tree_path = _fs_to_tree_path(fs_path)
+            full_path = os.path.join(root_path_bytes, fs_path)
+            try:
+                st = os.lstat(full_path)
+            except OSError:
+                # File no longer exists
+                try:
+                    del index[tree_path]
+                except KeyError:
+                    pass  # already removed
+            else:
+                if stat.S_ISDIR(st.st_mode):
+                    entry = index_entry_from_directory(st, full_path)
+                    if entry:
+                        index[tree_path] = entry
+                    else:
+                        try:
+                            del index[tree_path]
+                        except KeyError:
+                            pass
+                elif not stat.S_ISREG(st.st_mode) and not stat.S_ISLNK(st.st_mode):
+                    try:
+                        del index[tree_path]
+                    except KeyError:
+                        pass
+                else:
+                    blob = blob_from_path_and_stat(full_path, st)
+                    blob = blob_normalizer.checkin_normalize(blob, fs_path)
+                    self._repo.object_store.add_object(blob)
+                    index[tree_path] = index_entry_from_stat(st, blob.id)
+        index.write()
+
+    def unstage(self, fs_paths: list[str]) -> None:
+        """Unstage specific file in the index
+        Args:
+          fs_paths: a list of files to unstage,
+            relative to the repository path.
+        """
+        from .index import IndexEntry, _fs_to_tree_path
+
+        index = self._repo.open_index()
+        try:
+            tree_id = self._repo[b"HEAD"].tree
+        except KeyError:
+            # no head mean no commit in the repo
+            for fs_path in fs_paths:
+                tree_path = _fs_to_tree_path(fs_path)
+                del index[tree_path]
+            index.write()
+            return
+
+        for fs_path in fs_paths:
+            tree_path = _fs_to_tree_path(fs_path)
+            try:
+                tree = self._repo.object_store[tree_id]
+                assert isinstance(tree, Tree)
+                tree_entry = tree.lookup_path(
+                    self._repo.object_store.__getitem__, tree_path
+                )
+            except KeyError:
+                # if tree_entry didn't exist, this file was being added, so
+                # remove index entry
+                try:
+                    del index[tree_path]
+                    continue
+                except KeyError as exc:
+                    raise KeyError(f"file '{tree_path.decode()}' not in index") from exc
+
+            st = None
+            try:
+                st = os.lstat(os.path.join(self.path, fs_path))
+            except FileNotFoundError:
+                pass
+
+            index_entry = IndexEntry(
+                ctime=(self._repo[b"HEAD"].commit_time, 0),
+                mtime=(self._repo[b"HEAD"].commit_time, 0),
+                dev=st.st_dev if st else 0,
+                ino=st.st_ino if st else 0,
+                mode=tree_entry[0],
+                uid=st.st_uid if st else 0,
+                gid=st.st_gid if st else 0,
+                size=len(self._repo[tree_entry[1]].data),
+                sha=tree_entry[1],
+                flags=0,
+                extended_flags=0,
+            )
+
+            index[tree_path] = index_entry
+        index.write()
+
+    def commit(
+        self,
+        message: Optional[bytes] = None,
+        committer: Optional[bytes] = None,
+        author: Optional[bytes] = None,
+        commit_timestamp=None,
+        commit_timezone=None,
+        author_timestamp=None,
+        author_timezone=None,
+        tree: Optional[ObjectID] = None,
+        encoding: Optional[bytes] = None,
+        ref: Optional[Ref] = b"HEAD",
+        merge_heads: Optional[list[ObjectID]] = None,
+        no_verify: bool = False,
+        sign: bool = False,
+    ):
+        """Create a new commit.
+
+        If not specified, committer and author default to
+        get_user_identity(..., 'COMMITTER')
+        and get_user_identity(..., 'AUTHOR') respectively.
+
+        Args:
+          message: Commit message (bytes or callable that takes (repo, commit)
+            and returns bytes)
+          committer: Committer fullname
+          author: Author fullname
+          commit_timestamp: Commit timestamp (defaults to now)
+          commit_timezone: Commit timestamp timezone (defaults to GMT)
+          author_timestamp: Author timestamp (defaults to commit
+            timestamp)
+          author_timezone: Author timestamp timezone
+            (defaults to commit timestamp timezone)
+          tree: SHA1 of the tree root to use (if not specified the
+            current index will be committed).
+          encoding: Encoding
+          ref: Optional ref to commit to (defaults to current branch).
+            If None, creates a dangling commit without updating any ref.
+          merge_heads: Merge heads (defaults to .git/MERGE_HEAD)
+          no_verify: Skip pre-commit and commit-msg hooks
+          sign: GPG Sign the commit (bool, defaults to False,
+            pass True to use default GPG key,
+            pass a str containing Key ID to use a specific GPG key)
+
+        Returns:
+          New commit SHA1
+        """
+        try:
+            if not no_verify:
+                self._repo.hooks["pre-commit"].execute()
+        except HookError as exc:
+            raise CommitError(exc) from exc
+        except KeyError:  # no hook defined, silent fallthrough
+            pass
+
+        c = Commit()
+        if tree is None:
+            index = self._repo.open_index()
+            c.tree = index.commit(self._repo.object_store)
+        else:
+            if len(tree) != 40:
+                raise ValueError("tree must be a 40-byte hex sha string")
+            c.tree = tree
+
+        config = self._repo.get_config_stack()
+        if merge_heads is None:
+            merge_heads = self._repo._read_heads("MERGE_HEAD")
+        if committer is None:
+            committer = get_user_identity(config, kind="COMMITTER")
+        check_user_identity(committer)
+        c.committer = committer
+        if commit_timestamp is None:
+            # FIXME: Support GIT_COMMITTER_DATE environment variable
+            commit_timestamp = time.time()
+        c.commit_time = int(commit_timestamp)
+        if commit_timezone is None:
+            # FIXME: Use current user timezone rather than UTC
+            commit_timezone = 0
+        c.commit_timezone = commit_timezone
+        if author is None:
+            author = get_user_identity(config, kind="AUTHOR")
+        c.author = author
+        check_user_identity(author)
+        if author_timestamp is None:
+            # FIXME: Support GIT_AUTHOR_DATE environment variable
+            author_timestamp = commit_timestamp
+        c.author_time = int(author_timestamp)
+        if author_timezone is None:
+            author_timezone = commit_timezone
+        c.author_timezone = author_timezone
+        if encoding is None:
+            try:
+                encoding = config.get(("i18n",), "commitEncoding")
+            except KeyError:
+                pass  # No dice
+        if encoding is not None:
+            c.encoding = encoding
+        # Store original message (might be callable)
+        original_message = message
+        message = None  # Will be set later after parents are set
+
+        # Check if we should sign the commit
+        should_sign = sign
+        if sign is None:
+            # Check commit.gpgSign configuration when sign is not explicitly set
+            config = self._repo.get_config_stack()
+            try:
+                should_sign = config.get_boolean((b"commit",), b"gpgSign")
+            except KeyError:
+                should_sign = False  # Default to not signing if no config
+        keyid = sign if isinstance(sign, str) else None
+
+        if ref is None:
+            # Create a dangling commit
+            c.parents = merge_heads
+        else:
+            try:
+                old_head = self._repo.refs[ref]
+                c.parents = [old_head, *merge_heads]
+            except KeyError:
+                c.parents = merge_heads
+
+        # Handle message after parents are set
+        if callable(original_message):
+            message = original_message(self._repo, c)
+            if message is None:
+                raise ValueError("Message callback returned None")
+        else:
+            message = original_message
+
+        if message is None:
+            # FIXME: Try to read commit message from .git/MERGE_MSG
+            raise ValueError("No commit message specified")
+
+        try:
+            if no_verify:
+                c.message = message
+            else:
+                c.message = self._repo.hooks["commit-msg"].execute(message)
+                if c.message is None:
+                    c.message = message
+        except HookError as exc:
+            raise CommitError(exc) from exc
+        except KeyError:  # no hook defined, message not modified
+            c.message = message
+
+        if ref is None:
+            # Create a dangling commit
+            if should_sign:
+                c.sign(keyid)
+            self._repo.object_store.add_object(c)
+        else:
+            try:
+                old_head = self._repo.refs[ref]
+                if should_sign:
+                    c.sign(keyid)
+                self._repo.object_store.add_object(c)
+                ok = self._repo.refs.set_if_equals(
+                    ref,
+                    old_head,
+                    c.id,
+                    message=b"commit: " + message,
+                    committer=committer,
+                    timestamp=commit_timestamp,
+                    timezone=commit_timezone,
+                )
+            except KeyError:
+                c.parents = merge_heads
+                if should_sign:
+                    c.sign(keyid)
+                self._repo.object_store.add_object(c)
+                ok = self._repo.refs.add_if_new(
+                    ref,
+                    c.id,
+                    message=b"commit: " + message,
+                    committer=committer,
+                    timestamp=commit_timestamp,
+                    timezone=commit_timezone,
+                )
+            if not ok:
+                # Fail if the atomic compare-and-swap failed, leaving the
+                # commit and all its objects as garbage.
+                raise CommitError(f"{ref!r} changed during commit")
+
+        self._repo._del_named_file("MERGE_HEAD")
+
+        try:
+            self._repo.hooks["post-commit"].execute()
+        except HookError as e:  # silent failure
+            warnings.warn(f"post-commit hook failed: {e}", UserWarning)
+        except KeyError:  # no hook defined, silent fallthrough
+            pass
+
+        # Trigger auto GC if needed
+        from .gc import maybe_auto_gc
+
+        maybe_auto_gc(self._repo)
+
+        return c.id
+
+    def reset_index(self, tree: Optional[bytes] = None):
+        """Reset the index back to a specific tree.
+
+        Args:
+          tree: Tree SHA to reset to, None for current HEAD tree.
+        """
+        from .index import (
+            build_index_from_tree,
+            symlink,
+            validate_path_element_default,
+            validate_path_element_hfs,
+            validate_path_element_ntfs,
+        )
+
+        if tree is None:
+            head = self._repo[b"HEAD"]
+            if isinstance(head, Tag):
+                _cls, obj = head.object
+                head = self._repo.get_object(obj)
+            tree = head.tree
+        config = self._repo.get_config()
+        honor_filemode = config.get_boolean(b"core", b"filemode", os.name != "nt")
+        if config.get_boolean(b"core", b"core.protectNTFS", os.name == "nt"):
+            validate_path_element = validate_path_element_ntfs
+        elif config.get_boolean(b"core", b"core.protectHFS", sys.platform == "darwin"):
+            validate_path_element = validate_path_element_hfs
+        else:
+            validate_path_element = validate_path_element_default
+        if config.get_boolean(b"core", b"symlinks", True):
+            symlink_fn = symlink
+        else:
+
+            def symlink_fn(source, target) -> None:  # type: ignore
+                with open(
+                    target, "w" + ("b" if isinstance(source, bytes) else "")
+                ) as f:
+                    f.write(source)
+
+        blob_normalizer = self._repo.get_blob_normalizer()
+        return build_index_from_tree(
+            self.path,
+            self._repo.index_path(),
+            self._repo.object_store,
+            tree,
+            honor_filemode=honor_filemode,
+            validate_path_element=validate_path_element,
+            symlink_fn=symlink_fn,
+            blob_normalizer=blob_normalizer,
+        )
+
+    def _sparse_checkout_file_path(self) -> str:
+        """Return the path of the sparse-checkout file in this repo's control dir."""
+        return os.path.join(self._repo.controldir(), "info", "sparse-checkout")
+
+    def configure_for_cone_mode(self) -> None:
+        """Ensure the repository is configured for cone-mode sparse-checkout."""
+        config = self._repo.get_config()
+        config.set((b"core",), b"sparseCheckout", b"true")
+        config.set((b"core",), b"sparseCheckoutCone", b"true")
+        config.write_to_path()
+
+    def infer_cone_mode(self) -> bool:
+        """Return True if 'core.sparseCheckoutCone' is set to 'true' in config, else False."""
+        config = self._repo.get_config()
+        try:
+            sc_cone = config.get((b"core",), b"sparseCheckoutCone")
+            return sc_cone == b"true"
+        except KeyError:
+            # If core.sparseCheckoutCone is not set, default to False
+            return False
+
+    def get_sparse_checkout_patterns(self) -> list[str]:
+        """Return a list of sparse-checkout patterns from info/sparse-checkout.
+
+        Returns:
+            A list of patterns. Returns an empty list if the file is missing.
+        """
+        path = self._sparse_checkout_file_path()
+        try:
+            with open(path, encoding="utf-8") as f:
+                return [line.strip() for line in f if line.strip()]
+        except FileNotFoundError:
+            return []
+
+    def set_sparse_checkout_patterns(self, patterns: list[str]) -> None:
+        """Write the given sparse-checkout patterns into info/sparse-checkout.
+
+        Creates the info/ directory if it does not exist.
+
+        Args:
+            patterns: A list of gitignore-style patterns to store.
+        """
+        info_dir = os.path.join(self._repo.controldir(), "info")
+        os.makedirs(info_dir, exist_ok=True)
+
+        path = self._sparse_checkout_file_path()
+        with open(path, "w", encoding="utf-8") as f:
+            for pat in patterns:
+                f.write(pat + "\n")
+
+    def set_cone_mode_patterns(self, dirs: Union[list[str], None] = None) -> None:
+        """Write the given cone-mode directory patterns into info/sparse-checkout.
+
+        For each directory to include, add an inclusion line that "undoes" the prior
+        ``!/*/`` 'exclude' that re-includes that directory and everything under it.
+        Never add the same line twice.
+        """
+        patterns = ["/*", "!/*/"]
+        if dirs:
+            for d in dirs:
+                d = d.strip("/")
+                line = f"/{d}/"
+                if d and line not in patterns:
+                    patterns.append(line)
+        self.set_sparse_checkout_patterns(patterns)

+ 1 - 0
tests/__init__.py

@@ -173,6 +173,7 @@ def self_test_suite():
         "utils",
         "utils",
         "walk",
         "walk",
         "web",
         "web",
+        "worktree",
     ]
     ]
     module_names = ["tests.test_" + name for name in names]
     module_names = ["tests.test_" + name for name in names]
     loader = unittest.TestLoader()
     loader = unittest.TestLoader()

+ 410 - 0
tests/test_worktree.py

@@ -0,0 +1,410 @@
+# test_worktree.py -- Tests for dulwich.worktree
+# Copyright (C) 2024 Jelmer Vernooij <jelmer@jelmer.uk>
+#
+# SPDX-License-Identifier: Apache-2.0 OR GPL-2.0-or-later
+# Dulwich is dual-licensed under the Apache License, Version 2.0 and the GNU
+# General Public License as published by the Free Software Foundation; version 2.0
+# or (at your option) any later version. You can redistribute it and/or
+# modify it under the terms of either of these two licenses.
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# You should have received a copy of the licenses; if not, see
+# <http://www.gnu.org/licenses/> for a copy of the GNU General Public License
+# and <http://www.apache.org/licenses/LICENSE-2.0> for a copy of the Apache
+# License, Version 2.0.
+#
+
+"""Tests for dulwich.worktree."""
+
+import os
+import stat
+import tempfile
+from unittest import skipIf
+
+from dulwich import porcelain
+from dulwich.object_store import tree_lookup_path
+from dulwich.repo import Repo
+from dulwich.worktree import WorkTree
+
+from . import TestCase
+
+
+class WorkTreeTestCase(TestCase):
+    """Base test case for WorkTree tests."""
+
+    def setUp(self):
+        super().setUp()
+        self.test_dir = tempfile.mkdtemp()
+        self.repo = Repo.init(self.test_dir)
+
+        # Create initial commit with a file
+        with open(os.path.join(self.test_dir, "a"), "wb") as f:
+            f.write(b"contents of file a")
+        self.repo.stage(["a"])
+        self.root_commit = self.repo.do_commit(
+            b"Initial commit",
+            committer=b"Test Committer <test@nodomain.com>",
+            author=b"Test Author <test@nodomain.com>",
+            commit_timestamp=12345,
+            commit_timezone=0,
+            author_timestamp=12345,
+            author_timezone=0,
+        )
+        self.worktree = self.repo.get_worktree()
+
+    def tearDown(self):
+        self.repo.close()
+        super().tearDown()
+
+
+class WorkTreeInitTests(TestCase):
+    """Tests for WorkTree initialization."""
+
+    def test_init_with_repo_path(self):
+        """Test WorkTree initialization with same path as repo."""
+        with tempfile.TemporaryDirectory() as tmpdir:
+            repo = Repo.init(tmpdir)
+            worktree = WorkTree(repo, tmpdir)
+
+            self.assertEqual(worktree.path, tmpdir)
+            self.assertEqual(worktree._repo, repo)
+            self.assertTrue(os.path.isabs(worktree.path))
+
+    def test_init_with_different_path(self):
+        """Test WorkTree initialization with different path from repo."""
+        with tempfile.TemporaryDirectory() as tmpdir:
+            repo_path = os.path.join(tmpdir, "repo")
+            worktree_path = os.path.join(tmpdir, "worktree")
+
+            os.makedirs(repo_path)
+            os.makedirs(worktree_path)
+
+            repo = Repo.init(repo_path)
+            worktree = WorkTree(repo, worktree_path)
+
+            self.assertNotEqual(worktree.path, repo.path)
+            self.assertEqual(worktree.path, worktree_path)
+            self.assertEqual(worktree._repo, repo)
+            self.assertTrue(os.path.isabs(worktree.path))
+
+    def test_init_with_bytes_path(self):
+        """Test WorkTree initialization with bytes path."""
+        with tempfile.TemporaryDirectory() as tmpdir:
+            repo = Repo.init(tmpdir)
+            worktree = WorkTree(repo, tmpdir.encode("utf-8"))
+
+            self.assertEqual(worktree.path, tmpdir)
+            self.assertIsInstance(worktree.path, str)
+
+
+class WorkTreeStagingTests(WorkTreeTestCase):
+    """Tests for WorkTree staging operations."""
+
+    def test_stage_absolute(self):
+        """Test that staging with absolute paths raises ValueError."""
+        r = self.repo
+        os.remove(os.path.join(r.path, "a"))
+        self.assertRaises(ValueError, self.worktree.stage, [os.path.join(r.path, "a")])
+
+    def test_stage_deleted(self):
+        """Test staging a deleted file."""
+        r = self.repo
+        os.remove(os.path.join(r.path, "a"))
+        self.worktree.stage(["a"])
+        self.worktree.stage(["a"])  # double-stage a deleted path
+        self.assertEqual([], list(r.open_index()))
+
+    def test_stage_directory(self):
+        """Test staging a directory."""
+        r = self.repo
+        os.mkdir(os.path.join(r.path, "c"))
+        self.worktree.stage(["c"])
+        self.assertEqual([b"a"], list(r.open_index()))
+
+    def test_stage_submodule(self):
+        """Test staging a submodule."""
+        r = self.repo
+        s = Repo.init(os.path.join(r.path, "sub"), mkdir=True)
+        s.do_commit(b"message")
+        self.worktree.stage(["sub"])
+        self.assertEqual([b"a", b"sub"], list(r.open_index()))
+
+
+class WorkTreeUnstagingTests(WorkTreeTestCase):
+    """Tests for WorkTree unstaging operations."""
+
+    def test_unstage_modify_file_with_dir(self):
+        """Test unstaging a modified file in a directory."""
+        os.mkdir(os.path.join(self.repo.path, "new_dir"))
+        full_path = os.path.join(self.repo.path, "new_dir", "foo")
+
+        with open(full_path, "w") as f:
+            f.write("hello")
+        porcelain.add(self.repo, paths=[full_path])
+        porcelain.commit(
+            self.repo,
+            message=b"unittest",
+            committer=b"Jane <jane@example.com>",
+            author=b"John <john@example.com>",
+        )
+        with open(full_path, "a") as f:
+            f.write("something new")
+        self.worktree.unstage(["new_dir/foo"])
+        status = list(porcelain.status(self.repo))
+        self.assertEqual(
+            [{"add": [], "delete": [], "modify": []}, [b"new_dir/foo"], []], status
+        )
+
+    def test_unstage_while_no_commit(self):
+        """Test unstaging when there are no commits."""
+        file = "foo"
+        full_path = os.path.join(self.repo.path, file)
+        with open(full_path, "w") as f:
+            f.write("hello")
+        porcelain.add(self.repo, paths=[full_path])
+        self.worktree.unstage([file])
+        status = list(porcelain.status(self.repo))
+        self.assertEqual([{"add": [], "delete": [], "modify": []}, [], ["foo"]], status)
+
+    def test_unstage_add_file(self):
+        """Test unstaging a newly added file."""
+        file = "foo"
+        full_path = os.path.join(self.repo.path, file)
+        porcelain.commit(
+            self.repo,
+            message=b"unittest",
+            committer=b"Jane <jane@example.com>",
+            author=b"John <john@example.com>",
+        )
+        with open(full_path, "w") as f:
+            f.write("hello")
+        porcelain.add(self.repo, paths=[full_path])
+        self.worktree.unstage([file])
+        status = list(porcelain.status(self.repo))
+        self.assertEqual([{"add": [], "delete": [], "modify": []}, [], ["foo"]], status)
+
+    def test_unstage_modify_file(self):
+        """Test unstaging a modified file."""
+        file = "foo"
+        full_path = os.path.join(self.repo.path, file)
+        with open(full_path, "w") as f:
+            f.write("hello")
+        porcelain.add(self.repo, paths=[full_path])
+        porcelain.commit(
+            self.repo,
+            message=b"unittest",
+            committer=b"Jane <jane@example.com>",
+            author=b"John <john@example.com>",
+        )
+        with open(full_path, "a") as f:
+            f.write("broken")
+        porcelain.add(self.repo, paths=[full_path])
+        self.worktree.unstage([file])
+        status = list(porcelain.status(self.repo))
+
+        self.assertEqual(
+            [{"add": [], "delete": [], "modify": []}, [b"foo"], []], status
+        )
+
+    def test_unstage_remove_file(self):
+        """Test unstaging a removed file."""
+        file = "foo"
+        full_path = os.path.join(self.repo.path, file)
+        with open(full_path, "w") as f:
+            f.write("hello")
+        porcelain.add(self.repo, paths=[full_path])
+        porcelain.commit(
+            self.repo,
+            message=b"unittest",
+            committer=b"Jane <jane@example.com>",
+            author=b"John <john@example.com>",
+        )
+        os.remove(full_path)
+        self.worktree.unstage([file])
+        status = list(porcelain.status(self.repo))
+        self.assertEqual(
+            [{"add": [], "delete": [], "modify": []}, [b"foo"], []], status
+        )
+
+
+class WorkTreeCommitTests(WorkTreeTestCase):
+    """Tests for WorkTree commit operations."""
+
+    def test_commit_modified(self):
+        """Test committing a modified file."""
+        r = self.repo
+        with open(os.path.join(r.path, "a"), "wb") as f:
+            f.write(b"new contents")
+        self.worktree.stage(["a"])
+        commit_sha = self.worktree.commit(
+            b"modified a",
+            committer=b"Test Committer <test@nodomain.com>",
+            author=b"Test Author <test@nodomain.com>",
+            commit_timestamp=12395,
+            commit_timezone=0,
+            author_timestamp=12395,
+            author_timezone=0,
+        )
+        self.assertEqual([self.root_commit], r[commit_sha].parents)
+        a_mode, a_id = tree_lookup_path(r.get_object, r[commit_sha].tree, b"a")
+        self.assertEqual(stat.S_IFREG | 0o644, a_mode)
+        self.assertEqual(b"new contents", r[a_id].data)
+
+    @skipIf(not getattr(os, "symlink", None), "Requires symlink support")
+    def test_commit_symlink(self):
+        """Test committing a symlink."""
+        r = self.repo
+        os.symlink("a", os.path.join(r.path, "b"))
+        self.worktree.stage(["a", "b"])
+        commit_sha = self.worktree.commit(
+            b"Symlink b",
+            committer=b"Test Committer <test@nodomain.com>",
+            author=b"Test Author <test@nodomain.com>",
+            commit_timestamp=12395,
+            commit_timezone=0,
+            author_timestamp=12395,
+            author_timezone=0,
+        )
+        self.assertEqual([self.root_commit], r[commit_sha].parents)
+        b_mode, b_id = tree_lookup_path(r.get_object, r[commit_sha].tree, b"b")
+        self.assertEqual(stat.S_IFLNK, b_mode)
+        self.assertEqual(b"a", r[b_id].data)
+
+
+class WorkTreeResetTests(WorkTreeTestCase):
+    """Tests for WorkTree reset operations."""
+
+    def test_reset_index(self):
+        """Test resetting the index."""
+        # Make some changes and stage them
+        with open(os.path.join(self.repo.path, "a"), "wb") as f:
+            f.write(b"modified contents")
+        self.worktree.stage(["a"])
+
+        # Reset index should restore to HEAD
+        self.worktree.reset_index()
+
+        # Check that the working tree file was restored
+        with open(os.path.join(self.repo.path, "a"), "rb") as f:
+            contents = f.read()
+        self.assertEqual(b"contents of file a", contents)
+
+
+class WorkTreeSparseCheckoutTests(WorkTreeTestCase):
+    """Tests for WorkTree sparse checkout operations."""
+
+    def test_get_sparse_checkout_patterns_empty(self):
+        """Test getting sparse checkout patterns when file doesn't exist."""
+        patterns = self.worktree.get_sparse_checkout_patterns()
+        self.assertEqual([], patterns)
+
+    def test_set_sparse_checkout_patterns(self):
+        """Test setting sparse checkout patterns."""
+        patterns = ["*.py", "docs/"]
+        self.worktree.set_sparse_checkout_patterns(patterns)
+
+        # Read back the patterns
+        retrieved_patterns = self.worktree.get_sparse_checkout_patterns()
+        self.assertEqual(patterns, retrieved_patterns)
+
+    def test_configure_for_cone_mode(self):
+        """Test configuring repository for cone mode."""
+        self.worktree.configure_for_cone_mode()
+
+        config = self.repo.get_config()
+        self.assertEqual(b"true", config.get((b"core",), b"sparseCheckout"))
+        self.assertEqual(b"true", config.get((b"core",), b"sparseCheckoutCone"))
+
+    def test_infer_cone_mode_false(self):
+        """Test inferring cone mode when not configured."""
+        self.assertFalse(self.worktree.infer_cone_mode())
+
+    def test_infer_cone_mode_true(self):
+        """Test inferring cone mode when configured."""
+        self.worktree.configure_for_cone_mode()
+        self.assertTrue(self.worktree.infer_cone_mode())
+
+    def test_set_cone_mode_patterns(self):
+        """Test setting cone mode patterns."""
+        dirs = ["src", "tests"]
+        self.worktree.set_cone_mode_patterns(dirs)
+
+        patterns = self.worktree.get_sparse_checkout_patterns()
+        expected = ["/*", "!/*/", "/src/", "/tests/"]
+        self.assertEqual(expected, patterns)
+
+    def test_set_cone_mode_patterns_empty(self):
+        """Test setting cone mode patterns with empty list."""
+        self.worktree.set_cone_mode_patterns([])
+
+        patterns = self.worktree.get_sparse_checkout_patterns()
+        expected = ["/*", "!/*/"]
+        self.assertEqual(expected, patterns)
+
+    def test_set_cone_mode_patterns_duplicates(self):
+        """Test that duplicate patterns are not added."""
+        dirs = ["src", "src"]  # duplicate
+        self.worktree.set_cone_mode_patterns(dirs)
+
+        patterns = self.worktree.get_sparse_checkout_patterns()
+        expected = ["/*", "!/*/", "/src/"]
+        self.assertEqual(expected, patterns)
+
+    def test_sparse_checkout_file_path(self):
+        """Test getting the sparse checkout file path."""
+        expected_path = os.path.join(self.repo.controldir(), "info", "sparse-checkout")
+        actual_path = self.worktree._sparse_checkout_file_path()
+        self.assertEqual(expected_path, actual_path)
+
+
+class WorkTreeBackwardCompatibilityTests(WorkTreeTestCase):
+    """Tests for backward compatibility of deprecated Repo methods."""
+
+    def test_deprecated_stage_delegates_to_worktree(self):
+        """Test that deprecated Repo.stage delegates to WorkTree."""
+        with open(os.path.join(self.repo.path, "new_file"), "w") as f:
+            f.write("test content")
+
+        # This should show a deprecation warning but still work
+        import warnings
+
+        with warnings.catch_warnings(record=True) as w:
+            warnings.simplefilter("always")
+            self.repo.stage(["new_file"])
+            self.assertTrue(len(w) > 0)
+            self.assertTrue(issubclass(w[0].category, DeprecationWarning))
+
+    def test_deprecated_unstage_delegates_to_worktree(self):
+        """Test that deprecated Repo.unstage delegates to WorkTree."""
+        # This should show a deprecation warning but still work
+        import warnings
+
+        with warnings.catch_warnings(record=True) as w:
+            warnings.simplefilter("always")
+            self.repo.unstage(["a"])
+            self.assertTrue(len(w) > 0)
+            self.assertTrue(issubclass(w[0].category, DeprecationWarning))
+
+    def test_deprecated_sparse_checkout_methods(self):
+        """Test that deprecated sparse checkout methods delegate to WorkTree."""
+        import warnings
+
+        # Test get_sparse_checkout_patterns
+        with warnings.catch_warnings(record=True) as w:
+            warnings.simplefilter("always")
+            patterns = self.repo.get_sparse_checkout_patterns()
+            self.assertEqual([], patterns)
+            self.assertTrue(len(w) > 0)
+            self.assertTrue(issubclass(w[0].category, DeprecationWarning))
+
+        # Test set_sparse_checkout_patterns
+        with warnings.catch_warnings(record=True) as w:
+            warnings.simplefilter("always")
+            self.repo.set_sparse_checkout_patterns(["*.py"])
+            self.assertTrue(len(w) > 0)
+            self.assertTrue(issubclass(w[0].category, DeprecationWarning))